Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,6 @@ public record AggregationData(
BigDecimal totalEstimatedCost,
BigDecimal totalEstimatedCostAvg,
PercentageValues duration,
List<FeedbackScoreAverage> feedbackScores) {
List<FeedbackScoreAverage> feedbackScores,
List<FeedbackScoreAverage> experimentScores) {
}
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ public record Experiment(
@JsonView({
Experiment.View.Public.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY) String lastUpdatedBy,
@JsonView({Experiment.View.Public.class, Experiment.View.Write.class}) ExperimentStatus status,
@JsonView({Experiment.View.Public.class, Experiment.View.Write.class}) List<ExperimentScore> experimentScores,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor: Regarding validations, the list is fine (should be optional, like you have it), but you should probably add the following inner validations to the ExperimentScore object: List<@NotNull @Valid ExperimentScore> experimentScores.

@JsonView({Experiment.View.Public.class,
Experiment.View.Write.class}) @Schema(deprecated = true) PromptVersionLink promptVersion,
@JsonView({Experiment.View.Public.class, Experiment.View.Write.class}) List<PromptVersionLink> promptVersions){
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,6 @@ public record ExperimentGroupAggregationItem(
BigDecimal totalEstimatedCost,
BigDecimal totalEstimatedCostAvg,
PercentageValues duration, // p50, p90, p99 from DB
List<FeedbackScoreAverage> feedbackScores) { // name -> average value from DB
List<FeedbackScoreAverage> feedbackScores, // name -> average value from DB
List<FeedbackScoreAverage> experimentScores) { // name -> average value from DB
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
package com.comet.opik.api;

import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.databind.PropertyNamingStrategies;
import com.fasterxml.jackson.databind.annotation.JsonNaming;
import jakarta.validation.constraints.NotBlank;
import jakarta.validation.constraints.NotNull;
import lombok.Builder;

import java.math.BigDecimal;

@Builder(toBuilder = true)
@JsonIgnoreProperties(ignoreUnknown = true)
@JsonNaming(PropertyNamingStrategies.SnakeCaseStrategy.class)
public record ExperimentScore(
@NotBlank String name,
@NotNull BigDecimal value) {
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,12 @@
import com.fasterxml.jackson.databind.PropertyNamingStrategies;
import com.fasterxml.jackson.databind.annotation.JsonNaming;
import io.swagger.v3.oas.annotations.media.Schema;
import jakarta.validation.Valid;
import jakarta.validation.constraints.Pattern;
import lombok.Builder;

import java.util.List;

import static com.comet.opik.utils.ValidationUtils.NULL_OR_NOT_BLANK;

@Builder(toBuilder = true)
Expand All @@ -17,5 +20,6 @@ public record ExperimentUpdate(
@Pattern(regexp = NULL_OR_NOT_BLANK, message = "must not be blank") String name,
JsonNode metadata,
ExperimentType type,
@Schema(description = "The status of the experiment") ExperimentStatus status) {
@Schema(description = "The status of the experiment") ExperimentStatus status,
@Valid List<ExperimentScore> experimentScores) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor: related to the other comment, the outer valid won't do much, instead you should move it to the inner ExperimentScore object: List<@NotNull @Valid ExperimentScore> experimentScores.

}
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.databind.PropertyNamingStrategies;
import com.fasterxml.jackson.databind.annotation.JsonNaming;
import jakarta.annotation.Nullable;
import lombok.Builder;

import java.util.List;
Expand All @@ -12,6 +13,9 @@
@JsonNaming(PropertyNamingStrategies.SnakeCaseStrategy.class)
public record FeedbackScoreNames(List<ScoreName> scores) {

public record ScoreName(String name) {
public record ScoreName(String name, @Nullable String type) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a JSON payload, so for protecting the service it's important that you add:

@JsonIgnoreProperties(ignoreUnknown = true)
@JsonNaming(PropertyNamingStrategies.SnakeCaseStrategy.class)

As convention, we normally assume that something not validated is Nullable, so you can remove the annotation.

Finally, I recommend also adding @Builder(toBuilder = true), which will make the class more usable in the future.

public ScoreName(String name) {
this(name, null);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import static com.comet.opik.api.sorting.SortableFields.CREATED_AT;
import static com.comet.opik.api.sorting.SortableFields.CREATED_BY;
import static com.comet.opik.api.sorting.SortableFields.DURATION_AGG;
import static com.comet.opik.api.sorting.SortableFields.EXPERIMENT_METRICS;
import static com.comet.opik.api.sorting.SortableFields.FEEDBACK_SCORES;
import static com.comet.opik.api.sorting.SortableFields.ID;
import static com.comet.opik.api.sorting.SortableFields.LAST_UPDATED_AT;
Expand All @@ -29,6 +30,7 @@ public List<String> getSortableFields() {
TOTAL_ESTIMATED_COST,
TOTAL_ESTIMATED_COST_AVG,
FEEDBACK_SCORES,
EXPERIMENT_METRICS,
DURATION_AGG);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ public class SortableFields {
public static final String LLM_SPAN_COUNT = "llm_span_count";
public static final String TRACE_COUNT = "trace_count";
public static final String FEEDBACK_SCORES = "feedback_scores.*";
public static final String EXPERIMENT_METRICS = "experiment_scores.*";
public static final String NUMBER_OF_MESSAGES = "number_of_messages";
public static final String STATUS = "status";
public static final String VERSION_COUNT = "version_count";
Expand All @@ -51,4 +52,4 @@ public class SortableFields {
public static final String COMMENTS = "comments";
public static final String ENABLED = "enabled";
public static final String SAMPLING_RATE = "sampling_rate";
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import com.comet.opik.api.ExperimentGroupAggregationItem;
import com.comet.opik.api.ExperimentGroupCriteria;
import com.comet.opik.api.ExperimentGroupItem;
import com.comet.opik.api.ExperimentScore;
import com.comet.opik.api.ExperimentSearchCriteria;
import com.comet.opik.api.ExperimentStatus;
import com.comet.opik.api.ExperimentStreamRequest;
Expand All @@ -17,6 +18,8 @@
import com.comet.opik.domain.filter.FilterQueryBuilder;
import com.comet.opik.domain.filter.FilterStrategy;
import com.comet.opik.domain.sorting.SortingQueryBuilder;
import com.comet.opik.utils.JsonUtils;
import com.fasterxml.jackson.core.type.TypeReference;
import com.google.common.base.Function;
import com.google.common.base.Preconditions;
import io.opentelemetry.instrumentation.annotations.WithSpan;
Expand Down Expand Up @@ -87,7 +90,8 @@ INSERT INTO experiments (
prompt_versions,
type,
optimization_id,
status
status,
experiment_scores
)
SELECT
if(
Expand All @@ -106,7 +110,8 @@ INSERT INTO experiments (
new.prompt_versions,
new.type,
new.optimization_id,
new.status
new.status,
new.experiment_scores
FROM (
SELECT
:id AS id,
Expand All @@ -121,7 +126,8 @@ INSERT INTO experiments (
mapFromArrays(:prompt_ids, :prompt_version_ids) AS prompt_versions,
:type AS type,
:optimization_id AS optimization_id,
:status AS status
:status AS status,
:experiment_scores AS experiment_scores
) AS new
LEFT JOIN (
SELECT
Expand Down Expand Up @@ -357,6 +363,7 @@ WHERE ei.trace_id IN (SELECT id FROM trace_final)
e.optimization_id as optimization_id,
e.type as type,
e.status as status,
e.experiment_scores as experiment_scores,
fs.feedback_scores as feedback_scores,
ed.trace_count as trace_count,
ed.duration_values AS duration,
Expand Down Expand Up @@ -564,26 +571,50 @@ HAVING length(fs.name) > 0
) as fs_avg
GROUP BY experiment_id
),
experiment_scores_agg AS (
SELECT
experiment_id,
mapFromArrays(
groupArray(name),
groupArray(value)
) AS experiment_scores
FROM (
SELECT
e.id AS experiment_id,
JSONExtractString(score, 'name') AS name,
JSONExtractFloat(score, 'value') AS value
Comment on lines +584 to +585
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For this, we generally use a type neutral function: JSON_VALUE. See:

I recommend switching to that.

Please test to make sure I'm not wrong in my feedback here.

FROM experiments_final AS e
ARRAY JOIN JSONExtractArrayRaw(e.experiment_scores) AS score
WHERE e.experiment_scores IS NOT NULL
AND e.experiment_scores != ''
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor: is empty condition seems to be redundant with the length one. I'd keep the same.

AND length(e.experiment_scores) > 2
AND length(JSONExtractString(score, 'name')) > 0
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same about JSON_VALUE

) AS es
GROUP BY experiment_id
),
experiments_full AS (
SELECT
e.id as id,
e.dataset_id AS dataset_id,
e.metadata AS metadata,
fs.feedback_scores as feedback_scores,
es.experiment_scores as experiment_scores,
ed.trace_count as trace_count,
ed.duration_values AS duration,
ed.total_estimated_cost_sum as total_estimated_cost,
ed.total_estimated_cost_avg as total_estimated_cost_avg
FROM experiments_final AS e
LEFT JOIN experiment_durations AS ed ON e.id = ed.experiment_id
LEFT JOIN feedback_scores_agg AS fs ON e.id = fs.experiment_id
LEFT JOIN experiment_scores_agg AS es ON e.id = es.experiment_id
)
SELECT
count(DISTINCT id) as experiment_count,
sum(trace_count) as trace_count,
sum(total_estimated_cost) as total_estimated_cost,
avg(total_estimated_cost_avg) as total_estimated_cost_avg,
avgMap(feedback_scores) as feedback_scores,
avgMap(experiment_scores) as experiment_scores,
avgMap(duration) as duration,
<groupSelects>
FROM experiments_full
Expand Down Expand Up @@ -684,6 +715,7 @@ INSERT INTO experiments (
type,
optimization_id,
status,
experiment_scores,
created_at,
last_updated_at
)
Expand All @@ -701,6 +733,7 @@ INSERT INTO experiments (
<if(type)> :type <else> type <endif> as type,
optimization_id,
<if(status)> :status <else> status <endif> as status,
<if(experiment_scores)> :experiment_scores <else> experiment_scores <endif> as experiment_scores,
created_at,
now64(9) as last_updated_at
FROM experiments
Expand Down Expand Up @@ -729,7 +762,11 @@ private Publisher<? extends Result> insert(Experiment experiment, Connection con
.bind("metadata", getStringOrDefault(experiment.metadata()))
.bind("type", Optional.ofNullable(experiment.type()).orElse(ExperimentType.REGULAR).getValue())
.bind("optimization_id", experiment.optimizationId() != null ? experiment.optimizationId() : "")
.bind("status", Optional.ofNullable(experiment.status()).orElse(ExperimentStatus.COMPLETED).getValue());
.bind("status", Optional.ofNullable(experiment.status()).orElse(ExperimentStatus.COMPLETED).getValue())
.bind("experiment_scores", Optional.ofNullable(experiment.experimentScores())
.filter(scores -> !scores.isEmpty())
.map(JsonUtils::writeValueAsString)
.orElse(""));

if (experiment.promptVersion() != null) {
statement.bind("prompt_version_id", experiment.promptVersion().id());
Expand Down Expand Up @@ -849,6 +886,7 @@ private Publisher<Experiment> mapToDto(Result result) {
.orElse(null))
.type(ExperimentType.fromString(row.get("type", String.class)))
.status(ExperimentStatus.fromString(row.get("status", String.class)))
.experimentScores(getExperimentScores(row))
.build();
});
}
Expand Down Expand Up @@ -923,6 +961,39 @@ public static List<FeedbackScoreAverage> getFeedbackScores(Row row) {
return feedbackScoresAvg.isEmpty() ? null : feedbackScoresAvg;
}

public static List<FeedbackScoreAverage> getExperimentScoresAggregation(Row row) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a copy and past of getFeedbackScores. Just parameterise the function for the row name and reuse it.

List<FeedbackScoreAverage> experimentScoresAvg = Optional
.ofNullable(row.get("experiment_scores", Map.class))
.map(map -> (Map<String, ? extends Number>) map)
.orElse(Map.of())
.entrySet()
.stream()
.map(scores -> {
return new FeedbackScoreAverage(scores.getKey(),
BigDecimal.valueOf(scores.getValue().doubleValue()).setScale(SCALE,
RoundingMode.HALF_EVEN));
})
.toList();

return experimentScoresAvg.isEmpty() ? null : experimentScoresAvg;
}

public static List<ExperimentScore> getExperimentScores(Row row) {
String experimentScoresJson = row.get("experiment_scores", String.class);
if (experimentScoresJson == null || experimentScoresJson.isBlank()) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor: use StringUtils is blank for a null safe check.

return null;
}
try {
List<ExperimentScore> scores = JsonUtils.readValue(experimentScoresJson,
new TypeReference<List<ExperimentScore>>() {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please create a static final constant for this type reference within ExperimentScore model. It has some performance cost to instantiate every time. It's thread safe.

});
return scores == null || scores.isEmpty() ? null : scores;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor and same: you can use StringUtils.isBlank.

} catch (Exception e) {
log.warn("Failed to deserialize experiment_scores from JSON: {}", experimentScoresJson, e);
return null;
}
}

@WithSpan
Mono<ExperimentPage> find(
int page, int size, @NonNull ExperimentSearchCriteria experimentSearchCriteria) {
Expand Down Expand Up @@ -1265,6 +1336,7 @@ private Publisher<ExperimentGroupAggregationItem> mapExperimentGroupAggregationI
.totalEstimatedCostAvg(getCostValue(row, "total_estimated_cost_avg"))
.duration(getDuration(row))
.feedbackScores(getFeedbackScores(row))
.experimentScores(getExperimentScoresAggregation(row))
.build();
});
}
Expand Down Expand Up @@ -1309,6 +1381,10 @@ private ST buildUpdateTemplate(ExperimentUpdate experimentUpdate, String update)
template.add("status", experimentUpdate.status().getValue());
}

if (experimentUpdate.experimentScores() != null) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor: for all these conditionals here based on experimentScores, better use CollectionUtils.isNotEmpty.

template.add("experiment_scores", true);
}

return template;
}

Expand All @@ -1328,6 +1404,13 @@ private void bindUpdateParams(ExperimentUpdate experimentUpdate, Statement state
if (experimentUpdate.status() != null) {
statement.bind("status", experimentUpdate.status().getValue());
}

if (experimentUpdate.experimentScores() != null) {
statement.bind("experiment_scores", Optional.of(experimentUpdate.experimentScores())
.filter(scores -> !scores.isEmpty())
.map(JsonUtils::writeValueAsString)
.orElse(""));
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ private void buildNestedGroupsWithAggregations(Map<String, GroupContentWithAggre
.totalEstimatedCostAvg(BigDecimal.ZERO)
.duration(null)
.feedbackScores(List.of())
.experimentScores(List.of())
.build())
.groups(new HashMap<>())
.build();
Expand Down Expand Up @@ -140,6 +141,10 @@ private AggregationData calculateAggregatedChildrenValues(
Map<String, BigDecimal> feedbackScoreSums = new HashMap<>();
Map<String, Long> feedbackScoreCounts = new HashMap<>();

// For experiment scores - group by name and calculate weighted averages
Map<String, BigDecimal> experimentScoreSums = new HashMap<>();
Map<String, Long> experimentScoreCounts = new HashMap<>();

for (GroupContentWithAggregations child : childGroups.values()) {
AggregationData childAgg = child.aggregations();
long expCount = childAgg.experimentCount();
Expand Down Expand Up @@ -181,6 +186,19 @@ private AggregationData calculateAggregatedChildrenValues(
}
}
}

// For experiment scores (weighted average per name)
if (childAgg.experimentScores() != null) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same, this is a copy and paste of an existing block. Please encapsulate in a function and reuse.

for (FeedbackScoreAverage score : childAgg.experimentScores()) {
String name = score.name();
BigDecimal value = score.value();

if (value != null && name != null) {
experimentScoreSums.merge(name, value.multiply(BigDecimal.valueOf(expCount)), BigDecimal::add);
experimentScoreCounts.merge(name, expCount, Long::sum);
}
}
}
}

// Calculate averages
Expand All @@ -196,6 +214,8 @@ private AggregationData calculateAggregatedChildrenValues(
: null;

List<FeedbackScoreAverage> avgFeedbackScores = buildAvgFeedbackScores(feedbackScoreSums, feedbackScoreCounts);
List<FeedbackScoreAverage> avgExperimentScores = buildAvgFeedbackScores(experimentScoreSums,
experimentScoreCounts);

// Build updated aggregation data
return AggregationData.builder()
Expand All @@ -205,6 +225,7 @@ private AggregationData calculateAggregatedChildrenValues(
.totalEstimatedCostAvg(avgCost)
.duration(avgDuration)
.feedbackScores(avgFeedbackScores)
.experimentScores(avgExperimentScores)
.build();
}

Expand All @@ -231,6 +252,7 @@ private AggregationData buildAggregationData(ExperimentGroupAggregationItem item
.totalEstimatedCostAvg(item.totalEstimatedCostAvg())
.duration(item.duration())
.feedbackScores(item.feedbackScores())
.experimentScores(item.experimentScores())
.build();
}

Expand Down
Loading