Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@
public record ModelCostData(String litellmProvider,
String inputCostPerToken,
String outputCostPerToken,
String outputCostPerVideoPerSecond,
String cacheCreationInputTokenCost,
String cacheReadInputTokenCost,
String mode,
boolean supportsVision) {
}
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ public class CostService {
}

private static final ModelPrice DEFAULT_COST = new ModelPrice(new BigDecimal("0"),
new BigDecimal("0"), new BigDecimal("0"), new BigDecimal("0"), SpanCostCalculator::defaultCost);
new BigDecimal("0"), new BigDecimal("0"), new BigDecimal("0"), new BigDecimal("0"),
SpanCostCalculator::defaultCost);

public static BigDecimal calculateCost(@Nullable String modelName, @Nullable String provider,
@Nullable Map<String, Integer> usage, @Nullable JsonNode metadata) {
Expand Down Expand Up @@ -101,9 +102,15 @@ private static Map<String, ModelPrice> parseModelPrices() throws IOException {
BigDecimal cacheReadInputTokenPrice = Optional.ofNullable(modelCost.cacheReadInputTokenCost())
.map(BigDecimal::new)
.orElse(BigDecimal.ZERO);
BigDecimal videoOutputPrice = Optional.ofNullable(modelCost.outputCostPerVideoPerSecond())
.map(BigDecimal::new)
.orElse(BigDecimal.ZERO);
String mode = Optional.ofNullable(modelCost.mode()).orElse("");

BiFunction<ModelPrice, Map<String, Integer>, BigDecimal> calculator = SpanCostCalculator::defaultCost;
if (cacheCreationInputTokenPrice.compareTo(BigDecimal.ZERO) > 0
if ("video_generation".equalsIgnoreCase(mode) && videoOutputPrice.compareTo(BigDecimal.ZERO) > 0) {
calculator = SpanCostCalculator::videoGenerationCost;
} else if (cacheCreationInputTokenPrice.compareTo(BigDecimal.ZERO) > 0
|| cacheReadInputTokenPrice.compareTo(BigDecimal.ZERO) > 0) {
calculator = PROVIDERS_CACHE_COST_CALCULATOR.getOrDefault(provider,
SpanCostCalculator::textGenerationCost);
Expand All @@ -114,7 +121,7 @@ private static Map<String, ModelPrice> parseModelPrices() throws IOException {
parsedModelPrices.put(
createModelProviderKey(parseModelName(modelName), PROVIDERS_MAPPING.get(provider)),
new ModelPrice(inputPrice, outputPrice, cacheCreationInputTokenPrice,
cacheReadInputTokenPrice, calculator));
cacheReadInputTokenPrice, videoOutputPrice, calculator));
}
});

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,6 @@ public record ModelPrice(
@NonNull BigDecimal outputPrice,
@NonNull BigDecimal cacheCreationInputTokenPrice,
@NonNull BigDecimal cacheReadInputTokenPrice,
@NonNull BigDecimal videoOutputPrice,
@NonNull BiFunction<ModelPrice, Map<String, Integer>, BigDecimal> calculator) {
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@

@UtilityClass
class SpanCostCalculator {
private static final String VIDEO_DURATION_KEY = "video_duration_seconds";

public static BigDecimal textGenerationCost(ModelPrice modelPrice, Map<String, Integer> usage) {
return modelPrice.inputPrice().multiply(BigDecimal.valueOf(usage.getOrDefault("prompt_tokens", 0)))
.add(modelPrice.outputPrice()
Expand Down Expand Up @@ -83,4 +85,12 @@ private static BigDecimal textGenerationWithCachedTokensNotIncludedInCost(ModelP
public static BigDecimal defaultCost(ModelPrice modelPrice, Map<String, Integer> usage) {
return BigDecimal.ZERO;
}

public static BigDecimal videoGenerationCost(ModelPrice modelPrice, Map<String, Integer> usage) {
int durationSeconds = usage.getOrDefault(VIDEO_DURATION_KEY, 0);
if (durationSeconds <= 0) {
return BigDecimal.ZERO;
}
return modelPrice.videoOutputPrice().multiply(BigDecimal.valueOf(durationSeconds));
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
package com.comet.opik.domain.cost;

import org.junit.jupiter.api.Test;

import java.math.BigDecimal;
import java.util.Map;

import static org.assertj.core.api.Assertions.assertThat;

class CostServiceTest {

@Test
void calculateCostForVideoGenerationUsesDuration() {
BigDecimal cost = CostService.calculateCost("sora-2", "openai",
Map.of("video_duration_seconds", 4), null);

assertThat(cost).isEqualByComparingTo("0.4");
}
}
Loading