diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 69da91e..835c886 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -15,4 +15,4 @@ jobs: - uses: actions/checkout@v4 - uses: astral-sh/setup-uv@v5 - name: Run unit tests - run: uv run --group test pytest tee_gateway/test/test_tool_forwarding.py tee_gateway/test/test_tee_core.py -v --import-mode=importlib + run: uv run --group test pytest tee_gateway/test/test_tool_forwarding.py tee_gateway/test/test_tee_core.py tests/test_pricing.py -v --import-mode=importlib diff --git a/Makefile b/Makefile index 902c2f6..0523afa 100644 --- a/Makefile +++ b/Makefile @@ -44,9 +44,6 @@ $(image_eif): $(image_tar) .PHONY: run run: $(image_eif) - -rm -f /tmp/network.sock - -kill $$(lsof -ti :2222 2>/dev/null) - -rm -f nohup.out nitro-cli terminate-enclave --all ./scripts/run-enclave.sh $(image_eif) diff --git a/measurements.txt b/measurements.txt index a1a6f5f..38132ac 100644 --- a/measurements.txt +++ b/measurements.txt @@ -1,8 +1,8 @@ { "Measurements": { "HashAlgorithm": "Sha384 { ... }", - "PCR0": "6640e0ede60135375ad19eb5e41a045ce49cf84bf6919cc1f7e61556c2566fa855fd4e64247b1bde07f216761a101b83", + "PCR0": "f2945c350e80c9431649dd691cb4ecebc11174c0201e22c151b0b11c0aa41074174b57cc17343444380f212baaf5752c", "PCR1": "4b4d5b3661b3efc12920900c80e126e4ce783c522de6c02a2a5bf7af3a2b9327b86776f188e4be1c1c404a129dbda493", - "PCR2": "45b95f6151f44a0d2c0bdb7d64c1861f49d8ef59e726b86ded25155481654c2bf0a58ee65ee64c364528994a9b260c5d" + "PCR2": "01d8b4615e5dc0417d57713b6d8d2c3b0ab3ec3e1848658c116a2986da35b47c04001c8acbea57858446d0bf0a35de68" } } diff --git a/tee_gateway/__main__.py b/tee_gateway/__main__.py index fbaca68..1b6485d 100644 --- a/tee_gateway/__main__.py +++ b/tee_gateway/__main__.py @@ -30,6 +30,7 @@ from x402.schemas import AssetAmount from x402.server import x402ResourceServerSync from x402.session import SessionStore +import types as _types import x402.http.middleware.flask as x402_flask from .util import dynamic_session_cost_calculator @@ -377,7 +378,7 @@ def _patched_read_body_bytes(environ): x402_flask._read_body_bytes = _patched_read_body_bytes -payment_middleware( +_payment_mw = payment_middleware( application, routes=routes, server=server, @@ -386,6 +387,78 @@ def _patched_read_body_bytes(environ): session_idle_timeout=100, session_cost_calculator=dynamic_session_cost_calculator, ) + + +def _strict_resolve_session_request_cost( + self, + *, + method: str, + path: str, + request_body_bytes: bytes, + response_body_bytes: bytes, + payment_payload: object, + payment_requirements: object, + status_code: int | None, + output_object: object = None, + is_streaming: bool = False, +) -> int: + """Replacement for PaymentMiddleware._resolve_session_request_cost. + + Identical to the upstream implementation except that exceptions raised by + the dynamic cost calculator are NOT caught. This means a request whose + cost cannot be determined (unknown model, missing usage data, etc.) will + result in a 500 error rather than silently falling back to the static cap + amount and charging the user an incorrect amount. + """ + from x402.http.middleware.flask import _parse_json_bytes as _x402_parse_json # noqa: PLC0415 + + default_cost = self._get_session_cost(payment_requirements) + if not self._should_charge_response(status_code): + return default_cost + if not callable(self._session_cost_calculator): + return default_cost + + request_object = _x402_parse_json(request_body_bytes) + response_object = ( + output_object + if output_object is not None + else _x402_parse_json(response_body_bytes) + ) + + callback_context = { + "method": method, + "path": path, + "status_code": status_code, + "is_streaming": is_streaming, + "request_body_bytes": request_body_bytes, + "response_body_bytes": response_body_bytes, + "request_json": request_object + if isinstance(request_object, (dict, list)) + else None, + "response_json": response_object + if isinstance(response_object, (dict, list)) + else None, + "response_object": response_object, + "payment_payload": payment_payload, + "payment_requirements": payment_requirements, + "default_cost": default_cost, + } + + # Do NOT catch exceptions here — let them propagate so the request fails + # with a 500 rather than silently charging the static fallback amount. + dynamic_cost = self._session_cost_calculator(callback_context) + if dynamic_cost is None: + raise ValueError( + f"dynamic_session_cost_calculator returned None for {method} {path}; " + "cannot determine request cost" + ) + return self._coerce_non_negative_int(dynamic_cost) + + +_payment_mw._resolve_session_request_cost = _types.MethodType( # type: ignore[method-assign] + _strict_resolve_session_request_cost, _payment_mw +) + logger.info("x402 payment middleware initialized") if __name__ == "__main__": diff --git a/tee_gateway/definitions.py b/tee_gateway/definitions.py index 0d62dfb..98657a9 100644 --- a/tee_gateway/definitions.py +++ b/tee_gateway/definitions.py @@ -71,8 +71,8 @@ # /v1/chat/completions — 0.01 OUSDC precheck (6 decimals: 10_000 = $0.01) CHAT_COMPLETIONS_USDC_AMOUNT: str = "10000" -# /v1/chat/completions — 0.05 OPG precheck (18 decimals) -CHAT_COMPLETIONS_OPG_AMOUNT: str = "50000000000000000" +# /v1/chat/completions — 0.1 OPG precheck (18 decimals) +CHAT_COMPLETIONS_OPG_AMOUNT: str = "100000000000000000" # /v1/completions — 0.01 USDC precheck (6 decimals: 10_000 = $0.01) COMPLETIONS_USDC_AMOUNT: str = "10000" diff --git a/tee_gateway/model_registry.py b/tee_gateway/model_registry.py index f1bbe24..4027215 100644 --- a/tee_gateway/model_registry.py +++ b/tee_gateway/model_registry.py @@ -87,24 +87,6 @@ class SupportedModel(Enum): input_price_usd=Decimal("0.000005"), output_price_usd=Decimal("0.000025"), ) - CLAUDE_3_7_SONNET = ModelConfig( - provider="anthropic", - api_name="claude-3-7-sonnet-latest", - input_price_usd=Decimal("0.000003"), - output_price_usd=Decimal("0.000015"), - ) - CLAUDE_3_5_HAIKU = ModelConfig( - provider="anthropic", - api_name="claude-3-5-haiku-latest", - input_price_usd=Decimal("0.000001"), - output_price_usd=Decimal("0.000005"), - ) - CLAUDE_4_0_SONNET = ModelConfig( - provider="anthropic", - api_name="claude-sonnet-4-0", - input_price_usd=Decimal("0.000003"), - output_price_usd=Decimal("0.000015"), - ) # ── Google Gemini ─────────────────────────────────────────────────── GEMINI_2_5_FLASH = ModelConfig( @@ -128,12 +110,6 @@ class SupportedModel(Enum): output_price_usd=Decimal("0.0000004"), thinking_budget=0, ) - GEMINI_3_PRO_PREVIEW = ModelConfig( - provider="google", - api_name="gemini-3-pro-preview", - input_price_usd=Decimal("0.000002"), - output_price_usd=Decimal("0.000012"), - ) GEMINI_3_FLASH_PREVIEW = ModelConfig( provider="google", api_name="gemini-3-flash-preview", @@ -166,6 +142,14 @@ class SupportedModel(Enum): input_price_usd=Decimal("0.0000002"), output_price_usd=Decimal("0.0000005"), ) + + # ── Legacy models (not in current SDK — retained for older SDK versions) ── + CLAUDE_4_0_SONNET = ModelConfig( + provider="anthropic", + api_name="claude-sonnet-4-0", + input_price_usd=Decimal("0.000003"), + output_price_usd=Decimal("0.000015"), + ) GROK_3_MINI = ModelConfig( provider="x-ai", api_name="grok-3-mini", @@ -178,12 +162,6 @@ class SupportedModel(Enum): input_price_usd=Decimal("0.000003"), output_price_usd=Decimal("0.000015"), ) - GROK_2 = ModelConfig( - provider="x-ai", - api_name="grok-2-latest", - input_price_usd=Decimal("0.000002"), - output_price_usd=Decimal("0.00001"), - ) # Canonical lookup: user-facing model name → SupportedModel @@ -202,14 +180,10 @@ class SupportedModel(Enum): "claude-haiku-4-5": SupportedModel.CLAUDE_HAIKU_4_5, "claude-opus-4-5": SupportedModel.CLAUDE_OPUS_4_5, "claude-opus-4-6": SupportedModel.CLAUDE_OPUS_4_6, - "claude-3.7-sonnet": SupportedModel.CLAUDE_3_7_SONNET, - "claude-3.5-haiku": SupportedModel.CLAUDE_3_5_HAIKU, - "claude-4.0-sonnet": SupportedModel.CLAUDE_4_0_SONNET, # Google "gemini-2.5-flash": SupportedModel.GEMINI_2_5_FLASH, "gemini-2.5-pro": SupportedModel.GEMINI_2_5_PRO, "gemini-2.5-flash-lite": SupportedModel.GEMINI_2_5_FLASH_LITE, - "gemini-3-pro-preview": SupportedModel.GEMINI_3_PRO_PREVIEW, "gemini-3-flash-preview": SupportedModel.GEMINI_3_FLASH_PREVIEW, # xAI "grok-4": SupportedModel.GROK_4, @@ -217,12 +191,13 @@ class SupportedModel(Enum): "grok-4-1-fast": SupportedModel.GROK_4_1_FAST, "grok-4.1-fast": SupportedModel.GROK_4_1_FAST, "grok-4-1-fast-non-reasoning": SupportedModel.GROK_4_1_FAST_NON_REASONING, - "grok-3-mini-beta": SupportedModel.GROK_3_MINI, + # Legacy — not in current SDK, retained for older SDK versions + "claude-sonnet-4-0": SupportedModel.CLAUDE_4_0_SONNET, + "claude-4.0-sonnet": SupportedModel.CLAUDE_4_0_SONNET, # alternate dot notation + "grok-3-mini-beta": SupportedModel.GROK_3_MINI, # old beta alias "grok-3-mini": SupportedModel.GROK_3_MINI, - "grok-3-beta": SupportedModel.GROK_3, + "grok-3-beta": SupportedModel.GROK_3, # old beta alias "grok-3": SupportedModel.GROK_3, - "grok-2-1212": SupportedModel.GROK_2, - "grok-2": SupportedModel.GROK_2, } # Build the rate card automatically from the enum (for backward compat with util.py) diff --git a/tee_gateway/util.py b/tee_gateway/util.py index 8821c61..47559d9 100644 --- a/tee_gateway/util.py +++ b/tee_gateway/util.py @@ -157,7 +157,6 @@ def _deserialize_dict(data, boxed_type): from tee_gateway.definitions import ( # noqa: E402 ASSET_DECIMALS_BY_ADDRESS, - DEFAULT_ASSET_DECIMALS, ) from tee_gateway.model_registry import get_model_config # noqa: E402 @@ -238,31 +237,49 @@ def _normalize_model_name(model: str | None) -> str | None: def _extract_usage_tokens( response_json: dict[str, Any] | None, -) -> tuple[int, int] | None: +) -> tuple[int, int]: + """Extract (input_tokens, output_tokens) from response JSON. + + Raises ValueError if usage data is missing or malformed — no silent fallback. + """ if not isinstance(response_json, dict): - return None + raise ValueError("response_json is not a dict; cannot extract usage tokens") usage = response_json.get("usage") if not isinstance(usage, dict): - return None + raise ValueError( + "response_json has no 'usage' dict; cannot extract usage tokens" + ) prompt_tokens = usage.get("prompt_tokens", usage.get("input_tokens")) completion_tokens = usage.get("completion_tokens", usage.get("output_tokens")) if prompt_tokens is None or completion_tokens is None: - return None + raise ValueError(f"usage dict is missing token counts: {usage!r}") try: return max(0, int(prompt_tokens)), max(0, int(completion_tokens)) - except (TypeError, ValueError): - return None + except (TypeError, ValueError) as exc: + raise ValueError(f"Could not parse token counts from usage: {usage!r}") from exc def _extract_model_from_context( request_json: dict[str, Any] | None, response_json: dict[str, Any] | None, -) -> str | None: - req_model = request_json.get("model") if isinstance(request_json, dict) else None - resp_model = response_json.get("model") if isinstance(response_json, dict) else None - return _normalize_model_name(req_model or resp_model) +) -> str: + """Extract and normalize model name from request JSON. + + Uses only the request model name — the response model field is ignored + because providers may return a versioned alias that differs from the + user-facing name. Raises ValueError if the model name is absent. + """ + if not isinstance(request_json, dict): + raise ValueError("request_json is not a dict; cannot extract model name") + req_model = request_json.get("model") + if not req_model: + raise ValueError("request_json has no 'model' field") + normalized = _normalize_model_name(req_model) + if not normalized: + raise ValueError(f"model name normalizes to empty string: {req_model!r}") + return normalized def _extract_asset_decimals_from_requirements(payment_requirements: Any) -> int: @@ -272,15 +289,25 @@ def _extract_asset_decimals_from_requirements(payment_requirements: Any) -> int: if not asset and isinstance(req.get("price"), dict): asset = req["price"].get("asset") - if isinstance(asset, str): - return ASSET_DECIMALS_BY_ADDRESS.get(asset.lower(), DEFAULT_ASSET_DECIMALS) - return DEFAULT_ASSET_DECIMALS + if not isinstance(asset, str) or not asset: + raise ValueError( + f"payment_requirements has no recognizable asset address; " + f"cannot determine token decimals: {req!r}" + ) + + asset_lower = asset.lower() + if asset_lower not in ASSET_DECIMALS_BY_ADDRESS: + raise ValueError( + f"Unknown asset address {asset!r}; not in ASSET_DECIMALS_BY_ADDRESS. " + f"Add it to definitions.py before accepting payments with this token." + ) + return ASSET_DECIMALS_BY_ADDRESS[asset_lower] def dynamic_session_cost_calculator(context: dict[str, Any]) -> int: """Compute UPTO per-request cost in token smallest units from actual usage. - Raises ValueError if the model is not in the registry (no silent fallback). + Raises ValueError on any missing or unrecognised input — no silent fallback. """ request_json = context.get("request_json") response_json = context.get("response_json") @@ -291,18 +318,11 @@ def dynamic_session_cost_calculator(context: dict[str, Any]) -> int: ) model = _extract_model_from_context(request_json, response_json) - if not model: - raise ValueError("Could not extract model name from request/response") # get_model_config raises ValueError for unknown models — no fallback cfg = get_model_config(model) - usage_tokens = _extract_usage_tokens(response_json) - if not usage_tokens: - logger.warning("No usage tokens in response for model=%s; charging zero", model) - return 0 - - input_tokens, output_tokens = usage_tokens + input_tokens, output_tokens = _extract_usage_tokens(response_json) input_rate = cfg.input_price_usd output_rate = cfg.output_price_usd diff --git a/tests/test_pricing.py b/tests/test_pricing.py new file mode 100644 index 0000000..37a2db8 --- /dev/null +++ b/tests/test_pricing.py @@ -0,0 +1,491 @@ +""" +Unit tests for dynamic pricing / cost calculation across all supported models. + +Tests verify that: + - Every user-facing model name resolves to the correct ModelConfig + - dynamic_session_cost_calculator produces the right amount in token + smallest-units for each provider and token currency (OPG / USDC) + - Edge cases (no usage, unknown model, bad context) are handled correctly +""" + +import unittest +from decimal import Decimal + +from tee_gateway.definitions import BASE_OPG_ADDRESS, USDC_ADDRESS +from tee_gateway.model_registry import ( + _MODEL_LOOKUP, + get_model_config, +) +from tee_gateway.util import dynamic_session_cost_calculator + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _opg_requirements() -> dict: + """Fake PaymentRequirements dict for OPG (18 decimals).""" + return {"asset": BASE_OPG_ADDRESS, "amount": "50000000000000000"} + + +def _usdc_requirements() -> dict: + """Fake PaymentRequirements dict for USDC (6 decimals).""" + return {"asset": USDC_ADDRESS, "amount": "10000"} + + +def _ctx(model: str, input_tokens: int, output_tokens: int, requirements=None) -> dict: + """Build a minimal calculator context.""" + return { + "request_json": {"model": model, "messages": []}, + "response_json": { + "model": model, + "usage": { + "prompt_tokens": input_tokens, + "completion_tokens": output_tokens, + "total_tokens": input_tokens + output_tokens, + }, + }, + "payment_requirements": requirements or _opg_requirements(), + } + + +def _expected_cost_opg(model: str, input_tokens: int, output_tokens: int) -> int: + """Compute expected cost in OPG smallest units (18 decimals, ROUND_CEILING).""" + from decimal import ROUND_CEILING + + cfg = get_model_config(model) + total_usd = ( + Decimal(input_tokens) * cfg.input_price_usd + + Decimal(output_tokens) * cfg.output_price_usd + ) + return int((total_usd * Decimal(10**18)).to_integral_value(rounding=ROUND_CEILING)) + + +def _expected_cost_usdc(model: str, input_tokens: int, output_tokens: int) -> int: + """Compute expected cost in USDC smallest units (6 decimals, ROUND_CEILING).""" + from decimal import ROUND_CEILING + + cfg = get_model_config(model) + total_usd = ( + Decimal(input_tokens) * cfg.input_price_usd + + Decimal(output_tokens) * cfg.output_price_usd + ) + return int((total_usd * Decimal(10**6)).to_integral_value(rounding=ROUND_CEILING)) + + +# --------------------------------------------------------------------------- +# Model registry tests +# --------------------------------------------------------------------------- + + +class TestModelRegistry(unittest.TestCase): + """All user-facing model names must resolve without error.""" + + def test_all_lookup_keys_resolve(self): + """Every key in _MODEL_LOOKUP must resolve to a valid ModelConfig.""" + for name, enum_val in _MODEL_LOOKUP.items(): + with self.subTest(model=name): + cfg = get_model_config(name) + self.assertIsNotNone(cfg) + self.assertIsNotNone(cfg.provider) + self.assertIsNotNone(cfg.api_name) + self.assertGreater(cfg.input_price_usd, 0) + self.assertGreater(cfg.output_price_usd, 0) + + # ── Anthropic Sonnet ──────────────────────────────────────────────────── + + def test_claude_sonnet_4_5_resolves(self): + cfg = get_model_config("claude-sonnet-4-5") + self.assertEqual(cfg.provider, "anthropic") + self.assertEqual(cfg.input_price_usd, Decimal("0.000003")) + self.assertEqual(cfg.output_price_usd, Decimal("0.000015")) + + def test_claude_sonnet_4_6_resolves(self): + cfg = get_model_config("claude-sonnet-4-6") + self.assertEqual(cfg.provider, "anthropic") + self.assertEqual(cfg.input_price_usd, Decimal("0.000003")) + self.assertEqual(cfg.output_price_usd, Decimal("0.000015")) + + def test_claude_sonnet_4_0_hyphen_resolves(self): + """claude-sonnet-4-0 (legacy) must still resolve for older SDK versions.""" + cfg = get_model_config("claude-sonnet-4-0") + self.assertEqual(cfg, get_model_config("claude-4.0-sonnet")) + self.assertEqual(cfg.provider, "anthropic") + + # ── Anthropic Haiku ───────────────────────────────────────────────────── + + def test_claude_haiku_4_5_resolves(self): + cfg = get_model_config("claude-haiku-4-5") + self.assertEqual(cfg.provider, "anthropic") + self.assertEqual(cfg.input_price_usd, Decimal("0.000001")) + self.assertEqual(cfg.output_price_usd, Decimal("0.000005")) + + # ── Anthropic Opus ────────────────────────────────────────────────────── + + def test_claude_opus_4_5_resolves(self): + cfg = get_model_config("claude-opus-4-5") + self.assertEqual(cfg.provider, "anthropic") + self.assertEqual(cfg.input_price_usd, Decimal("0.000005")) + self.assertEqual(cfg.output_price_usd, Decimal("0.000025")) + + def test_claude_opus_4_6_resolves(self): + cfg = get_model_config("claude-opus-4-6") + self.assertEqual(cfg.provider, "anthropic") + + # ── OpenAI ────────────────────────────────────────────────────────────── + + def test_gpt_4_1_resolves(self): + cfg = get_model_config("gpt-4.1") + self.assertEqual(cfg.provider, "openai") + self.assertEqual(cfg.input_price_usd, Decimal("0.000002")) + self.assertEqual(cfg.output_price_usd, Decimal("0.000008")) + + def test_gpt_4_1_full_date_resolves(self): + cfg = get_model_config("gpt-4.1-2025-04-14") + self.assertEqual(cfg, get_model_config("gpt-4.1")) + + def test_o4_mini_resolves(self): + cfg = get_model_config("o4-mini") + self.assertEqual(cfg.provider, "openai") + + def test_gpt_5_resolves(self): + cfg = get_model_config("gpt-5") + self.assertEqual(cfg.provider, "openai") + + def test_gpt_5_mini_resolves(self): + cfg = get_model_config("gpt-5-mini") + self.assertEqual(cfg.provider, "openai") + + def test_gpt_5_2_resolves(self): + cfg = get_model_config("gpt-5.2") + self.assertEqual(cfg.provider, "openai") + + # ── Google ────────────────────────────────────────────────────────────── + + def test_gemini_2_5_flash_resolves(self): + cfg = get_model_config("gemini-2.5-flash") + self.assertEqual(cfg.provider, "google") + self.assertEqual(cfg.input_price_usd, Decimal("0.0000003")) + + def test_gemini_2_5_pro_resolves(self): + cfg = get_model_config("gemini-2.5-pro") + self.assertEqual(cfg.provider, "google") + + def test_gemini_2_5_flash_lite_resolves(self): + cfg = get_model_config("gemini-2.5-flash-lite") + self.assertEqual(cfg.provider, "google") + + def test_gemini_3_flash_preview_resolves(self): + cfg = get_model_config("gemini-3-flash-preview") + self.assertEqual(cfg.provider, "google") + + # ── xAI Grok ──────────────────────────────────────────────────────────── + + def test_grok_4_resolves(self): + cfg = get_model_config("grok-4") + self.assertEqual(cfg.provider, "x-ai") + + def test_grok_4_fast_resolves(self): + cfg = get_model_config("grok-4-fast") + self.assertEqual(cfg.provider, "x-ai") + + def test_grok_4_1_fast_resolves(self): + cfg = get_model_config("grok-4-1-fast") + self.assertEqual(cfg.provider, "x-ai") + + def test_grok_4_1_fast_dot_notation_resolves(self): + cfg = get_model_config("grok-4.1-fast") + self.assertEqual(cfg, get_model_config("grok-4-1-fast")) + + def test_grok_3_mini_resolves(self): + cfg = get_model_config("grok-3-mini") + self.assertEqual(cfg.provider, "x-ai") + + def test_grok_3_resolves(self): + cfg = get_model_config("grok-3") + self.assertEqual(cfg.provider, "x-ai") + + # ── Errors ─────────────────────────────────────────────────────────────── + + def test_unknown_model_raises(self): + with self.assertRaises(ValueError): + get_model_config("gpt-4o") # not in registry + + def test_unknown_sonnet_variant_raises(self): + with self.assertRaises(ValueError): + get_model_config("claude-sonnet-99") + + +# --------------------------------------------------------------------------- +# Pricing calculation tests +# --------------------------------------------------------------------------- + + +class TestDynamicSessionCostCalculatorOPG(unittest.TestCase): + """dynamic_session_cost_calculator with OPG (18 decimals).""" + + def _calc(self, model, input_tokens, output_tokens): + return dynamic_session_cost_calculator( + _ctx(model, input_tokens, output_tokens, _opg_requirements()) + ) + + # ── OpenAI ────────────────────────────────────────────────────────────── + + def test_gpt_4_1_cost(self): + cost = self._calc("gpt-4.1", 1000, 500) + expected = _expected_cost_opg("gpt-4.1", 1000, 500) + self.assertEqual(cost, expected) + # 1000*0.000002 + 500*0.000008 = 0.002 + 0.004 = 0.006 USD = 6e15 wei + self.assertEqual(cost, 6_000_000_000_000_000) + + def test_gpt_5_mini_cost(self): + cost = self._calc("gpt-5-mini", 1000, 500) + expected = _expected_cost_opg("gpt-5-mini", 1000, 500) + self.assertEqual(cost, expected) + # 1000*0.00000025 + 500*0.000002 = 0.00025 + 0.001 = 0.00125 USD + self.assertEqual(cost, 1_250_000_000_000_000) + + def test_o4_mini_cost(self): + cost = self._calc("o4-mini", 2000, 1000) + expected = _expected_cost_opg("o4-mini", 2000, 1000) + self.assertEqual(cost, expected) + + # ── Anthropic Sonnet ──────────────────────────────────────────────────── + + def test_claude_sonnet_4_5_cost(self): + cost = self._calc("claude-sonnet-4-5", 1000, 500) + expected = _expected_cost_opg("claude-sonnet-4-5", 1000, 500) + self.assertEqual(cost, expected) + # 1000*0.000003 + 500*0.000015 = 0.003 + 0.0075 = 0.0105 USD = 10.5e15 wei + self.assertEqual(cost, 10_500_000_000_000_000) + + def test_claude_sonnet_4_6_cost(self): + cost = self._calc("claude-sonnet-4-6", 1000, 500) + self.assertEqual(cost, self._calc("claude-sonnet-4-5", 1000, 500)) + + def test_claude_sonnet_4_0_cost(self): + """claude-sonnet-4-0 (legacy) must produce correct pricing.""" + cost = self._calc("claude-sonnet-4-0", 1000, 500) + expected = _expected_cost_opg("claude-sonnet-4-0", 1000, 500) + self.assertEqual(cost, expected) + # Same price tier as claude-sonnet-4-5 + self.assertEqual(cost, 10_500_000_000_000_000) + + # ── Anthropic Haiku ───────────────────────────────────────────────────── + + def test_claude_haiku_4_5_cost(self): + cost = self._calc("claude-haiku-4-5", 1000, 500) + expected = _expected_cost_opg("claude-haiku-4-5", 1000, 500) + self.assertEqual(cost, expected) + # 1000*0.000001 + 500*0.000005 = 0.001 + 0.0025 = 0.0035 USD = 3.5e15 wei + self.assertEqual(cost, 3_500_000_000_000_000) + + # ── Anthropic Opus ────────────────────────────────────────────────────── + + def test_claude_opus_4_5_cost(self): + cost = self._calc("claude-opus-4-5", 1000, 500) + expected = _expected_cost_opg("claude-opus-4-5", 1000, 500) + self.assertEqual(cost, expected) + # 1000*0.000005 + 500*0.000025 = 0.005 + 0.0125 = 0.0175 USD = 17.5e15 wei + self.assertEqual(cost, 17_500_000_000_000_000) + + def test_claude_opus_4_6_cost(self): + cost = self._calc("claude-opus-4-6", 1000, 500) + self.assertEqual(cost, self._calc("claude-opus-4-5", 1000, 500)) + + # ── Google Gemini ──────────────────────────────────────────────────────── + + def test_gemini_2_5_flash_cost(self): + cost = self._calc("gemini-2.5-flash", 1000, 500) + expected = _expected_cost_opg("gemini-2.5-flash", 1000, 500) + self.assertEqual(cost, expected) + # 1000*0.0000003 + 500*0.0000025 = 0.0003 + 0.00125 = 0.00155 USD + self.assertEqual(cost, 1_550_000_000_000_000) + + def test_gemini_2_5_flash_lite_cost(self): + cost = self._calc("gemini-2.5-flash-lite", 1000, 500) + expected = _expected_cost_opg("gemini-2.5-flash-lite", 1000, 500) + self.assertEqual(cost, expected) + # 1000*0.0000001 + 500*0.0000004 = 0.0001 + 0.0002 = 0.0003 USD + self.assertEqual(cost, 300_000_000_000_000) + + def test_gemini_2_5_pro_cost(self): + cost = self._calc("gemini-2.5-pro", 1000, 500) + expected = _expected_cost_opg("gemini-2.5-pro", 1000, 500) + self.assertEqual(cost, expected) + + def test_gemini_3_flash_preview_cost(self): + cost = self._calc("gemini-3-flash-preview", 1000, 500) + expected = _expected_cost_opg("gemini-3-flash-preview", 1000, 500) + self.assertEqual(cost, expected) + + # ── xAI Grok ──────────────────────────────────────────────────────────── + + def test_grok_4_cost(self): + cost = self._calc("grok-4", 1000, 500) + expected = _expected_cost_opg("grok-4", 1000, 500) + self.assertEqual(cost, expected) + # Same pricing tier as claude-sonnet-4-5 + self.assertEqual(cost, 10_500_000_000_000_000) + + def test_grok_4_fast_cost(self): + cost = self._calc("grok-4-fast", 1000, 500) + expected = _expected_cost_opg("grok-4-fast", 1000, 500) + self.assertEqual(cost, expected) + # 1000*0.0000002 + 500*0.0000005 = 0.0002 + 0.00025 = 0.00045 USD + self.assertEqual(cost, 450_000_000_000_000) + + def test_grok_4_1_fast_cost(self): + cost = self._calc("grok-4-1-fast", 1000, 500) + self.assertEqual(cost, self._calc("grok-4-fast", 1000, 500)) + + def test_grok_3_mini_cost(self): + cost = self._calc("grok-3-mini", 1000, 500) + expected = _expected_cost_opg("grok-3-mini", 1000, 500) + self.assertEqual(cost, expected) + + def test_grok_3_cost(self): + cost = self._calc("grok-3", 1000, 500) + expected = _expected_cost_opg("grok-3", 1000, 500) + self.assertEqual(cost, expected) + + # ── Haiku is cheaper than Sonnet ──────────────────────────────────────── + + def test_haiku_cheaper_than_sonnet(self): + haiku = self._calc("claude-haiku-4-5", 1000, 1000) + sonnet = self._calc("claude-sonnet-4-5", 1000, 1000) + self.assertLess(haiku, sonnet) + + def test_gemini_flash_lite_cheaper_than_flash(self): + lite = self._calc("gemini-2.5-flash-lite", 1000, 1000) + flash = self._calc("gemini-2.5-flash", 1000, 1000) + self.assertLess(lite, flash) + + def test_grok_4_fast_cheaper_than_grok_4(self): + fast = self._calc("grok-4-fast", 1000, 1000) + full = self._calc("grok-4", 1000, 1000) + self.assertLess(fast, full) + + +class TestDynamicSessionCostCalculatorUSDC(unittest.TestCase): + """dynamic_session_cost_calculator with USDC (6 decimals).""" + + def _calc(self, model, input_tokens, output_tokens): + return dynamic_session_cost_calculator( + _ctx(model, input_tokens, output_tokens, _usdc_requirements()) + ) + + def test_gpt_4_1_usdc_cost(self): + cost = self._calc("gpt-4.1", 1000, 500) + expected = _expected_cost_usdc("gpt-4.1", 1000, 500) + self.assertEqual(cost, expected) + # 0.006 USD in USDC (6 decimals) = 6000 units + self.assertEqual(cost, 6000) + + def test_claude_sonnet_4_5_usdc_cost(self): + cost = self._calc("claude-sonnet-4-5", 1000, 500) + expected = _expected_cost_usdc("claude-sonnet-4-5", 1000, 500) + self.assertEqual(cost, expected) + # 0.0105 USD = 10500 units + self.assertEqual(cost, 10500) + + def test_gemini_flash_lite_usdc_cost(self): + cost = self._calc("gemini-2.5-flash-lite", 1000, 500) + expected = _expected_cost_usdc("gemini-2.5-flash-lite", 1000, 500) + self.assertEqual(cost, expected) + # 0.0003 USD = 300 units + self.assertEqual(cost, 300) + + +class TestDynamicSessionCostCalculatorEdgeCases(unittest.TestCase): + """Edge cases for dynamic_session_cost_calculator.""" + + def test_zero_tokens_returns_zero(self): + cost = dynamic_session_cost_calculator(_ctx("claude-sonnet-4-5", 0, 0)) + self.assertEqual(cost, 0) + + def test_missing_usage_raises(self): + ctx = { + "request_json": {"model": "claude-sonnet-4-5"}, + "response_json": {"model": "claude-sonnet-4-5"}, # no usage + "payment_requirements": _opg_requirements(), + } + with self.assertRaises(ValueError): + dynamic_session_cost_calculator(ctx) + + def test_unknown_asset_raises(self): + ctx = _ctx("claude-sonnet-4-5", 100, 100) + ctx["payment_requirements"] = {"asset": "0xdeadbeef", "amount": "1000"} + with self.assertRaises(ValueError): + dynamic_session_cost_calculator(ctx) + + def test_missing_asset_raises(self): + ctx = _ctx("claude-sonnet-4-5", 100, 100) + ctx["payment_requirements"] = {"amount": "1000"} # no asset + with self.assertRaises(ValueError): + dynamic_session_cost_calculator(ctx) + + def test_unknown_model_raises_value_error(self): + ctx = _ctx("gpt-4o", 100, 100) + with self.assertRaises(ValueError): + dynamic_session_cost_calculator(ctx) + + def test_missing_request_json_raises_value_error(self): + ctx = { + "request_json": None, + "response_json": { + "model": "claude-sonnet-4-5", + "usage": {"prompt_tokens": 100, "completion_tokens": 100}, + }, + "payment_requirements": _opg_requirements(), + } + with self.assertRaises(ValueError): + dynamic_session_cost_calculator(ctx) + + def test_model_from_request_takes_priority(self): + """request_json model name is used even if response_json has a different model.""" + ctx = { + "request_json": {"model": "claude-haiku-4-5"}, + "response_json": { + "model": "claude-sonnet-4-5", # response says Sonnet + "usage": {"prompt_tokens": 1000, "completion_tokens": 500}, + }, + "payment_requirements": _opg_requirements(), + } + cost = dynamic_session_cost_calculator(ctx) + # Should be priced as Haiku (from request), not Sonnet + haiku_cost = _expected_cost_opg("claude-haiku-4-5", 1000, 500) + self.assertEqual(cost, haiku_cost) + + def test_rounding_ceiling(self): + """Fractional token costs are always rounded UP.""" + # 1 output token of Haiku: 0.000005 USD = 5e12 wei — exact, no rounding needed + cost = dynamic_session_cost_calculator(_ctx("claude-haiku-4-5", 0, 1)) + self.assertEqual(cost, 5_000_000_000_000) + + # 1 input token of Gemini Flash Lite: 0.0000001 USD = 1e11 wei — exact + cost = dynamic_session_cost_calculator(_ctx("gemini-2.5-flash-lite", 1, 0)) + self.assertEqual(cost, 100_000_000_000) + + def test_model_name_case_insensitive(self): + """Model names are normalized to lowercase before lookup.""" + cost_lower = dynamic_session_cost_calculator( + _ctx("claude-sonnet-4-5", 100, 100) + ) + cost_upper = dynamic_session_cost_calculator( + _ctx("CLAUDE-SONNET-4-5", 100, 100) + ) + self.assertEqual(cost_lower, cost_upper) + + def test_sonnet_4_0_hyphen_vs_dot_same_cost(self): + """claude-sonnet-4-0 and claude-4.0-sonnet are the same model.""" + cost_hyphen = dynamic_session_cost_calculator( + _ctx("claude-sonnet-4-0", 1000, 500) + ) + cost_dot = dynamic_session_cost_calculator(_ctx("claude-4.0-sonnet", 1000, 500)) + self.assertEqual(cost_hyphen, cost_dot) + + +if __name__ == "__main__": + unittest.main()