diff --git a/src/opengradient/client/llm.py b/src/opengradient/client/llm.py index 329d2c10..5a1e67a9 100644 --- a/src/opengradient/client/llm.py +++ b/src/opengradient/client/llm.py @@ -2,22 +2,22 @@ import json import logging -import ssl from dataclasses import dataclass from typing import AsyncGenerator, Awaitable, Callable, Dict, List, Optional, TypeVar, Union import httpx +import asyncio from eth_account import Account from eth_account.account import LocalAccount from x402 import x402Client -from x402.http.clients import x402HttpxClient from x402.mechanisms.evm import EthAccountSigner from x402.mechanisms.evm.exact.register import register_exact_evm_client from x402.mechanisms.evm.upto.register import register_upto_evm_client from ..types import TEE_LLM, StreamChoice, StreamChunk, StreamDelta, TextGenerationOutput, x402SettlementMode from .opg_token import Permit2ApprovalResult, ensure_opg_approval -from .tee_registry import TEERegistry, build_ssl_context_from_der +from .tee_connection import RegistryTEEConnection, StaticTEEConnection, TEEConnectionInterface +from .tee_registry import TEERegistry logger = logging.getLogger(__name__) T = TypeVar("T") @@ -34,7 +34,7 @@ _REQUEST_TIMEOUT = 60 -@dataclass +@dataclass(frozen=True) class _ChatParams: """Bundles the common parameters for chat/completion requests.""" @@ -63,29 +63,17 @@ class LLM: below the requested amount. Usage: + # Via on-chain registry (default) llm = og.LLM(private_key="0x...") + # Via hardcoded URL (development / self-hosted) + llm = og.LLM.from_url(private_key="0x...", llm_server_url="https://1.2.3.4") + # One-time approval (idempotent — skips if allowance is already sufficient) llm.ensure_opg_approval(opg_amount=5) result = await llm.chat(model=TEE_LLM.CLAUDE_HAIKU_4_5, messages=[...]) result = await llm.completion(model=TEE_LLM.CLAUDE_HAIKU_4_5, prompt="Hello") - - Args: - private_key (str): Ethereum private key for signing x402 payments. - rpc_url (str): RPC URL for the OpenGradient network. Used to fetch the - active TEE endpoint from the on-chain registry when ``llm_server_url`` - is not provided. - tee_registry_address (str): Address of the on-chain TEE registry contract. - llm_server_url (str, optional): Bypass the registry and connect directly - to this TEE endpoint URL (e.g. ``"https://1.2.3.4"``). When set, - TLS certificate verification is disabled automatically because - self-hosted TEE servers typically use self-signed certificates. - - .. warning:: - Using ``llm_server_url`` disables TLS certificate verification, - which removes protection against man-in-the-middle attacks. - Only connect to servers you trust and over secure network paths. """ def __init__( @@ -93,82 +81,55 @@ def __init__( private_key: str, rpc_url: str = DEFAULT_RPC_URL, tee_registry_address: str = DEFAULT_TEE_REGISTRY_ADDRESS, - llm_server_url: Optional[str] = None, ): + if not private_key: + raise ValueError("A private key is required to use the LLM client. Pass a valid private_key to the constructor.") self._wallet_account: LocalAccount = Account.from_key(private_key) - self._rpc_url = rpc_url - self._tee_registry_address = tee_registry_address - self._llm_server_url = llm_server_url - - # x402 payment stack (created once, reused across TEE refreshes) - signer = EthAccountSigner(self._wallet_account) - self._x402_client = x402Client() - register_exact_evm_client(self._x402_client, signer, networks=[BASE_TESTNET_NETWORK]) - register_upto_evm_client(self._x402_client, signer, networks=[BASE_TESTNET_NETWORK]) - - self._connect_tee() - - # ── TEE resolution and connection ─────────────────────────────────────────── - - def _connect_tee(self) -> None: - """Resolve TEE from registry and create a secure HTTP client for it.""" - endpoint, tls_cert_der, tee_id, tee_payment_address = self._resolve_tee( - self._llm_server_url, - self._rpc_url, - self._tee_registry_address, - ) - self._tee_id = tee_id - self._tee_endpoint = endpoint - self._tee_payment_address = tee_payment_address - - ssl_ctx = build_ssl_context_from_der(tls_cert_der) if tls_cert_der else None - self._tls_verify: Union[ssl.SSLContext, bool] = ssl_ctx if ssl_ctx else (self._llm_server_url is None) - self._http_client = x402HttpxClient(self._x402_client, verify=self._tls_verify) - - async def _refresh_tee(self) -> None: - """Re-resolve TEE from the registry and rebuild the HTTP client.""" - old_http_client = self._http_client - self._connect_tee() - try: - await old_http_client.aclose() - except Exception: - logger.debug("Failed to close previous HTTP client during TEE refresh.", exc_info=True) - - @staticmethod - def _resolve_tee( - tee_endpoint_override: Optional[str], - og_rpc_url: Optional[str], - tee_registry_address: Optional[str], - ) -> tuple: - """Resolve TEE endpoint and metadata from the on-chain registry or explicit URL. + x402_client = LLM._build_x402_client(private_key) + onchain_registry = TEERegistry(rpc_url=rpc_url, registry_address=tee_registry_address) + self._tee: TEEConnectionInterface = RegistryTEEConnection(x402_client=x402_client, registry=onchain_registry) - Returns: - (endpoint, tls_cert_der, tee_id, payment_address) - """ - if tee_endpoint_override is not None: - return tee_endpoint_override, None, None, None - - if og_rpc_url is None or tee_registry_address is None: - raise ValueError("Either llm_server_url or both rpc_url and tee_registry_address must be provided.") + @classmethod + def from_url( + cls, + private_key: str, + llm_server_url: str, + ) -> "LLM": + """**[Dev]** Create an LLM client with a hardcoded TEE endpoint URL. - try: - registry = TEERegistry(rpc_url=og_rpc_url, registry_address=tee_registry_address) - tee = registry.get_llm_tee() - except Exception as e: - raise RuntimeError(f"Failed to fetch LLM TEE endpoint from registry ({tee_registry_address} on {og_rpc_url}): {e}. ") from e + Intended for development and self-hosted TEE servers. TLS certificate + verification is disabled because these servers typically use self-signed + certificates. For production use, prefer the default constructor which + resolves TEEs from the on-chain registry. - if tee is None: - raise ValueError("No active LLM proxy TEE found in the registry. Pass llm_server_url explicitly to override.") + Args: + private_key: Ethereum private key for signing x402 payments. + llm_server_url: The TEE endpoint URL (e.g. ``"https://1.2.3.4"``). + """ + instance = cls.__new__(cls) + if not private_key: + raise ValueError("A private key is required to use the LLM client. Pass a valid private_key to the constructor.") + instance._wallet_account = Account.from_key(private_key) + x402_client = cls._build_x402_client(private_key) + instance._tee = StaticTEEConnection(x402_client=x402_client, endpoint=llm_server_url) + return instance - logger.info("Using TEE endpoint from registry: %s (teeId=%s)", tee.endpoint, tee.tee_id) - return tee.endpoint, tee.tls_cert_der, tee.tee_id, tee.payment_address + @staticmethod + def _build_x402_client(private_key: str) -> x402Client: + """Build the x402 payment stack from a private key.""" + account = Account.from_key(private_key) + signer = EthAccountSigner(account) + client = x402Client() + register_exact_evm_client(client, signer, networks=[BASE_TESTNET_NETWORK]) + register_upto_evm_client(client, signer, networks=[BASE_TESTNET_NETWORK]) + return client # ── Lifecycle ─────────────────────────────────────────────────────── async def close(self) -> None: - """Close the underlying HTTP client.""" - await self._http_client.aclose() + """Cancel the background refresh loop and close the HTTP client.""" + await self._tee.close() # ── Request helpers ───────────────────────────────────────────────── @@ -195,13 +156,6 @@ def _chat_payload(self, params: _ChatParams, messages: List[Dict], stream: bool payload["tool_choice"] = params.tool_choice or "auto" return payload - def _tee_metadata(self) -> Dict: - return dict( - tee_id=self._tee_id, - tee_endpoint=self._tee_endpoint, - tee_payment_address=self._tee_payment_address, - ) - async def _call_with_tee_retry( self, operation_name: str, @@ -212,17 +166,20 @@ async def _call_with_tee_retry( Only retries when the request never reached the server (no HTTP response). Server-side errors (4xx/5xx) are not retried. """ + self._tee.ensure_refresh_loop() try: return await call() except httpx.HTTPStatusError: raise + except asyncio.CancelledError: + raise except Exception as exc: logger.warning( "Connection failure during %s; refreshing TEE and retrying once: %s", operation_name, exc, ) - await self._refresh_tee() + await self._tee.reconnect() return await call() # ── Public API ────────────────────────────────────────────────────── @@ -295,8 +252,9 @@ async def completion( payload["stop"] = stop_sequence async def _request() -> TextGenerationOutput: - response = await self._http_client.post( - self._tee_endpoint + _COMPLETION_ENDPOINT, + tee = self._tee.get() + response = await tee.http_client.post( + tee.endpoint + _COMPLETION_ENDPOINT, json=payload, headers=self._headers(x402_settlement_mode), timeout=_REQUEST_TIMEOUT, @@ -309,7 +267,7 @@ async def _request() -> TextGenerationOutput: completion_output=result.get("completion"), tee_signature=result.get("tee_signature"), tee_timestamp=result.get("tee_timestamp"), - **self._tee_metadata(), + **tee.metadata(), ) try: @@ -384,8 +342,9 @@ async def _chat_request(self, params: _ChatParams, messages: List[Dict]) -> Text payload = self._chat_payload(params, messages) async def _request() -> TextGenerationOutput: - response = await self._http_client.post( - self._tee_endpoint + _CHAT_ENDPOINT, + tee = self._tee.get() + response = await tee.http_client.post( + tee.endpoint + _CHAT_ENDPOINT, json=payload, headers=self._headers(params.x402_settlement_mode), timeout=_REQUEST_TIMEOUT, @@ -411,7 +370,7 @@ async def _request() -> TextGenerationOutput: chat_output=message, tee_signature=result.get("tee_signature"), tee_timestamp=result.get("tee_timestamp"), - **self._tee_metadata(), + **tee.metadata(), ) try: @@ -448,24 +407,28 @@ async def _chat_tools_as_stream(self, params: _ChatParams, messages: List[Dict]) async def _chat_stream(self, params: _ChatParams, messages: List[Dict]) -> AsyncGenerator[StreamChunk, None]: """Async SSE streaming implementation.""" + self._tee.ensure_refresh_loop() headers = self._headers(params.x402_settlement_mode) payload = self._chat_payload(params, messages, stream=True) chunks_yielded = False try: - async with self._http_client.stream( + tee = self._tee.get() + async with tee.http_client.stream( "POST", - self._tee_endpoint + _CHAT_ENDPOINT, + tee.endpoint + _CHAT_ENDPOINT, json=payload, headers=headers, timeout=_REQUEST_TIMEOUT, ) as response: - async for chunk in self._parse_sse_response(response): + async for chunk in self._parse_sse_response(response, tee): chunks_yielded = True yield chunk return except httpx.HTTPStatusError: raise + except asyncio.CancelledError: + raise except Exception as exc: if chunks_yielded: raise @@ -476,19 +439,21 @@ async def _chat_stream(self, params: _ChatParams, messages: List[Dict]) -> Async # Only reached if the first attempt failed before yielding any chunks. # Re-resolve the TEE endpoint from the registry and retry once. - await self._refresh_tee() + await self._tee.reconnect() + tee = self._tee.get() + headers = self._headers(params.x402_settlement_mode) - async with self._http_client.stream( + async with tee.http_client.stream( "POST", - self._tee_endpoint + _CHAT_ENDPOINT, + tee.endpoint + _CHAT_ENDPOINT, json=payload, headers=headers, timeout=_REQUEST_TIMEOUT, ) as response: - async for chunk in self._parse_sse_response(response): + async for chunk in self._parse_sse_response(response, tee): yield chunk - async def _parse_sse_response(self, response) -> AsyncGenerator[StreamChunk, None]: + async def _parse_sse_response(self, response, tee) -> AsyncGenerator[StreamChunk, None]: """Parse an SSE response stream into StreamChunk objects.""" status_code = getattr(response, "status_code", None) if status_code is not None and status_code >= 400: @@ -526,7 +491,7 @@ async def _parse_sse_response(self, response) -> AsyncGenerator[StreamChunk, Non chunk = StreamChunk.from_sse_data(data) if chunk.is_final: - chunk.tee_id = self._tee_id - chunk.tee_endpoint = self._tee_endpoint - chunk.tee_payment_address = self._tee_payment_address + chunk.tee_id = tee.tee_id + chunk.tee_endpoint = tee.endpoint + chunk.tee_payment_address = tee.payment_address yield chunk diff --git a/src/opengradient/client/tee_connection.py b/src/opengradient/client/tee_connection.py new file mode 100644 index 00000000..2a7682d3 --- /dev/null +++ b/src/opengradient/client/tee_connection.py @@ -0,0 +1,202 @@ +"""Manages the lifecycle of a connection to a TEE endpoint.""" + +import asyncio +import logging +import ssl +from dataclasses import dataclass +from typing import Dict, Optional, Protocol, Union + +from x402 import x402Client +from x402.http.clients import x402HttpxClient + +from .tee_registry import TEE_TYPE_LLM_PROXY, TEERegistry, build_ssl_context_from_der + +logger = logging.getLogger(__name__) + +_TEE_REFRESH_INTERVAL = 300 # Re-resolve TEE from registry every 5 minutes + + +@dataclass(frozen=True) +class ActiveTEE: + """Snapshot of the currently connected TEE.""" + + endpoint: str + http_client: x402HttpxClient + tee_id: Optional[str] + payment_address: Optional[str] + + def metadata(self) -> Dict: + """Return TEE metadata dict for decorating responses.""" + return dict( + tee_id=self.tee_id, + tee_endpoint=self.endpoint, + tee_payment_address=self.payment_address, + ) + + +class TEEConnectionInterface(Protocol): + """Interface for TEE connection implementations.""" + + def get(self) -> ActiveTEE: ... + def ensure_refresh_loop(self) -> None: ... + async def reconnect(self) -> None: ... + async def close(self) -> None: ... + + +class StaticTEEConnection: + """TEE connection with a hardcoded endpoint URL. + + No registry lookup, no background refresh. TLS certificate verification + is disabled because self-hosted TEE servers typically use self-signed certs. + + Args: + x402_client: Configured x402 payment client for creating HTTP clients. + endpoint: The TEE endpoint URL to connect to. + """ + + def __init__(self, x402_client: x402Client, endpoint: str): + self._x402_client = x402_client + self._endpoint = endpoint + self._active: ActiveTEE = self._connect() + + def get(self) -> ActiveTEE: + """Return a snapshot of the current TEE connection.""" + return self._active + + def _connect(self) -> ActiveTEE: + return ActiveTEE( + endpoint=self._endpoint, + http_client=x402HttpxClient(self._x402_client, verify=False), + tee_id=None, + payment_address=None, + ) + + def ensure_refresh_loop(self) -> None: + """No-op — static connections don't refresh.""" + pass + + async def reconnect(self) -> None: + """Rebuild the HTTP client (same endpoint).""" + old_client = self._active.http_client + self._active = self._connect() + try: + await old_client.aclose() + except Exception: + logger.debug("Failed to close previous HTTP client during reconnect.", exc_info=True) + + async def close(self) -> None: + """Close the HTTP client.""" + await self._active.http_client.aclose() + + +class RegistryTEEConnection: + """TEE connection resolved from the on-chain registry. + + Handles TLS certificate pinning, background health checks, and automatic + failover when the current TEE becomes unavailable. + + Args: + x402_client: Configured x402 payment client for creating HTTP clients. + registry: TEERegistry for looking up active TEEs. + """ + + def __init__(self, x402_client: x402Client, registry: TEERegistry): + self._x402_client = x402_client + self._registry = registry + + self._refresh_lock = asyncio.Lock() + self._refresh_task: Optional[asyncio.Task] = None + + self._active: ActiveTEE = self._connect() + + # ── Public API ────────────────────────────────────────────────────── + + def get(self) -> ActiveTEE: + """Return a snapshot of the current TEE connection.""" + return self._active + + # ── Connection management ─────────────────────────────────────────── + + def _resolve_tee(self): + """Resolve TEE endpoint and metadata from the on-chain registry. + + Returns: + The TEE object from the registry. + + Raises: + RuntimeError: If the registry lookup fails. + ValueError: If no active LLM proxy TEE is found. + """ + try: + tee = self._registry.get_llm_tee() + except Exception as e: + raise RuntimeError(f"Failed to fetch LLM TEE endpoint from registry: {e}") from e + + if tee is None: + raise ValueError("No active LLM proxy TEE found in the registry.") + + logger.info("Using TEE endpoint from registry: %s (teeId=%s)", tee.endpoint, tee.tee_id) + return tee + + def _connect(self) -> ActiveTEE: + """Resolve TEE from registry and create a secure HTTP client.""" + tee = self._resolve_tee() + + ssl_ctx = build_ssl_context_from_der(tee.tls_cert_der) if tee.tls_cert_der else None + tls_verify: Union[ssl.SSLContext, bool] = ssl_ctx if ssl_ctx else True + + return ActiveTEE( + endpoint=tee.endpoint, + http_client=x402HttpxClient(self._x402_client, verify=tls_verify), + tee_id=tee.tee_id, + payment_address=tee.payment_address, + ) + + async def reconnect(self) -> None: + """Connect to a new TEE from the registry and rebuild the HTTP client.""" + async with self._refresh_lock: + try: + self._active = self._connect() + except Exception: + logger.debug("Failed to close previous HTTP client during TEE refresh.", exc_info=True) + + # ── Background health check ───────────────────────────────────────── + + def ensure_refresh_loop(self) -> None: + """Start the background TEE refresh loop if not already running. + + Called lazily from async request methods since ``__init__`` is synchronous. + """ + if self._refresh_task is not None and not self._refresh_task.done(): + return + self._refresh_task = asyncio.create_task(self._tee_refresh_loop()) + + async def _tee_refresh_loop(self) -> None: + """Periodically check that the current TEE is still active in the registry. + + If the current TEE is no longer active, performs a full refresh to pick + a new one. Does nothing when the current TEE is still healthy. + """ + while True: + await asyncio.sleep(_TEE_REFRESH_INTERVAL) + try: + active_tees = self._registry.get_active_tees_by_type(TEE_TYPE_LLM_PROXY) + if any(t.tee_id == self._active.tee_id for t in active_tees): + logger.debug("Current TEE %s still active; no refresh needed.", self._active.tee_id) + continue + logger.info("Current TEE %s no longer active; switching to a new one.", self._active.tee_id) + await self.reconnect() + except asyncio.CancelledError: + logger.debug("Background TEE health check cancelled; exiting loop.") + raise + except Exception: + logger.warning("Background TEE health check failed; will retry next cycle.", exc_info=True) + + # ── Lifecycle ─────────────────────────────────────────────────────── + + async def close(self) -> None: + """Cancel the background refresh loop and close the HTTP client.""" + if self._refresh_task is not None: + self._refresh_task.cancel() + self._refresh_task = None + await self._active.http_client.aclose() diff --git a/src/opengradient/client/tee_registry.py b/src/opengradient/client/tee_registry.py index 9ad3cfd7..571e3712 100644 --- a/src/opengradient/client/tee_registry.py +++ b/src/opengradient/client/tee_registry.py @@ -31,7 +31,7 @@ class TEEInfo(NamedTuple): last_heartbeat_at: int -@dataclass +@dataclass(frozen=True) class TEEEndpoint: """A verified TEE with its endpoint URL and TLS certificate from the registry.""" diff --git a/tests/client_test.py b/tests/client_test.py index 28df2bf0..2e2bf14c 100644 --- a/tests/client_test.py +++ b/tests/client_test.py @@ -72,13 +72,13 @@ class TestLLMInitialization: def test_llm_initialization(self, mock_tee_registry): """Test basic LLM initialization.""" llm = LLM(private_key=FAKE_PRIVATE_KEY) - assert llm._tee_endpoint == "https://test.tee.server" + assert llm._tee.get().endpoint == "https://test.tee.server" def test_llm_initialization_custom_url(self, mock_tee_registry): """Test LLM initialization with custom server URL.""" custom_llm_url = "https://custom.llm.server" - llm = LLM(private_key=FAKE_PRIVATE_KEY, llm_server_url=custom_llm_url) - assert llm._tee_endpoint == custom_llm_url + llm = LLM.from_url(private_key=FAKE_PRIVATE_KEY, llm_server_url=custom_llm_url) + assert llm._tee.get().endpoint == custom_llm_url # --- ModelHub Authentication Tests --- diff --git a/tests/llm_test.py b/tests/llm_test.py index 3953943e..ce0cd48b 100644 --- a/tests/llm_test.py +++ b/tests/llm_test.py @@ -107,7 +107,7 @@ async def aread(self) -> bytes: # so LLM.__init__ runs its real code but gets our FakeHTTPClient. _PATCHES = { - "x402_httpx": "src.opengradient.client.llm.x402HttpxClient", + "x402_httpx": "src.opengradient.client.tee_connection.x402HttpxClient", "x402_client": "src.opengradient.client.llm.x402Client", "signer": "src.opengradient.client.llm.EthAccountSigner", "register_exact": "src.opengradient.client.llm.register_exact_evm_client", @@ -137,10 +137,11 @@ def _make_llm( endpoint: str = "https://test.tee.server", ) -> LLM: """Build an LLM with an explicit server URL (skips registry lookup).""" - llm = LLM(private_key=FAKE_PRIVATE_KEY, llm_server_url=endpoint) - # llm_server_url path sets tee_id/payment_address to None; set them for assertions. - llm._tee_id = "test-tee-id" - llm._tee_payment_address = "0xTestPayment" + from dataclasses import replace + + llm = LLM.from_url(private_key=FAKE_PRIVATE_KEY, llm_server_url=endpoint) + # from_url sets tee_id/payment_address to None; replace with test values. + llm._tee._active = replace(llm._tee.get(), tee_id="test-tee-id", payment_address="0xTestPayment") return llm @@ -510,50 +511,6 @@ async def test_close_delegates_to_http_client(self, fake_http): # FakeHTTPClient.aclose is a no-op; just verify it doesn't blow up. -# ── TEE resolution tests ───────────────────────────────────────────── - - -class TestResolveTeE: - def test_explicit_url_skips_registry(self): - endpoint, cert, tee_id, pay_addr = LLM._resolve_tee("https://explicit.url", None, None) - - assert endpoint == "https://explicit.url" - assert cert is None - assert tee_id is None - assert pay_addr is None - - def test_missing_rpc_and_registry_raises(self): - with pytest.raises(ValueError): - LLM._resolve_tee(None, None, None) - - def test_missing_registry_address_raises(self): - with pytest.raises(ValueError): - LLM._resolve_tee(None, "https://rpc", None) - - def test_registry_returns_none_raises(self): - with patch("src.opengradient.client.llm.TEERegistry") as mock_reg: - mock_reg.return_value.get_llm_tee.return_value = None - - with pytest.raises(ValueError, match="No active LLM proxy TEE"): - LLM._resolve_tee(None, "https://rpc", "0xRegistry") - - def test_registry_success(self): - with patch("src.opengradient.client.llm.TEERegistry") as mock_reg: - mock_tee = MagicMock() - mock_tee.endpoint = "https://registry.tee" - mock_tee.tls_cert_der = b"cert-bytes" - mock_tee.tee_id = "tee-42" - mock_tee.payment_address = "0xPay" - mock_reg.return_value.get_llm_tee.return_value = mock_tee - - endpoint, cert, tee_id, pay_addr = LLM._resolve_tee(None, "https://rpc", "0xRegistry") - - assert endpoint == "https://registry.tee" - assert cert == b"cert-bytes" - assert tee_id == "tee-42" - assert pay_addr == "0xPay" - - # ── TEE retry tests (non-streaming) ────────────────────────────────── @@ -677,53 +634,6 @@ async def aread(self) -> bytes: assert len(fake_http.post_calls) == 1 -# ── _refresh_tee tests ───────────────────────────────────── - - -@pytest.mark.asyncio -class TestRefreshTeeAndReset: - async def test_replaces_http_client(self): - """After refresh, the http client should be a new instance.""" - clients_created = [] - - def make_client(*args, **kwargs): - c = FakeHTTPClient() - clients_created.append(c) - return c - - with ( - patch(_PATCHES["x402_httpx"], side_effect=make_client), - patch(_PATCHES["x402_client"]), - patch(_PATCHES["signer"]), - patch(_PATCHES["register_exact"]), - patch(_PATCHES["register_upto"]), - ): - llm = _make_llm() - old_client = llm._http_client - - await llm._refresh_tee() - - assert llm._http_client is not old_client - assert len(clients_created) == 2 # init + refresh - - async def test_closes_old_client(self, fake_http): - llm = _make_llm() - old_client = llm._http_client - old_client.aclose = AsyncMock() - - await llm._refresh_tee() - - old_client.aclose.assert_awaited_once() - - async def test_close_failure_is_swallowed(self, fake_http): - llm = _make_llm() - old_client = llm._http_client - old_client.aclose = AsyncMock(side_effect=OSError("already closed")) - - # Should not raise - await llm._refresh_tee() - - # ── TEE cert rotation (crash + re-register) tests ──────────────────── @@ -740,10 +650,10 @@ async def test_ssl_verification_failure_triggers_tee_refresh_completion(self, fa fake_http.fail_next_post(ssl.SSLCertVerificationError("certificate verify failed")) llm = _make_llm() - with patch.object(llm, "_connect_tee", wraps=llm._connect_tee) as spy: + with patch.object(llm._tee, "_connect", wraps=llm._tee._connect) as spy: result = await llm.completion(model=TEE_LLM.GPT_5, prompt="Hi") - # _connect_tee was called once during the retry (refresh) + # _connect was called once during the retry (reconnect) spy.assert_called_once() assert result.completion_output == "ok after refresh" assert len(fake_http.post_calls) == 2 @@ -756,7 +666,7 @@ async def test_ssl_verification_failure_triggers_tee_refresh_chat(self, fake_htt fake_http.fail_next_post(ssl.SSLCertVerificationError("certificate verify failed")) llm = _make_llm() - with patch.object(llm, "_connect_tee", wraps=llm._connect_tee) as spy: + with patch.object(llm._tee, "_connect", wraps=llm._tee._connect) as spy: result = await llm.chat(model=TEE_LLM.GPT_5, messages=[{"role": "user", "content": "Hi"}]) spy.assert_called_once() @@ -774,7 +684,7 @@ async def test_ssl_verification_failure_triggers_tee_refresh_streaming(self, fak fake_http.fail_next_stream(ssl.SSLCertVerificationError("certificate verify failed")) llm = _make_llm() - with patch.object(llm, "_connect_tee", wraps=llm._connect_tee) as spy: + with patch.object(llm._tee, "_connect", wraps=llm._tee._connect) as spy: gen = await llm.chat( model=TEE_LLM.GPT_5, messages=[{"role": "user", "content": "Hi"}], diff --git a/tests/tee_connection_test.py b/tests/tee_connection_test.py new file mode 100644 index 00000000..3f019123 --- /dev/null +++ b/tests/tee_connection_test.py @@ -0,0 +1,329 @@ +"""Tests for RegistryTEEConnection and ActiveTEE.""" + +import asyncio +import ssl +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from src.opengradient.client.tee_connection import ( + ActiveTEE, + RegistryTEEConnection, +) + + +# ── Helpers ────────────────────────────────────────────────────────── + + +class FakeHTTPClient: + """Minimal stand-in for x402HttpxClient.""" + + def __init__(self, *_a, **_kw): + self.closed = False + + async def aclose(self): + self.closed = True + + +def _mock_x402_client(): + return MagicMock() + + +def _make_registry_connection(*, registry=None, http_factory=None): + """Build a RegistryTEEConnection with patched externals.""" + factory = http_factory or FakeHTTPClient + with patch( + "src.opengradient.client.tee_connection.x402HttpxClient", + side_effect=factory, + ): + return RegistryTEEConnection( + x402_client=_mock_x402_client(), + registry=registry, + ) + + +def _mock_registry_with_tee(endpoint="https://tee.endpoint", tls_cert_der=None, tee_id="tee-1", payment_address="0xPay"): + mock_reg = MagicMock() + mock_tee = MagicMock() + mock_tee.endpoint = endpoint + mock_tee.tls_cert_der = tls_cert_der + mock_tee.tee_id = tee_id + mock_tee.payment_address = payment_address + mock_reg.get_llm_tee.return_value = mock_tee + return mock_reg + + +class TestActiveTEE: + def test_metadata_returns_dict(self): + tee = ActiveTEE( + endpoint="https://ep", + http_client=MagicMock(), + tee_id="tee-1", + payment_address="0xPay", + ) + assert tee.metadata() == { + "tee_id": "tee-1", + "tee_endpoint": "https://ep", + "tee_payment_address": "0xPay", + } + + def test_metadata_with_none_values(self): + tee = ActiveTEE( + endpoint="https://ep", + http_client=MagicMock(), + tee_id=None, + payment_address=None, + ) + meta = tee.metadata() + assert meta["tee_id"] is None + assert meta["tee_payment_address"] is None + + def test_frozen_dataclass(self): + tee = ActiveTEE( + endpoint="https://ep", + http_client=MagicMock(), + tee_id="tee-1", + payment_address="0xPay", + ) + with pytest.raises(AttributeError): + tee.endpoint = "https://other" + + +@pytest.mark.asyncio +class TestRegistryTEEConnection: + # ── init / resolve ─────────────────────────────────────────── + + async def test_get_returns_active_tee(self): + mock_reg = _mock_registry_with_tee() + conn = _make_registry_connection(registry=mock_reg) + active = conn.get() + + assert isinstance(active, ActiveTEE) + assert active.endpoint == "https://tee.endpoint" + + async def test_resolve_none_raises(self): + mock_reg = MagicMock() + mock_reg.get_llm_tee.return_value = None + + with patch( + "src.opengradient.client.tee_connection.x402HttpxClient", + side_effect=FakeHTTPClient, + ): + with pytest.raises(ValueError, match="No active LLM proxy TEE"): + RegistryTEEConnection(x402_client=_mock_x402_client(), registry=mock_reg) + + async def test_resolve_exception_wraps_in_runtime_error(self): + mock_reg = MagicMock() + mock_reg.get_llm_tee.side_effect = Exception("rpc down") + + with patch( + "src.opengradient.client.tee_connection.x402HttpxClient", + side_effect=FakeHTTPClient, + ): + with pytest.raises(RuntimeError, match="Failed to fetch LLM TEE"): + RegistryTEEConnection(x402_client=_mock_x402_client(), registry=mock_reg) + + async def test_resolve_success(self): + mock_reg = _mock_registry_with_tee( + endpoint="https://registry.tee", + tee_id="tee-42", + payment_address="0xPay", + ) + conn = _make_registry_connection(registry=mock_reg) + + assert conn.get().endpoint == "https://registry.tee" + assert conn.get().tee_id == "tee-42" + assert conn.get().payment_address == "0xPay" + + async def test_builds_ssl_context_from_der(self): + mock_reg = _mock_registry_with_tee(tls_cert_der=b"fake-der") + mock_ssl_ctx = MagicMock(spec=ssl.SSLContext) + + with ( + patch( + "src.opengradient.client.tee_connection.x402HttpxClient", + side_effect=FakeHTTPClient, + ), + patch( + "src.opengradient.client.tee_connection.build_ssl_context_from_der", + return_value=mock_ssl_ctx, + ) as mock_build, + ): + conn = RegistryTEEConnection( + x402_client=_mock_x402_client(), + registry=mock_reg, + ) + + mock_build.assert_called_once_with(b"fake-der") + assert conn.get().tee_id == "tee-1" + + # ── reconnect ──────────────────────────────────────────────── + + async def test_reconnect_replaces_active_tee(self): + clients_created = [] + + def make_client(*args, **kwargs): + c = FakeHTTPClient() + clients_created.append(c) + return c + + mock_reg = _mock_registry_with_tee() + + with patch( + "src.opengradient.client.tee_connection.x402HttpxClient", + side_effect=make_client, + ): + conn = RegistryTEEConnection( + x402_client=_mock_x402_client(), + registry=mock_reg, + ) + old_client = conn.get().http_client + await conn.reconnect() + + assert conn.get().http_client is not old_client + assert len(clients_created) == 2 + + + async def test_reconnect_swallows_close_failure(self): + mock_reg = _mock_registry_with_tee() + conn = _make_registry_connection(registry=mock_reg) + conn.get().http_client.aclose = AsyncMock(side_effect=OSError("already closed")) + + with patch( + "src.opengradient.client.tee_connection.x402HttpxClient", + side_effect=FakeHTTPClient, + ): + await conn.reconnect() # should not raise + + async def test_reconnect_is_serialized(self): + call_order = [] + original_connect = RegistryTEEConnection._connect + + def slow_connect(self): + call_order.append("start") + result = original_connect(self) + call_order.append("end") + return result + + mock_reg = _mock_registry_with_tee() + conn = _make_registry_connection(registry=mock_reg) + + with patch.object(RegistryTEEConnection, "_connect", slow_connect): + await asyncio.gather(conn.reconnect(), conn.reconnect()) + + assert call_order == ["start", "end", "start", "end"] + + # ── refresh loop ───────────────────────────────────────────── + + async def test_ensure_refresh_loop_starts_task(self): + mock_reg = _mock_registry_with_tee() + conn = _make_registry_connection(registry=mock_reg) + + conn.ensure_refresh_loop() + + assert conn._refresh_task is not None + assert not conn._refresh_task.done() + + conn._refresh_task.cancel() + try: + await conn._refresh_task + except asyncio.CancelledError: + pass + + async def test_ensure_refresh_loop_is_idempotent(self): + mock_reg = _mock_registry_with_tee() + conn = _make_registry_connection(registry=mock_reg) + + conn.ensure_refresh_loop() + first_task = conn._refresh_task + conn.ensure_refresh_loop() + + assert conn._refresh_task is first_task + + conn._refresh_task.cancel() + try: + await conn._refresh_task + except asyncio.CancelledError: + pass + + async def test_refresh_loop_skips_when_tee_still_active(self): + mock_reg = _mock_registry_with_tee(tee_id="tee-1") + active_tee = MagicMock() + active_tee.tee_id = "tee-1" + mock_reg.get_active_tees_by_type.return_value = [active_tee] + + conn = _make_registry_connection(registry=mock_reg) + + with patch.object(conn, "reconnect", new_callable=AsyncMock) as mock_reconnect: + with patch( + "src.opengradient.client.tee_connection.asyncio.sleep", + side_effect=[None, asyncio.CancelledError], + ): + with pytest.raises(asyncio.CancelledError): + await conn._tee_refresh_loop() + + mock_reconnect.assert_not_called() + + async def test_refresh_loop_reconnects_when_tee_gone(self): + mock_reg = _mock_registry_with_tee(tee_id="tee-1") + other_tee = MagicMock() + other_tee.tee_id = "tee-99" + mock_reg.get_active_tees_by_type.return_value = [other_tee] + + conn = _make_registry_connection(registry=mock_reg) + + with patch.object(conn, "reconnect", new_callable=AsyncMock) as mock_reconnect: + with patch( + "src.opengradient.client.tee_connection.asyncio.sleep", + side_effect=[None, asyncio.CancelledError], + ): + with pytest.raises(asyncio.CancelledError): + await conn._tee_refresh_loop() + + mock_reconnect.assert_awaited_once() + + async def test_refresh_loop_survives_registry_error(self): + mock_reg = _mock_registry_with_tee(tee_id="tee-1") + mock_reg.get_active_tees_by_type.side_effect = [ + RuntimeError("rpc timeout"), + asyncio.CancelledError, + ] + + conn = _make_registry_connection(registry=mock_reg) + + with patch( + "src.opengradient.client.tee_connection.asyncio.sleep", + side_effect=[None, None], + ): + with pytest.raises(asyncio.CancelledError): + await conn._tee_refresh_loop() + + assert mock_reg.get_active_tees_by_type.call_count == 2 + + # ── close ──────────────────────────────────────────────────── + + async def test_close_closes_http_client(self): + mock_reg = _mock_registry_with_tee() + conn = _make_registry_connection(registry=mock_reg) + conn.get().http_client.aclose = AsyncMock() + + await conn.close() + + conn.get().http_client.aclose.assert_awaited_once() + + async def test_close_cancels_refresh_task(self): + mock_reg = _mock_registry_with_tee() + conn = _make_registry_connection(registry=mock_reg) + mock_task = MagicMock() + conn._refresh_task = mock_task + + await conn.close() + + mock_task.cancel.assert_called_once() + assert conn._refresh_task is None + + async def test_close_without_refresh_task(self): + mock_reg = _mock_registry_with_tee() + conn = _make_registry_connection(registry=mock_reg) + + await conn.close() # should not raise