Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
185 changes: 75 additions & 110 deletions src/opengradient/client/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,22 @@

import json
import logging
import ssl
from dataclasses import dataclass
from typing import AsyncGenerator, Awaitable, Callable, Dict, List, Optional, TypeVar, Union
import httpx
import asyncio

from eth_account import Account
from eth_account.account import LocalAccount
from x402 import x402Client
from x402.http.clients import x402HttpxClient
from x402.mechanisms.evm import EthAccountSigner
from x402.mechanisms.evm.exact.register import register_exact_evm_client
from x402.mechanisms.evm.upto.register import register_upto_evm_client

from ..types import TEE_LLM, StreamChoice, StreamChunk, StreamDelta, TextGenerationOutput, x402SettlementMode
from .opg_token import Permit2ApprovalResult, ensure_opg_approval
from .tee_registry import TEERegistry, build_ssl_context_from_der
from .tee_connection import RegistryTEEConnection, StaticTEEConnection, TEEConnectionInterface
from .tee_registry import TEERegistry

logger = logging.getLogger(__name__)
T = TypeVar("T")
Expand All @@ -34,7 +34,7 @@
_REQUEST_TIMEOUT = 60


@dataclass
@dataclass(frozen=True)
class _ChatParams:
"""Bundles the common parameters for chat/completion requests."""

Expand Down Expand Up @@ -63,112 +63,73 @@ class LLM:
below the requested amount.

Usage:
# Via on-chain registry (default)
llm = og.LLM(private_key="0x...")

# Via hardcoded URL (development / self-hosted)
llm = og.LLM.from_url(private_key="0x...", llm_server_url="https://1.2.3.4")

# One-time approval (idempotent — skips if allowance is already sufficient)
llm.ensure_opg_approval(opg_amount=5)

result = await llm.chat(model=TEE_LLM.CLAUDE_HAIKU_4_5, messages=[...])
result = await llm.completion(model=TEE_LLM.CLAUDE_HAIKU_4_5, prompt="Hello")

Args:
private_key (str): Ethereum private key for signing x402 payments.
rpc_url (str): RPC URL for the OpenGradient network. Used to fetch the
active TEE endpoint from the on-chain registry when ``llm_server_url``
is not provided.
tee_registry_address (str): Address of the on-chain TEE registry contract.
llm_server_url (str, optional): Bypass the registry and connect directly
to this TEE endpoint URL (e.g. ``"https://1.2.3.4"``). When set,
TLS certificate verification is disabled automatically because
self-hosted TEE servers typically use self-signed certificates.

.. warning::
Using ``llm_server_url`` disables TLS certificate verification,
which removes protection against man-in-the-middle attacks.
Only connect to servers you trust and over secure network paths.
"""

def __init__(
self,
private_key: str,
rpc_url: str = DEFAULT_RPC_URL,
tee_registry_address: str = DEFAULT_TEE_REGISTRY_ADDRESS,
llm_server_url: Optional[str] = None,
):
if not private_key:
raise ValueError("A private key is required to use the LLM client. Pass a valid private_key to the constructor.")
self._wallet_account: LocalAccount = Account.from_key(private_key)
self._rpc_url = rpc_url
self._tee_registry_address = tee_registry_address
self._llm_server_url = llm_server_url

# x402 payment stack (created once, reused across TEE refreshes)
signer = EthAccountSigner(self._wallet_account)
self._x402_client = x402Client()
register_exact_evm_client(self._x402_client, signer, networks=[BASE_TESTNET_NETWORK])
register_upto_evm_client(self._x402_client, signer, networks=[BASE_TESTNET_NETWORK])

self._connect_tee()

# ── TEE resolution and connection ───────────────────────────────────────────

def _connect_tee(self) -> None:
"""Resolve TEE from registry and create a secure HTTP client for it."""
endpoint, tls_cert_der, tee_id, tee_payment_address = self._resolve_tee(
self._llm_server_url,
self._rpc_url,
self._tee_registry_address,
)
self._tee_id = tee_id
self._tee_endpoint = endpoint
self._tee_payment_address = tee_payment_address

ssl_ctx = build_ssl_context_from_der(tls_cert_der) if tls_cert_der else None
self._tls_verify: Union[ssl.SSLContext, bool] = ssl_ctx if ssl_ctx else (self._llm_server_url is None)
self._http_client = x402HttpxClient(self._x402_client, verify=self._tls_verify)

async def _refresh_tee(self) -> None:
"""Re-resolve TEE from the registry and rebuild the HTTP client."""
old_http_client = self._http_client
self._connect_tee()
try:
await old_http_client.aclose()
except Exception:
logger.debug("Failed to close previous HTTP client during TEE refresh.", exc_info=True)


@staticmethod
def _resolve_tee(
tee_endpoint_override: Optional[str],
og_rpc_url: Optional[str],
tee_registry_address: Optional[str],
) -> tuple:
"""Resolve TEE endpoint and metadata from the on-chain registry or explicit URL.
x402_client = LLM._build_x402_client(private_key)
onchain_registry = TEERegistry(rpc_url=rpc_url, registry_address=tee_registry_address)
self._tee: TEEConnectionInterface = RegistryTEEConnection(x402_client=x402_client, registry=onchain_registry)

Returns:
(endpoint, tls_cert_der, tee_id, payment_address)
"""
if tee_endpoint_override is not None:
return tee_endpoint_override, None, None, None

if og_rpc_url is None or tee_registry_address is None:
raise ValueError("Either llm_server_url or both rpc_url and tee_registry_address must be provided.")
@classmethod
def from_url(
cls,
private_key: str,
llm_server_url: str,
) -> "LLM":
"""**[Dev]** Create an LLM client with a hardcoded TEE endpoint URL.

try:
registry = TEERegistry(rpc_url=og_rpc_url, registry_address=tee_registry_address)
tee = registry.get_llm_tee()
except Exception as e:
raise RuntimeError(f"Failed to fetch LLM TEE endpoint from registry ({tee_registry_address} on {og_rpc_url}): {e}. ") from e
Intended for development and self-hosted TEE servers. TLS certificate
verification is disabled because these servers typically use self-signed
certificates. For production use, prefer the default constructor which
resolves TEEs from the on-chain registry.

if tee is None:
raise ValueError("No active LLM proxy TEE found in the registry. Pass llm_server_url explicitly to override.")
Args:
private_key: Ethereum private key for signing x402 payments.
llm_server_url: The TEE endpoint URL (e.g. ``"https://1.2.3.4"``).
"""
instance = cls.__new__(cls)
if not private_key:
raise ValueError("A private key is required to use the LLM client. Pass a valid private_key to the constructor.")
instance._wallet_account = Account.from_key(private_key)
x402_client = cls._build_x402_client(private_key)
instance._tee = StaticTEEConnection(x402_client=x402_client, endpoint=llm_server_url)
return instance

logger.info("Using TEE endpoint from registry: %s (teeId=%s)", tee.endpoint, tee.tee_id)
return tee.endpoint, tee.tls_cert_der, tee.tee_id, tee.payment_address
@staticmethod
def _build_x402_client(private_key: str) -> x402Client:
"""Build the x402 payment stack from a private key."""
account = Account.from_key(private_key)
signer = EthAccountSigner(account)
client = x402Client()
register_exact_evm_client(client, signer, networks=[BASE_TESTNET_NETWORK])
register_upto_evm_client(client, signer, networks=[BASE_TESTNET_NETWORK])
return client

# ── Lifecycle ───────────────────────────────────────────────────────

async def close(self) -> None:
"""Close the underlying HTTP client."""
await self._http_client.aclose()
"""Cancel the background refresh loop and close the HTTP client."""
await self._tee.close()

# ── Request helpers ─────────────────────────────────────────────────

Expand All @@ -195,13 +156,6 @@ def _chat_payload(self, params: _ChatParams, messages: List[Dict], stream: bool
payload["tool_choice"] = params.tool_choice or "auto"
return payload

def _tee_metadata(self) -> Dict:
return dict(
tee_id=self._tee_id,
tee_endpoint=self._tee_endpoint,
tee_payment_address=self._tee_payment_address,
)

async def _call_with_tee_retry(
self,
operation_name: str,
Expand All @@ -212,17 +166,20 @@ async def _call_with_tee_retry(
Only retries when the request never reached the server (no HTTP response).
Server-side errors (4xx/5xx) are not retried.
"""
self._tee.ensure_refresh_loop()
try:
return await call()
except httpx.HTTPStatusError:
raise
except asyncio.CancelledError:
raise
except Exception as exc:
logger.warning(
"Connection failure during %s; refreshing TEE and retrying once: %s",
operation_name,
exc,
)
await self._refresh_tee()
await self._tee.reconnect()
return await call()

# ── Public API ──────────────────────────────────────────────────────
Expand Down Expand Up @@ -295,8 +252,9 @@ async def completion(
payload["stop"] = stop_sequence

async def _request() -> TextGenerationOutput:
response = await self._http_client.post(
self._tee_endpoint + _COMPLETION_ENDPOINT,
tee = self._tee.get()
response = await tee.http_client.post(
tee.endpoint + _COMPLETION_ENDPOINT,
json=payload,
headers=self._headers(x402_settlement_mode),
timeout=_REQUEST_TIMEOUT,
Expand All @@ -309,7 +267,7 @@ async def _request() -> TextGenerationOutput:
completion_output=result.get("completion"),
tee_signature=result.get("tee_signature"),
tee_timestamp=result.get("tee_timestamp"),
**self._tee_metadata(),
**tee.metadata(),
)

try:
Expand Down Expand Up @@ -384,8 +342,9 @@ async def _chat_request(self, params: _ChatParams, messages: List[Dict]) -> Text
payload = self._chat_payload(params, messages)

async def _request() -> TextGenerationOutput:
response = await self._http_client.post(
self._tee_endpoint + _CHAT_ENDPOINT,
tee = self._tee.get()
response = await tee.http_client.post(
tee.endpoint + _CHAT_ENDPOINT,
json=payload,
headers=self._headers(params.x402_settlement_mode),
timeout=_REQUEST_TIMEOUT,
Expand All @@ -411,7 +370,7 @@ async def _request() -> TextGenerationOutput:
chat_output=message,
tee_signature=result.get("tee_signature"),
tee_timestamp=result.get("tee_timestamp"),
**self._tee_metadata(),
**tee.metadata(),
)

try:
Expand Down Expand Up @@ -448,24 +407,28 @@ async def _chat_tools_as_stream(self, params: _ChatParams, messages: List[Dict])

async def _chat_stream(self, params: _ChatParams, messages: List[Dict]) -> AsyncGenerator[StreamChunk, None]:
"""Async SSE streaming implementation."""
self._tee.ensure_refresh_loop()
headers = self._headers(params.x402_settlement_mode)
payload = self._chat_payload(params, messages, stream=True)

chunks_yielded = False
try:
async with self._http_client.stream(
tee = self._tee.get()
async with tee.http_client.stream(
"POST",
self._tee_endpoint + _CHAT_ENDPOINT,
tee.endpoint + _CHAT_ENDPOINT,
json=payload,
headers=headers,
timeout=_REQUEST_TIMEOUT,
) as response:
async for chunk in self._parse_sse_response(response):
async for chunk in self._parse_sse_response(response, tee):
chunks_yielded = True
yield chunk
return
except httpx.HTTPStatusError:
raise
except asyncio.CancelledError:
raise
except Exception as exc:
if chunks_yielded:
raise
Expand All @@ -476,19 +439,21 @@ async def _chat_stream(self, params: _ChatParams, messages: List[Dict]) -> Async

# Only reached if the first attempt failed before yielding any chunks.
# Re-resolve the TEE endpoint from the registry and retry once.
await self._refresh_tee()
await self._tee.reconnect()
tee = self._tee.get()

headers = self._headers(params.x402_settlement_mode)
async with self._http_client.stream(
async with tee.http_client.stream(
"POST",
self._tee_endpoint + _CHAT_ENDPOINT,
tee.endpoint + _CHAT_ENDPOINT,
json=payload,
headers=headers,
timeout=_REQUEST_TIMEOUT,
) as response:
async for chunk in self._parse_sse_response(response):
async for chunk in self._parse_sse_response(response, tee):
yield chunk

async def _parse_sse_response(self, response) -> AsyncGenerator[StreamChunk, None]:
async def _parse_sse_response(self, response, tee) -> AsyncGenerator[StreamChunk, None]:
"""Parse an SSE response stream into StreamChunk objects."""
status_code = getattr(response, "status_code", None)
if status_code is not None and status_code >= 400:
Expand Down Expand Up @@ -526,7 +491,7 @@ async def _parse_sse_response(self, response) -> AsyncGenerator[StreamChunk, Non

chunk = StreamChunk.from_sse_data(data)
if chunk.is_final:
chunk.tee_id = self._tee_id
chunk.tee_endpoint = self._tee_endpoint
chunk.tee_payment_address = self._tee_payment_address
chunk.tee_id = tee.tee_id
chunk.tee_endpoint = tee.endpoint
chunk.tee_payment_address = tee.payment_address
yield chunk
Loading
Loading