Skip to content
Open
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@
public record ModelCostData(String litellmProvider,
String inputCostPerToken,
String outputCostPerToken,
String outputCostPerVideoPerSecond,
String cacheCreationInputTokenCost,
String cacheReadInputTokenCost,
String mode,
boolean supportsVision) {
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
import com.comet.opik.utils.JsonUtils;
import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.JsonNode;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;

import javax.annotation.Nullable;

Expand Down Expand Up @@ -47,8 +49,7 @@ public class CostService {
}
}

private static final ModelPrice DEFAULT_COST = new ModelPrice(new BigDecimal("0"),
new BigDecimal("0"), new BigDecimal("0"), new BigDecimal("0"), SpanCostCalculator::defaultCost);
private static final ModelPrice DEFAULT_COST = ModelPrice.empty();

public static BigDecimal calculateCost(@Nullable String modelName, @Nullable String provider,
@Nullable Map<String, Integer> usage, @Nullable JsonNode metadata) {
Expand Down Expand Up @@ -101,20 +102,19 @@ private static Map<String, ModelPrice> parseModelPrices() throws IOException {
BigDecimal cacheReadInputTokenPrice = Optional.ofNullable(modelCost.cacheReadInputTokenCost())
.map(BigDecimal::new)
.orElse(BigDecimal.ZERO);
BigDecimal videoOutputPrice = Optional.ofNullable(modelCost.outputCostPerVideoPerSecond())
.map(BigDecimal::new)
.orElse(BigDecimal.ZERO);
ModelMode mode = ModelMode.fromValue(modelCost.mode());

BiFunction<ModelPrice, Map<String, Integer>, BigDecimal> calculator = SpanCostCalculator::defaultCost;
if (cacheCreationInputTokenPrice.compareTo(BigDecimal.ZERO) > 0
|| cacheReadInputTokenPrice.compareTo(BigDecimal.ZERO) > 0) {
calculator = PROVIDERS_CACHE_COST_CALCULATOR.getOrDefault(provider,
SpanCostCalculator::textGenerationCost);
} else if (inputPrice.compareTo(BigDecimal.ZERO) > 0 || outputPrice.compareTo(BigDecimal.ZERO) > 0) {
calculator = SpanCostCalculator::textGenerationCost;
}
BiFunction<ModelPrice, Map<String, Integer>, BigDecimal> calculator = resolveCalculator(provider, mode,
inputPrice, outputPrice, cacheCreationInputTokenPrice, cacheReadInputTokenPrice,
videoOutputPrice);

parsedModelPrices.put(
createModelProviderKey(parseModelName(modelName), PROVIDERS_MAPPING.get(provider)),
new ModelPrice(inputPrice, outputPrice, cacheCreationInputTokenPrice,
cacheReadInputTokenPrice, calculator));
cacheReadInputTokenPrice, videoOutputPrice, calculator));
}
});

Expand All @@ -138,4 +138,68 @@ private static boolean isValidModelProvider(String modelName, String provider) {

return true;
}

private static BiFunction<ModelPrice, Map<String, Integer>, BigDecimal> resolveCalculator(
String provider,
ModelMode mode,
BigDecimal inputPrice,
BigDecimal outputPrice,
BigDecimal cacheCreationInputTokenPrice,
BigDecimal cacheReadInputTokenPrice,
BigDecimal videoOutputPrice) {

if (mode.isVideoGeneration() && isPositive(videoOutputPrice)) {
return SpanCostCalculator::videoGenerationCost;
}

if (isPositive(cacheCreationInputTokenPrice) || isPositive(cacheReadInputTokenPrice)) {
return PROVIDERS_CACHE_COST_CALCULATOR.getOrDefault(provider, SpanCostCalculator::textGenerationCost);
}

if (isPositive(inputPrice) || isPositive(outputPrice)) {
return SpanCostCalculator::textGenerationCost;
}

return SpanCostCalculator::defaultCost;
}

private static boolean isPositive(BigDecimal value) {
return Optional.ofNullable(value).map(v -> v.compareTo(BigDecimal.ZERO) > 0).orElse(false);
}

@RequiredArgsConstructor
private enum ModelMode {
TEXT_GENERATION("text_generation"),
CHAT("chat"),
EMBEDDING("embedding"),
COMPLETION("completion"),
IMAGE_GENERATION("image_generation"),
AUDIO_TRANSCRIPTION("audio_transcription"),
AUDIO_SPEECH("audio_speech"),
MODERATION("moderation"),
RERANK("rerank"),
SEARCH("search"),
VIDEO_GENERATION("video_generation");

private static final ModelMode DEFAULT = TEXT_GENERATION;
private final String value;

static ModelMode fromValue(String rawValue) {
if (StringUtils.isBlank(rawValue)) {
return DEFAULT;
}

for (ModelMode mode : values()) {
if (mode.value.equalsIgnoreCase(rawValue)) {
return mode;
}
}

return DEFAULT;
}

boolean isVideoGeneration() {
return this == VIDEO_GENERATION;
}
}
}
Original file line number Diff line number Diff line change
@@ -1,15 +1,28 @@
package com.comet.opik.domain.cost;

import lombok.Builder;
import lombok.NonNull;

import java.math.BigDecimal;
import java.util.Map;
import java.util.function.BiFunction;

@Builder(toBuilder = true)
public record ModelPrice(
@NonNull BigDecimal inputPrice,
@NonNull BigDecimal outputPrice,
@NonNull BigDecimal cacheCreationInputTokenPrice,
@NonNull BigDecimal cacheReadInputTokenPrice,
@NonNull BigDecimal videoOutputPrice,
@NonNull BiFunction<ModelPrice, Map<String, Integer>, BigDecimal> calculator) {

public static ModelPrice empty() {
return new ModelPrice(
BigDecimal.ZERO,
BigDecimal.ZERO,
BigDecimal.ZERO,
BigDecimal.ZERO,
BigDecimal.ZERO,
SpanCostCalculator::defaultCost);
}
}
Original file line number Diff line number Diff line change
@@ -1,19 +1,23 @@
package com.comet.opik.domain.cost;

import lombok.NonNull;
import lombok.experimental.UtilityClass;

import java.math.BigDecimal;
import java.util.Map;

@UtilityClass
class SpanCostCalculator {
public static BigDecimal textGenerationCost(ModelPrice modelPrice, Map<String, Integer> usage) {
private static final String VIDEO_DURATION_KEY = "video_duration_seconds";

public static BigDecimal textGenerationCost(@NonNull ModelPrice modelPrice, @NonNull Map<String, Integer> usage) {
return modelPrice.inputPrice().multiply(BigDecimal.valueOf(usage.getOrDefault("prompt_tokens", 0)))
.add(modelPrice.outputPrice()
.multiply(BigDecimal.valueOf(usage.getOrDefault("completion_tokens", 0))));
}

public static BigDecimal textGenerationWithCacheCostOpenAI(ModelPrice modelPrice, Map<String, Integer> usage) {
public static BigDecimal textGenerationWithCacheCostOpenAI(@NonNull ModelPrice modelPrice,
@NonNull Map<String, Integer> usage) {

// In OpenAI usage format, input tokens includes the cached input tokens, so we need to substract them to compute the correct input token count
// Don't generalize yet as other providers seems to separate the cached tokens from non-cached tokens
Expand All @@ -38,13 +42,15 @@ public static BigDecimal textGenerationWithCacheCostOpenAI(ModelPrice modelPrice
.add(modelPrice.cacheReadInputTokenPrice().multiply(BigDecimal.valueOf(cachedReadInputTokens)));
}

public static BigDecimal textGenerationWithCacheCostAnthropic(ModelPrice modelPrice, Map<String, Integer> usage) {
public static BigDecimal textGenerationWithCacheCostAnthropic(@NonNull ModelPrice modelPrice,
@NonNull Map<String, Integer> usage) {
return textGenerationWithCachedTokensNotIncludedInCost(modelPrice, usage, "original_usage.input_tokens",
"original_usage.output_tokens", "original_usage.cache_read_input_tokens",
"original_usage.cache_creation_input_tokens");
}

public static BigDecimal textGenerationWithCacheCostBedrock(ModelPrice modelPrice, Map<String, Integer> usage) {
public static BigDecimal textGenerationWithCacheCostBedrock(@NonNull ModelPrice modelPrice,
@NonNull Map<String, Integer> usage) {
return textGenerationWithCachedTokensNotIncludedInCost(modelPrice, usage, "original_usage.inputTokens",
"original_usage.outputTokens", "original_usage.cacheReadInputTokens",
"original_usage.cacheWriteInputTokens");
Expand All @@ -63,8 +69,8 @@ public static BigDecimal textGenerationWithCacheCostBedrock(ModelPrice modelPric
* @param cacheCreationInputTokensKey Key for cache creation tokens in usage map
* @return The calculated cost as a BigDecimal
*/
private static BigDecimal textGenerationWithCachedTokensNotIncludedInCost(ModelPrice modelPrice,
Map<String, Integer> usage,
private static BigDecimal textGenerationWithCachedTokensNotIncludedInCost(@NonNull ModelPrice modelPrice,
@NonNull Map<String, Integer> usage,
String inputTokensKey, String outputTokensKey, String cacheReadInputTokensKey,
String cacheCreationInputTokensKey) {

Expand All @@ -80,7 +86,21 @@ private static BigDecimal textGenerationWithCachedTokensNotIncludedInCost(ModelP
.multiply(BigDecimal.valueOf(usage.getOrDefault(cacheReadInputTokensKey, 0))));
}

public static BigDecimal defaultCost(ModelPrice modelPrice, Map<String, Integer> usage) {
public static BigDecimal defaultCost(@NonNull ModelPrice modelPrice, @NonNull Map<String, Integer> usage) {
return BigDecimal.ZERO;
}

public static BigDecimal videoGenerationCost(@NonNull ModelPrice modelPrice,
@NonNull Map<String, Integer> usage) {
int durationSeconds = usage.getOrDefault(VIDEO_DURATION_KEY, 0);
BigDecimal videoPrice = modelPrice.videoOutputPrice();
if (durationSeconds <= 0 || !isPositive(videoPrice)) {
return BigDecimal.ZERO;
}
return videoPrice.multiply(BigDecimal.valueOf(durationSeconds));
}

private static boolean isPositive(BigDecimal value) {
return value != null && value.compareTo(BigDecimal.ZERO) > 0;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
package com.comet.opik.domain.cost;

import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.node.ObjectNode;
import org.junit.jupiter.api.Test;

import java.math.BigDecimal;
import java.util.Map;

import static org.assertj.core.api.Assertions.assertThat;

class CostServiceTest {

private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper();

@Test
void calculateCostForVideoGenerationUsesDuration() {
BigDecimal cost = CostService.calculateCost("sora-2", "openai",
Map.of("video_duration_seconds", 4), null);

assertThat(cost).isEqualByComparingTo("0.4");
}

@Test
void calculateCostUsesCacheAwareCalculatorWhenCachePricesConfigured() {
Map<String, Integer> usage = Map.of(
"original_usage.inputTokens", 100,
"original_usage.outputTokens", 20,
"original_usage.cacheReadInputTokens", 10,
"original_usage.cacheWriteInputTokens", 5);

BigDecimal cost = CostService.calculateCost("anthropic.claude-3-5-haiku-20241022-v1:0", "bedrock", usage, null);

assertThat(cost).isEqualByComparingTo("0.0001658");
}

@Test
void calculateCostFallsBackToMetadataWhenNoMatchingModelFound() {
ObjectNode metadata = OBJECT_MAPPER.createObjectNode();
metadata.putObject("cost")
.put("currency", "USD")
.put("total_tokens", 0.42);

BigDecimal cost = CostService.calculateCost("unknown-model", "unknown", Map.of(), metadata);

assertThat(cost).isEqualByComparingTo("0.42");
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
package com.comet.opik.domain.cost;

import org.junit.jupiter.api.Test;

import java.math.BigDecimal;
import java.util.Map;

import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;

class SpanCostCalculatorTest {

@Test
void videoGenerationCostReturnsZeroWhenPriceIsZero() {
ModelPrice modelPrice = new ModelPrice(BigDecimal.ZERO, BigDecimal.ZERO, BigDecimal.ZERO, BigDecimal.ZERO,
BigDecimal.ZERO, SpanCostCalculator::defaultCost);

BigDecimal cost = SpanCostCalculator.videoGenerationCost(modelPrice, Map.of("video_duration_seconds", 10));

assertThat(cost).isZero();
}

@Test
void videoGenerationCostValidatesArguments() {
assertThatThrownBy(() -> SpanCostCalculator.videoGenerationCost(null, Map.of()))
.isInstanceOf(NullPointerException.class);
ModelPrice modelPrice = new ModelPrice(BigDecimal.ONE, BigDecimal.ONE, BigDecimal.ONE, BigDecimal.ONE,
BigDecimal.ONE, SpanCostCalculator::textGenerationCost);
assertThatThrownBy(() -> SpanCostCalculator.videoGenerationCost(modelPrice, null))
.isInstanceOf(NullPointerException.class);
}

@Test
void videoGenerationCostMultipliesDurationAndPrice() {
ModelPrice modelPrice = new ModelPrice(BigDecimal.ZERO, BigDecimal.ZERO, BigDecimal.ZERO, BigDecimal.ZERO,
new BigDecimal("0.5"), SpanCostCalculator::defaultCost);

BigDecimal cost = SpanCostCalculator.videoGenerationCost(modelPrice, Map.of("video_duration_seconds", 2));

assertThat(cost).isEqualByComparingTo("1.0");
}
}
Loading