Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,4 @@ jobs:
- uses: actions/checkout@v4
- uses: astral-sh/setup-uv@v5
- name: Run unit tests
run: uv run --group test pytest tee_gateway/test/test_tool_forwarding.py tee_gateway/test/test_tee_core.py -v --import-mode=importlib
run: uv run --group test pytest tee_gateway/test/test_tool_forwarding.py tee_gateway/test/test_tee_core.py tests/test_pricing.py -v --import-mode=importlib
3 changes: 0 additions & 3 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,6 @@ $(image_eif): $(image_tar)

.PHONY: run
run: $(image_eif)
-rm -f /tmp/network.sock
-kill $$(lsof -ti :2222 2>/dev/null)
-rm -f nohup.out
nitro-cli terminate-enclave --all
./scripts/run-enclave.sh $(image_eif)

Expand Down
4 changes: 2 additions & 2 deletions measurements.txt
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
{
"Measurements": {
"HashAlgorithm": "Sha384 { ... }",
"PCR0": "6640e0ede60135375ad19eb5e41a045ce49cf84bf6919cc1f7e61556c2566fa855fd4e64247b1bde07f216761a101b83",
"PCR0": "f2945c350e80c9431649dd691cb4ecebc11174c0201e22c151b0b11c0aa41074174b57cc17343444380f212baaf5752c",
"PCR1": "4b4d5b3661b3efc12920900c80e126e4ce783c522de6c02a2a5bf7af3a2b9327b86776f188e4be1c1c404a129dbda493",
"PCR2": "45b95f6151f44a0d2c0bdb7d64c1861f49d8ef59e726b86ded25155481654c2bf0a58ee65ee64c364528994a9b260c5d"
"PCR2": "01d8b4615e5dc0417d57713b6d8d2c3b0ab3ec3e1848658c116a2986da35b47c04001c8acbea57858446d0bf0a35de68"
}
}
75 changes: 74 additions & 1 deletion tee_gateway/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from x402.schemas import AssetAmount
from x402.server import x402ResourceServerSync
from x402.session import SessionStore
import types as _types
import x402.http.middleware.flask as x402_flask

from .util import dynamic_session_cost_calculator
Expand Down Expand Up @@ -377,7 +378,7 @@ def _patched_read_body_bytes(environ):

x402_flask._read_body_bytes = _patched_read_body_bytes

payment_middleware(
_payment_mw = payment_middleware(
application,
routes=routes,
server=server,
Expand All @@ -386,6 +387,78 @@ def _patched_read_body_bytes(environ):
session_idle_timeout=100,
session_cost_calculator=dynamic_session_cost_calculator,
)


def _strict_resolve_session_request_cost(
self,
*,
method: str,
path: str,
request_body_bytes: bytes,
response_body_bytes: bytes,
payment_payload: object,
payment_requirements: object,
status_code: int | None,
output_object: object = None,
is_streaming: bool = False,
) -> int:
"""Replacement for PaymentMiddleware._resolve_session_request_cost.

Identical to the upstream implementation except that exceptions raised by
the dynamic cost calculator are NOT caught. This means a request whose
cost cannot be determined (unknown model, missing usage data, etc.) will
result in a 500 error rather than silently falling back to the static cap
amount and charging the user an incorrect amount.
"""
from x402.http.middleware.flask import _parse_json_bytes as _x402_parse_json # noqa: PLC0415

default_cost = self._get_session_cost(payment_requirements)
if not self._should_charge_response(status_code):
return default_cost
if not callable(self._session_cost_calculator):
return default_cost

request_object = _x402_parse_json(request_body_bytes)
response_object = (
output_object
if output_object is not None
else _x402_parse_json(response_body_bytes)
)

callback_context = {
"method": method,
"path": path,
"status_code": status_code,
"is_streaming": is_streaming,
"request_body_bytes": request_body_bytes,
"response_body_bytes": response_body_bytes,
"request_json": request_object
if isinstance(request_object, (dict, list))
else None,
"response_json": response_object
if isinstance(response_object, (dict, list))
else None,
"response_object": response_object,
"payment_payload": payment_payload,
"payment_requirements": payment_requirements,
"default_cost": default_cost,
}

# Do NOT catch exceptions here — let them propagate so the request fails
# with a 500 rather than silently charging the static fallback amount.
dynamic_cost = self._session_cost_calculator(callback_context)
if dynamic_cost is None:
raise ValueError(
f"dynamic_session_cost_calculator returned None for {method} {path}; "
"cannot determine request cost"
)
return self._coerce_non_negative_int(dynamic_cost)


_payment_mw._resolve_session_request_cost = _types.MethodType( # type: ignore[method-assign]
_strict_resolve_session_request_cost, _payment_mw
)

logger.info("x402 payment middleware initialized")

if __name__ == "__main__":
Expand Down
4 changes: 2 additions & 2 deletions tee_gateway/definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,8 @@
# /v1/chat/completions — 0.01 OUSDC precheck (6 decimals: 10_000 = $0.01)
CHAT_COMPLETIONS_USDC_AMOUNT: str = "10000"

# /v1/chat/completions — 0.05 OPG precheck (18 decimals)
CHAT_COMPLETIONS_OPG_AMOUNT: str = "50000000000000000"
# /v1/chat/completions — 0.1 OPG precheck (18 decimals)
CHAT_COMPLETIONS_OPG_AMOUNT: str = "100000000000000000"

# /v1/completions — 0.01 USDC precheck (6 decimals: 10_000 = $0.01)
COMPLETIONS_USDC_AMOUNT: str = "10000"
51 changes: 13 additions & 38 deletions tee_gateway/model_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,24 +87,6 @@ class SupportedModel(Enum):
input_price_usd=Decimal("0.000005"),
output_price_usd=Decimal("0.000025"),
)
CLAUDE_3_7_SONNET = ModelConfig(
provider="anthropic",
api_name="claude-3-7-sonnet-latest",
input_price_usd=Decimal("0.000003"),
output_price_usd=Decimal("0.000015"),
)
CLAUDE_3_5_HAIKU = ModelConfig(
provider="anthropic",
api_name="claude-3-5-haiku-latest",
input_price_usd=Decimal("0.000001"),
output_price_usd=Decimal("0.000005"),
)
CLAUDE_4_0_SONNET = ModelConfig(
provider="anthropic",
api_name="claude-sonnet-4-0",
input_price_usd=Decimal("0.000003"),
output_price_usd=Decimal("0.000015"),
)

# ── Google Gemini ───────────────────────────────────────────────────
GEMINI_2_5_FLASH = ModelConfig(
Expand All @@ -128,12 +110,6 @@ class SupportedModel(Enum):
output_price_usd=Decimal("0.0000004"),
thinking_budget=0,
)
GEMINI_3_PRO_PREVIEW = ModelConfig(
provider="google",
api_name="gemini-3-pro-preview",
input_price_usd=Decimal("0.000002"),
output_price_usd=Decimal("0.000012"),
)
GEMINI_3_FLASH_PREVIEW = ModelConfig(
provider="google",
api_name="gemini-3-flash-preview",
Expand Down Expand Up @@ -166,6 +142,14 @@ class SupportedModel(Enum):
input_price_usd=Decimal("0.0000002"),
output_price_usd=Decimal("0.0000005"),
)

# ── Legacy models (not in current SDK — retained for older SDK versions) ──
CLAUDE_4_0_SONNET = ModelConfig(
provider="anthropic",
api_name="claude-sonnet-4-0",
input_price_usd=Decimal("0.000003"),
output_price_usd=Decimal("0.000015"),
)
GROK_3_MINI = ModelConfig(
provider="x-ai",
api_name="grok-3-mini",
Expand All @@ -178,12 +162,6 @@ class SupportedModel(Enum):
input_price_usd=Decimal("0.000003"),
output_price_usd=Decimal("0.000015"),
)
GROK_2 = ModelConfig(
provider="x-ai",
api_name="grok-2-latest",
input_price_usd=Decimal("0.000002"),
output_price_usd=Decimal("0.00001"),
)


# Canonical lookup: user-facing model name → SupportedModel
Expand All @@ -202,27 +180,24 @@ class SupportedModel(Enum):
"claude-haiku-4-5": SupportedModel.CLAUDE_HAIKU_4_5,
"claude-opus-4-5": SupportedModel.CLAUDE_OPUS_4_5,
"claude-opus-4-6": SupportedModel.CLAUDE_OPUS_4_6,
"claude-3.7-sonnet": SupportedModel.CLAUDE_3_7_SONNET,
"claude-3.5-haiku": SupportedModel.CLAUDE_3_5_HAIKU,
"claude-4.0-sonnet": SupportedModel.CLAUDE_4_0_SONNET,
# Google
"gemini-2.5-flash": SupportedModel.GEMINI_2_5_FLASH,
"gemini-2.5-pro": SupportedModel.GEMINI_2_5_PRO,
"gemini-2.5-flash-lite": SupportedModel.GEMINI_2_5_FLASH_LITE,
"gemini-3-pro-preview": SupportedModel.GEMINI_3_PRO_PREVIEW,
"gemini-3-flash-preview": SupportedModel.GEMINI_3_FLASH_PREVIEW,
# xAI
"grok-4": SupportedModel.GROK_4,
"grok-4-fast": SupportedModel.GROK_4_FAST,
"grok-4-1-fast": SupportedModel.GROK_4_1_FAST,
"grok-4.1-fast": SupportedModel.GROK_4_1_FAST,
"grok-4-1-fast-non-reasoning": SupportedModel.GROK_4_1_FAST_NON_REASONING,
"grok-3-mini-beta": SupportedModel.GROK_3_MINI,
# Legacy — not in current SDK, retained for older SDK versions
"claude-sonnet-4-0": SupportedModel.CLAUDE_4_0_SONNET,
"claude-4.0-sonnet": SupportedModel.CLAUDE_4_0_SONNET, # alternate dot notation
"grok-3-mini-beta": SupportedModel.GROK_3_MINI, # old beta alias
"grok-3-mini": SupportedModel.GROK_3_MINI,
"grok-3-beta": SupportedModel.GROK_3,
"grok-3-beta": SupportedModel.GROK_3, # old beta alias
"grok-3": SupportedModel.GROK_3,
"grok-2-1212": SupportedModel.GROK_2,
"grok-2": SupportedModel.GROK_2,
}

# Build the rate card automatically from the enum (for backward compat with util.py)
Expand Down
66 changes: 43 additions & 23 deletions tee_gateway/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,6 @@ def _deserialize_dict(data, boxed_type):

from tee_gateway.definitions import ( # noqa: E402
ASSET_DECIMALS_BY_ADDRESS,
DEFAULT_ASSET_DECIMALS,
)
from tee_gateway.model_registry import get_model_config # noqa: E402

Expand Down Expand Up @@ -238,31 +237,49 @@ def _normalize_model_name(model: str | None) -> str | None:

def _extract_usage_tokens(
response_json: dict[str, Any] | None,
) -> tuple[int, int] | None:
) -> tuple[int, int]:
"""Extract (input_tokens, output_tokens) from response JSON.

Raises ValueError if usage data is missing or malformed — no silent fallback.
"""
if not isinstance(response_json, dict):
return None
raise ValueError("response_json is not a dict; cannot extract usage tokens")
usage = response_json.get("usage")
if not isinstance(usage, dict):
return None
raise ValueError(
"response_json has no 'usage' dict; cannot extract usage tokens"
)

prompt_tokens = usage.get("prompt_tokens", usage.get("input_tokens"))
completion_tokens = usage.get("completion_tokens", usage.get("output_tokens"))
if prompt_tokens is None or completion_tokens is None:
return None
raise ValueError(f"usage dict is missing token counts: {usage!r}")

try:
return max(0, int(prompt_tokens)), max(0, int(completion_tokens))
except (TypeError, ValueError):
return None
except (TypeError, ValueError) as exc:
raise ValueError(f"Could not parse token counts from usage: {usage!r}") from exc


def _extract_model_from_context(
request_json: dict[str, Any] | None,
response_json: dict[str, Any] | None,
) -> str | None:
req_model = request_json.get("model") if isinstance(request_json, dict) else None
resp_model = response_json.get("model") if isinstance(response_json, dict) else None
return _normalize_model_name(req_model or resp_model)
) -> str:
"""Extract and normalize model name from request JSON.

Uses only the request model name — the response model field is ignored
because providers may return a versioned alias that differs from the
user-facing name. Raises ValueError if the model name is absent.
"""
if not isinstance(request_json, dict):
raise ValueError("request_json is not a dict; cannot extract model name")
req_model = request_json.get("model")
if not req_model:
raise ValueError("request_json has no 'model' field")
normalized = _normalize_model_name(req_model)
if not normalized:
raise ValueError(f"model name normalizes to empty string: {req_model!r}")
return normalized


def _extract_asset_decimals_from_requirements(payment_requirements: Any) -> int:
Expand All @@ -272,15 +289,25 @@ def _extract_asset_decimals_from_requirements(payment_requirements: Any) -> int:
if not asset and isinstance(req.get("price"), dict):
asset = req["price"].get("asset")

if isinstance(asset, str):
return ASSET_DECIMALS_BY_ADDRESS.get(asset.lower(), DEFAULT_ASSET_DECIMALS)
return DEFAULT_ASSET_DECIMALS
if not isinstance(asset, str) or not asset:
raise ValueError(
f"payment_requirements has no recognizable asset address; "
f"cannot determine token decimals: {req!r}"
)

asset_lower = asset.lower()
if asset_lower not in ASSET_DECIMALS_BY_ADDRESS:
raise ValueError(
f"Unknown asset address {asset!r}; not in ASSET_DECIMALS_BY_ADDRESS. "
f"Add it to definitions.py before accepting payments with this token."
)
return ASSET_DECIMALS_BY_ADDRESS[asset_lower]


def dynamic_session_cost_calculator(context: dict[str, Any]) -> int:
"""Compute UPTO per-request cost in token smallest units from actual usage.

Raises ValueError if the model is not in the registry (no silent fallback).
Raises ValueError on any missing or unrecognised input — no silent fallback.
"""
request_json = context.get("request_json")
response_json = context.get("response_json")
Expand All @@ -291,18 +318,11 @@ def dynamic_session_cost_calculator(context: dict[str, Any]) -> int:
)

model = _extract_model_from_context(request_json, response_json)
if not model:
raise ValueError("Could not extract model name from request/response")

# get_model_config raises ValueError for unknown models — no fallback
cfg = get_model_config(model)

usage_tokens = _extract_usage_tokens(response_json)
if not usage_tokens:
logger.warning("No usage tokens in response for model=%s; charging zero", model)
return 0

input_tokens, output_tokens = usage_tokens
input_tokens, output_tokens = _extract_usage_tokens(response_json)

input_rate = cfg.input_price_usd
output_rate = cfg.output_price_usd
Expand Down
Loading
Loading