-
Notifications
You must be signed in to change notification settings - Fork 1.2k
[issue-3764] [FE] [BE] [Docs] Introduce experiment scores #3989
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
1d00408
81fe6c8
21c4aef
2f1231e
4f1b508
54510fa
7fede23
e8adcdf
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,18 @@ | ||
| package com.comet.opik.api; | ||
|
|
||
| import com.fasterxml.jackson.annotation.JsonIgnoreProperties; | ||
| import com.fasterxml.jackson.databind.PropertyNamingStrategies; | ||
| import com.fasterxml.jackson.databind.annotation.JsonNaming; | ||
| import jakarta.validation.constraints.NotBlank; | ||
| import jakarta.validation.constraints.NotNull; | ||
| import lombok.Builder; | ||
|
|
||
| import java.math.BigDecimal; | ||
|
|
||
| @Builder(toBuilder = true) | ||
| @JsonIgnoreProperties(ignoreUnknown = true) | ||
| @JsonNaming(PropertyNamingStrategies.SnakeCaseStrategy.class) | ||
| public record ExperimentScore( | ||
| @NotBlank String name, | ||
| @NotNull BigDecimal value) { | ||
| } |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -5,9 +5,12 @@ | |
| import com.fasterxml.jackson.databind.PropertyNamingStrategies; | ||
| import com.fasterxml.jackson.databind.annotation.JsonNaming; | ||
| import io.swagger.v3.oas.annotations.media.Schema; | ||
| import jakarta.validation.Valid; | ||
| import jakarta.validation.constraints.Pattern; | ||
| import lombok.Builder; | ||
|
|
||
| import java.util.List; | ||
|
|
||
| import static com.comet.opik.utils.ValidationUtils.NULL_OR_NOT_BLANK; | ||
|
|
||
| @Builder(toBuilder = true) | ||
|
|
@@ -17,5 +20,6 @@ public record ExperimentUpdate( | |
| @Pattern(regexp = NULL_OR_NOT_BLANK, message = "must not be blank") String name, | ||
| JsonNode metadata, | ||
| ExperimentType type, | ||
| @Schema(description = "The status of the experiment") ExperimentStatus status) { | ||
| @Schema(description = "The status of the experiment") ExperimentStatus status, | ||
| @Valid List<ExperimentScore> experimentScores) { | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Minor: related to the other comment, the outer valid won't do much, instead you should move it to the inner |
||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -3,6 +3,7 @@ | |
| import com.fasterxml.jackson.annotation.JsonIgnoreProperties; | ||
| import com.fasterxml.jackson.databind.PropertyNamingStrategies; | ||
| import com.fasterxml.jackson.databind.annotation.JsonNaming; | ||
| import jakarta.annotation.Nullable; | ||
| import lombok.Builder; | ||
|
|
||
| import java.util.List; | ||
|
|
@@ -12,6 +13,9 @@ | |
| @JsonNaming(PropertyNamingStrategies.SnakeCaseStrategy.class) | ||
| public record FeedbackScoreNames(List<ScoreName> scores) { | ||
|
|
||
| public record ScoreName(String name) { | ||
| public record ScoreName(String name, @Nullable String type) { | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is a JSON payload, so for protecting the service it's important that you add: As convention, we normally assume that something not validated is Nullable, so you can remove the annotation. Finally, I recommend also adding |
||
| public ScoreName(String name) { | ||
| this(name, null); | ||
| } | ||
| } | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -6,6 +6,7 @@ | |
| import com.comet.opik.api.ExperimentGroupAggregationItem; | ||
| import com.comet.opik.api.ExperimentGroupCriteria; | ||
| import com.comet.opik.api.ExperimentGroupItem; | ||
| import com.comet.opik.api.ExperimentScore; | ||
| import com.comet.opik.api.ExperimentSearchCriteria; | ||
| import com.comet.opik.api.ExperimentStatus; | ||
| import com.comet.opik.api.ExperimentStreamRequest; | ||
|
|
@@ -17,6 +18,8 @@ | |
| import com.comet.opik.domain.filter.FilterQueryBuilder; | ||
| import com.comet.opik.domain.filter.FilterStrategy; | ||
| import com.comet.opik.domain.sorting.SortingQueryBuilder; | ||
| import com.comet.opik.utils.JsonUtils; | ||
| import com.fasterxml.jackson.core.type.TypeReference; | ||
| import com.google.common.base.Function; | ||
| import com.google.common.base.Preconditions; | ||
| import io.opentelemetry.instrumentation.annotations.WithSpan; | ||
|
|
@@ -87,7 +90,8 @@ INSERT INTO experiments ( | |
| prompt_versions, | ||
| type, | ||
| optimization_id, | ||
| status | ||
| status, | ||
| experiment_scores | ||
| ) | ||
| SELECT | ||
| if( | ||
|
|
@@ -106,7 +110,8 @@ INSERT INTO experiments ( | |
| new.prompt_versions, | ||
| new.type, | ||
| new.optimization_id, | ||
| new.status | ||
| new.status, | ||
| new.experiment_scores | ||
| FROM ( | ||
| SELECT | ||
| :id AS id, | ||
|
|
@@ -121,7 +126,8 @@ INSERT INTO experiments ( | |
| mapFromArrays(:prompt_ids, :prompt_version_ids) AS prompt_versions, | ||
| :type AS type, | ||
| :optimization_id AS optimization_id, | ||
| :status AS status | ||
| :status AS status, | ||
| :experiment_scores AS experiment_scores | ||
| ) AS new | ||
| LEFT JOIN ( | ||
| SELECT | ||
|
|
@@ -357,6 +363,7 @@ WHERE ei.trace_id IN (SELECT id FROM trace_final) | |
| e.optimization_id as optimization_id, | ||
| e.type as type, | ||
| e.status as status, | ||
| e.experiment_scores as experiment_scores, | ||
| fs.feedback_scores as feedback_scores, | ||
| ed.trace_count as trace_count, | ||
| ed.duration_values AS duration, | ||
|
|
@@ -564,26 +571,50 @@ HAVING length(fs.name) > 0 | |
| ) as fs_avg | ||
| GROUP BY experiment_id | ||
| ), | ||
| experiment_scores_agg AS ( | ||
| SELECT | ||
| experiment_id, | ||
| mapFromArrays( | ||
| groupArray(name), | ||
| groupArray(value) | ||
| ) AS experiment_scores | ||
| FROM ( | ||
| SELECT | ||
| e.id AS experiment_id, | ||
| JSONExtractString(score, 'name') AS name, | ||
| JSONExtractFloat(score, 'value') AS value | ||
|
Comment on lines
+584
to
+585
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For this, we generally use a type neutral function: I recommend switching to that. Please test to make sure I'm not wrong in my feedback here. |
||
| FROM experiments_final AS e | ||
| ARRAY JOIN JSONExtractArrayRaw(e.experiment_scores) AS score | ||
| WHERE e.experiment_scores IS NOT NULL | ||
| AND e.experiment_scores != '' | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Minor: is empty condition seems to be redundant with the length one. I'd keep the same. |
||
| AND length(e.experiment_scores) > 2 | ||
| AND length(JSONExtractString(score, 'name')) > 0 | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same about JSON_VALUE |
||
| ) AS es | ||
| GROUP BY experiment_id | ||
| ), | ||
| experiments_full AS ( | ||
| SELECT | ||
| e.id as id, | ||
| e.dataset_id AS dataset_id, | ||
| e.metadata AS metadata, | ||
| fs.feedback_scores as feedback_scores, | ||
| es.experiment_scores as experiment_scores, | ||
| ed.trace_count as trace_count, | ||
| ed.duration_values AS duration, | ||
| ed.total_estimated_cost_sum as total_estimated_cost, | ||
| ed.total_estimated_cost_avg as total_estimated_cost_avg | ||
| FROM experiments_final AS e | ||
| LEFT JOIN experiment_durations AS ed ON e.id = ed.experiment_id | ||
| LEFT JOIN feedback_scores_agg AS fs ON e.id = fs.experiment_id | ||
| LEFT JOIN experiment_scores_agg AS es ON e.id = es.experiment_id | ||
| ) | ||
| SELECT | ||
| count(DISTINCT id) as experiment_count, | ||
| sum(trace_count) as trace_count, | ||
| sum(total_estimated_cost) as total_estimated_cost, | ||
| avg(total_estimated_cost_avg) as total_estimated_cost_avg, | ||
| avgMap(feedback_scores) as feedback_scores, | ||
| avgMap(experiment_scores) as experiment_scores, | ||
| avgMap(duration) as duration, | ||
| <groupSelects> | ||
| FROM experiments_full | ||
|
|
@@ -684,6 +715,7 @@ INSERT INTO experiments ( | |
| type, | ||
| optimization_id, | ||
| status, | ||
| experiment_scores, | ||
| created_at, | ||
| last_updated_at | ||
| ) | ||
|
|
@@ -701,6 +733,7 @@ INSERT INTO experiments ( | |
| <if(type)> :type <else> type <endif> as type, | ||
| optimization_id, | ||
| <if(status)> :status <else> status <endif> as status, | ||
| <if(experiment_scores)> :experiment_scores <else> experiment_scores <endif> as experiment_scores, | ||
| created_at, | ||
| now64(9) as last_updated_at | ||
| FROM experiments | ||
|
|
@@ -729,7 +762,11 @@ private Publisher<? extends Result> insert(Experiment experiment, Connection con | |
| .bind("metadata", getStringOrDefault(experiment.metadata())) | ||
| .bind("type", Optional.ofNullable(experiment.type()).orElse(ExperimentType.REGULAR).getValue()) | ||
| .bind("optimization_id", experiment.optimizationId() != null ? experiment.optimizationId() : "") | ||
| .bind("status", Optional.ofNullable(experiment.status()).orElse(ExperimentStatus.COMPLETED).getValue()); | ||
| .bind("status", Optional.ofNullable(experiment.status()).orElse(ExperimentStatus.COMPLETED).getValue()) | ||
| .bind("experiment_scores", Optional.ofNullable(experiment.experimentScores()) | ||
| .filter(scores -> !scores.isEmpty()) | ||
| .map(JsonUtils::writeValueAsString) | ||
| .orElse("")); | ||
|
|
||
| if (experiment.promptVersion() != null) { | ||
| statement.bind("prompt_version_id", experiment.promptVersion().id()); | ||
|
|
@@ -849,6 +886,7 @@ private Publisher<Experiment> mapToDto(Result result) { | |
| .orElse(null)) | ||
| .type(ExperimentType.fromString(row.get("type", String.class))) | ||
| .status(ExperimentStatus.fromString(row.get("status", String.class))) | ||
| .experimentScores(getExperimentScores(row)) | ||
| .build(); | ||
| }); | ||
| } | ||
|
|
@@ -923,6 +961,39 @@ public static List<FeedbackScoreAverage> getFeedbackScores(Row row) { | |
| return feedbackScoresAvg.isEmpty() ? null : feedbackScoresAvg; | ||
| } | ||
|
|
||
| public static List<FeedbackScoreAverage> getExperimentScoresAggregation(Row row) { | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is a copy and past of |
||
| List<FeedbackScoreAverage> experimentScoresAvg = Optional | ||
| .ofNullable(row.get("experiment_scores", Map.class)) | ||
| .map(map -> (Map<String, ? extends Number>) map) | ||
| .orElse(Map.of()) | ||
| .entrySet() | ||
| .stream() | ||
| .map(scores -> { | ||
| return new FeedbackScoreAverage(scores.getKey(), | ||
| BigDecimal.valueOf(scores.getValue().doubleValue()).setScale(SCALE, | ||
| RoundingMode.HALF_EVEN)); | ||
| }) | ||
| .toList(); | ||
|
|
||
| return experimentScoresAvg.isEmpty() ? null : experimentScoresAvg; | ||
| } | ||
|
|
||
| public static List<ExperimentScore> getExperimentScores(Row row) { | ||
| String experimentScoresJson = row.get("experiment_scores", String.class); | ||
| if (experimentScoresJson == null || experimentScoresJson.isBlank()) { | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Minor: use StringUtils is blank for a null safe check. |
||
| return null; | ||
| } | ||
| try { | ||
| List<ExperimentScore> scores = JsonUtils.readValue(experimentScoresJson, | ||
| new TypeReference<List<ExperimentScore>>() { | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please create a static final constant for this type reference within |
||
| }); | ||
| return scores == null || scores.isEmpty() ? null : scores; | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Minor and same: you can use StringUtils.isBlank. |
||
| } catch (Exception e) { | ||
| log.warn("Failed to deserialize experiment_scores from JSON: {}", experimentScoresJson, e); | ||
| return null; | ||
| } | ||
| } | ||
|
|
||
| @WithSpan | ||
| Mono<ExperimentPage> find( | ||
| int page, int size, @NonNull ExperimentSearchCriteria experimentSearchCriteria) { | ||
|
|
@@ -1265,6 +1336,7 @@ private Publisher<ExperimentGroupAggregationItem> mapExperimentGroupAggregationI | |
| .totalEstimatedCostAvg(getCostValue(row, "total_estimated_cost_avg")) | ||
| .duration(getDuration(row)) | ||
| .feedbackScores(getFeedbackScores(row)) | ||
| .experimentScores(getExperimentScoresAggregation(row)) | ||
| .build(); | ||
| }); | ||
| } | ||
|
|
@@ -1309,6 +1381,10 @@ private ST buildUpdateTemplate(ExperimentUpdate experimentUpdate, String update) | |
| template.add("status", experimentUpdate.status().getValue()); | ||
| } | ||
|
|
||
| if (experimentUpdate.experimentScores() != null) { | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Minor: for all these conditionals here based on |
||
| template.add("experiment_scores", true); | ||
| } | ||
|
|
||
| return template; | ||
| } | ||
|
|
||
|
|
@@ -1328,6 +1404,13 @@ private void bindUpdateParams(ExperimentUpdate experimentUpdate, Statement state | |
| if (experimentUpdate.status() != null) { | ||
| statement.bind("status", experimentUpdate.status().getValue()); | ||
| } | ||
|
|
||
| if (experimentUpdate.experimentScores() != null) { | ||
| statement.bind("experiment_scores", Optional.of(experimentUpdate.experimentScores()) | ||
| .filter(scores -> !scores.isEmpty()) | ||
| .map(JsonUtils::writeValueAsString) | ||
| .orElse("")); | ||
| } | ||
| } | ||
|
|
||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -94,6 +94,7 @@ private void buildNestedGroupsWithAggregations(Map<String, GroupContentWithAggre | |
| .totalEstimatedCostAvg(BigDecimal.ZERO) | ||
| .duration(null) | ||
| .feedbackScores(List.of()) | ||
| .experimentScores(List.of()) | ||
| .build()) | ||
| .groups(new HashMap<>()) | ||
| .build(); | ||
|
|
@@ -140,6 +141,10 @@ private AggregationData calculateAggregatedChildrenValues( | |
| Map<String, BigDecimal> feedbackScoreSums = new HashMap<>(); | ||
| Map<String, Long> feedbackScoreCounts = new HashMap<>(); | ||
|
|
||
| // For experiment scores - group by name and calculate weighted averages | ||
| Map<String, BigDecimal> experimentScoreSums = new HashMap<>(); | ||
| Map<String, Long> experimentScoreCounts = new HashMap<>(); | ||
|
|
||
| for (GroupContentWithAggregations child : childGroups.values()) { | ||
| AggregationData childAgg = child.aggregations(); | ||
| long expCount = childAgg.experimentCount(); | ||
|
|
@@ -181,6 +186,19 @@ private AggregationData calculateAggregatedChildrenValues( | |
| } | ||
| } | ||
| } | ||
|
|
||
| // For experiment scores (weighted average per name) | ||
| if (childAgg.experimentScores() != null) { | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same, this is a copy and paste of an existing block. Please encapsulate in a function and reuse. |
||
| for (FeedbackScoreAverage score : childAgg.experimentScores()) { | ||
| String name = score.name(); | ||
| BigDecimal value = score.value(); | ||
|
|
||
| if (value != null && name != null) { | ||
| experimentScoreSums.merge(name, value.multiply(BigDecimal.valueOf(expCount)), BigDecimal::add); | ||
| experimentScoreCounts.merge(name, expCount, Long::sum); | ||
| } | ||
| } | ||
| } | ||
| } | ||
|
|
||
| // Calculate averages | ||
|
|
@@ -196,6 +214,8 @@ private AggregationData calculateAggregatedChildrenValues( | |
| : null; | ||
|
|
||
| List<FeedbackScoreAverage> avgFeedbackScores = buildAvgFeedbackScores(feedbackScoreSums, feedbackScoreCounts); | ||
| List<FeedbackScoreAverage> avgExperimentScores = buildAvgFeedbackScores(experimentScoreSums, | ||
| experimentScoreCounts); | ||
|
|
||
| // Build updated aggregation data | ||
| return AggregationData.builder() | ||
|
|
@@ -205,6 +225,7 @@ private AggregationData calculateAggregatedChildrenValues( | |
| .totalEstimatedCostAvg(avgCost) | ||
| .duration(avgDuration) | ||
| .feedbackScores(avgFeedbackScores) | ||
| .experimentScores(avgExperimentScores) | ||
| .build(); | ||
| } | ||
|
|
||
|
|
@@ -231,6 +252,7 @@ private AggregationData buildAggregationData(ExperimentGroupAggregationItem item | |
| .totalEstimatedCostAvg(item.totalEstimatedCostAvg()) | ||
| .duration(item.duration()) | ||
| .feedbackScores(item.feedbackScores()) | ||
| .experimentScores(item.experimentScores()) | ||
| .build(); | ||
| } | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Minor: Regarding validations, the list is fine (should be optional, like you have it), but you should probably add the following inner validations to the
ExperimentScoreobject:List<@NotNull @Valid ExperimentScore> experimentScores.