diff --git a/src/opengradient/client/tee_connection.py b/src/opengradient/client/tee_connection.py index 2a7682d..2807b88 100644 --- a/src/opengradient/client/tee_connection.py +++ b/src/opengradient/client/tee_connection.py @@ -142,12 +142,10 @@ def _connect(self) -> ActiveTEE: """Resolve TEE from registry and create a secure HTTP client.""" tee = self._resolve_tee() - ssl_ctx = build_ssl_context_from_der(tee.tls_cert_der) if tee.tls_cert_der else None - tls_verify: Union[ssl.SSLContext, bool] = ssl_ctx if ssl_ctx else True - + ssl_ctx = build_ssl_context_from_der(tee.tls_cert_der) return ActiveTEE( endpoint=tee.endpoint, - http_client=x402HttpxClient(self._x402_client, verify=tls_verify), + http_client=x402HttpxClient(self._x402_client, verify=ssl_ctx), tee_id=tee.tee_id, payment_address=tee.payment_address, ) diff --git a/tests/client_test.py b/tests/client_test.py index 2e2bf14..d7d70f8 100644 --- a/tests/client_test.py +++ b/tests/client_test.py @@ -22,10 +22,16 @@ @pytest.fixture def mock_tee_registry(): """Mock the TEE registry so LLM.__init__ doesn't need a live registry.""" - with patch("src.opengradient.client.llm.TEERegistry") as mock_tee_registry: + with ( + patch("src.opengradient.client.llm.TEERegistry") as mock_tee_registry, + patch( + "src.opengradient.client.tee_connection.build_ssl_context_from_der", + return_value=MagicMock(), + ), + ): mock_tee = MagicMock() mock_tee.endpoint = "https://test.tee.server" - mock_tee.tls_cert_der = None + mock_tee.tls_cert_der = b"fake-der" mock_tee.tee_id = "test-tee-id" mock_tee.payment_address = "0xTestPaymentAddress" mock_tee_registry.return_value.get_llm_tee.return_value = mock_tee diff --git a/tests/tee_connection_test.py b/tests/tee_connection_test.py index 3f01912..2d54c77 100644 --- a/tests/tee_connection_test.py +++ b/tests/tee_connection_test.py @@ -1,15 +1,25 @@ """Tests for RegistryTEEConnection and ActiveTEE.""" import asyncio +import datetime +import os import ssl +import tempfile from unittest.mock import AsyncMock, MagicMock, patch +import httpx import pytest +from cryptography import x509 +from cryptography.hazmat.primitives import hashes, serialization +from cryptography.hazmat.primitives.asymmetric import rsa +from cryptography.x509.oid import NameOID +from x402 import x402Client from src.opengradient.client.tee_connection import ( ActiveTEE, RegistryTEEConnection, ) +from src.opengradient.client.tee_registry import build_ssl_context_from_der # ── Helpers ────────────────────────────────────────────────────────── @@ -32,9 +42,15 @@ def _mock_x402_client(): def _make_registry_connection(*, registry=None, http_factory=None): """Build a RegistryTEEConnection with patched externals.""" factory = http_factory or FakeHTTPClient - with patch( - "src.opengradient.client.tee_connection.x402HttpxClient", - side_effect=factory, + with ( + patch( + "src.opengradient.client.tee_connection.x402HttpxClient", + side_effect=factory, + ), + patch( + "src.opengradient.client.tee_connection.build_ssl_context_from_der", + return_value=MagicMock(spec=ssl.SSLContext), + ), ): return RegistryTEEConnection( x402_client=_mock_x402_client(), @@ -42,7 +58,7 @@ def _make_registry_connection(*, registry=None, http_factory=None): ) -def _mock_registry_with_tee(endpoint="https://tee.endpoint", tls_cert_der=None, tee_id="tee-1", payment_address="0xPay"): +def _mock_registry_with_tee(endpoint="https://tee.endpoint", tls_cert_der=b"fake-der", tee_id="tee-1", payment_address="0xPay"): mock_reg = MagicMock() mock_tee = MagicMock() mock_tee.endpoint = endpoint @@ -169,9 +185,15 @@ def make_client(*args, **kwargs): mock_reg = _mock_registry_with_tee() - with patch( - "src.opengradient.client.tee_connection.x402HttpxClient", - side_effect=make_client, + with ( + patch( + "src.opengradient.client.tee_connection.x402HttpxClient", + side_effect=make_client, + ), + patch( + "src.opengradient.client.tee_connection.build_ssl_context_from_der", + return_value=MagicMock(spec=ssl.SSLContext), + ), ): conn = RegistryTEEConnection( x402_client=_mock_x402_client(), @@ -183,7 +205,6 @@ def make_client(*args, **kwargs): assert conn.get().http_client is not old_client assert len(clients_created) == 2 - async def test_reconnect_swallows_close_failure(self): mock_reg = _mock_registry_with_tee() conn = _make_registry_connection(registry=mock_reg) @@ -208,7 +229,13 @@ def slow_connect(self): mock_reg = _mock_registry_with_tee() conn = _make_registry_connection(registry=mock_reg) - with patch.object(RegistryTEEConnection, "_connect", slow_connect): + with patch.object(RegistryTEEConnection, "_connect", slow_connect), patch( + "src.opengradient.client.tee_connection.build_ssl_context_from_der", + return_value=MagicMock(spec=ssl.SSLContext), + ), patch( + "src.opengradient.client.tee_connection.x402HttpxClient", + side_effect=FakeHTTPClient, + ): await asyncio.gather(conn.reconnect(), conn.reconnect()) assert call_order == ["start", "end", "start", "end"] @@ -327,3 +354,125 @@ async def test_close_without_refresh_task(self): conn = _make_registry_connection(registry=mock_reg) await conn.close() # should not raise + + +# ── TLS certificate verification (real handshake) ──────────────────── + + +def _make_self_signed_cert(): + """Generate a self-signed cert. Returns (der_bytes, pem_cert_bytes, pem_key_bytes).""" + key = rsa.generate_private_key(public_exponent=65537, key_size=2048) + subject = issuer = x509.Name([x509.NameAttribute(NameOID.COMMON_NAME, "localhost")]) + cert = ( + x509.CertificateBuilder() + .subject_name(subject) + .issuer_name(issuer) + .public_key(key.public_key()) + .serial_number(x509.random_serial_number()) + .not_valid_before(datetime.datetime.now(datetime.UTC)) + .not_valid_after(datetime.datetime.now(datetime.UTC) + datetime.timedelta(days=1)) + .sign(key, hashes.SHA256()) + ) + return ( + cert.public_bytes(serialization.Encoding.DER), + cert.public_bytes(serialization.Encoding.PEM), + key.private_bytes(serialization.Encoding.PEM, serialization.PrivateFormat.TraditionalOpenSSL, serialization.NoEncryption()), + ) + + +@pytest.fixture +async def tls_server(): + """Spin up a local TLS server with a self-signed cert.""" + der, pem_cert, pem_key = _make_self_signed_cert() + + cert_file = tempfile.NamedTemporaryFile(suffix=".pem", delete=False) + key_file = tempfile.NamedTemporaryFile(suffix=".pem", delete=False) + try: + cert_file.write(pem_cert) + cert_file.close() + key_file.write(pem_key) + key_file.close() + + server_ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) + server_ctx.load_cert_chain(cert_file.name, key_file.name) + + async def handler(reader, writer): + await reader.read(4096) + writer.write(b"HTTP/1.1 200 OK\r\nContent-Length: 2\r\nConnection: close\r\n\r\nok") + await writer.drain() + writer.close() + + server = await asyncio.start_server(handler, "127.0.0.1", 0, ssl=server_ctx) + port = server.sockets[0].getsockname()[1] + + yield {"port": port, "der": der} + + server.close() + await server.wait_closed() + finally: + os.unlink(cert_file.name) + os.unlink(key_file.name) + + +def _registry_with_real_cert(tls_server): + """Return a mock registry that serves the local TLS server's real DER cert.""" + return _mock_registry_with_tee( + endpoint=f"https://127.0.0.1:{tls_server['port']}", + tls_cert_der=tls_server["der"], + tee_id="tee-real", + payment_address="0xRealPay", + ) + + +@pytest.mark.asyncio +class TestTlsCertVerification: + """End-to-end TLS handshake tests through RegistryTEEConnection. + + A real local TLS server is started with a self-signed cert. The registry + mock returns that cert's DER bytes. RegistryTEEConnection._connect() runs + its real code (build_ssl_context_from_der → x402HttpxClient(verify=ctx)) + so the full cert-pinning path is exercised with an actual TLS handshake. + """ + + async def test_connect_succeeds_with_matching_cert(self, tls_server): + mock_reg = _registry_with_real_cert(tls_server) + conn = RegistryTEEConnection(x402_client=x402Client(), registry=mock_reg) + + resp = await conn.get().http_client.get(f"https://127.0.0.1:{tls_server['port']}/") + assert resp.status_code == 200 + assert conn.get().tee_id == "tee-real" + assert conn.get().payment_address == "0xRealPay" + await conn.close() + + async def test_connect_fails_with_wrong_cert(self, tls_server): + wrong_der, _, _ = _make_self_signed_cert() # different key pair + mock_reg = _mock_registry_with_tee( + endpoint=f"https://127.0.0.1:{tls_server['port']}", + tls_cert_der=wrong_der, + ) + conn = RegistryTEEConnection(x402_client=x402Client(), registry=mock_reg) + + with pytest.raises(httpx.ConnectError): + await conn.get().http_client.get(f"https://127.0.0.1:{tls_server['port']}/") + await conn.close() + + async def test_connect_fails_with_no_cert_pinning(self, tls_server): + """Without a pinned cert (tls_cert_der=None), build_ssl_context_from_der + rejects the None value and connection construction fails.""" + mock_reg = _mock_registry_with_tee( + endpoint=f"https://127.0.0.1:{tls_server['port']}", + tls_cert_der=None, + ) + with pytest.raises(TypeError): + RegistryTEEConnection(x402_client=x402Client(), registry=mock_reg) + + async def test_reconnect_picks_up_new_cert(self, tls_server): + """After reconnect, the connection uses the freshly-resolved cert.""" + mock_reg = _registry_with_real_cert(tls_server) + conn = RegistryTEEConnection(x402_client=x402Client(), registry=mock_reg) + + await conn.reconnect() + + resp = await conn.get().http_client.get(f"https://127.0.0.1:{tls_server['port']}/") + assert resp.status_code == 200 + await conn.close()