Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions src/opengradient/client/tee_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,12 +142,10 @@ def _connect(self) -> ActiveTEE:
"""Resolve TEE from registry and create a secure HTTP client."""
tee = self._resolve_tee()

ssl_ctx = build_ssl_context_from_der(tee.tls_cert_der) if tee.tls_cert_der else None
tls_verify: Union[ssl.SSLContext, bool] = ssl_ctx if ssl_ctx else True

ssl_ctx = build_ssl_context_from_der(tee.tls_cert_der)
return ActiveTEE(
endpoint=tee.endpoint,
http_client=x402HttpxClient(self._x402_client, verify=tls_verify),
http_client=x402HttpxClient(self._x402_client, verify=ssl_ctx),
tee_id=tee.tee_id,
payment_address=tee.payment_address,
)
Expand Down
10 changes: 8 additions & 2 deletions tests/client_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,16 @@
@pytest.fixture
def mock_tee_registry():
"""Mock the TEE registry so LLM.__init__ doesn't need a live registry."""
with patch("src.opengradient.client.llm.TEERegistry") as mock_tee_registry:
with (
patch("src.opengradient.client.llm.TEERegistry") as mock_tee_registry,
patch(
"src.opengradient.client.tee_connection.build_ssl_context_from_der",
return_value=MagicMock(),
),
):
mock_tee = MagicMock()
mock_tee.endpoint = "https://test.tee.server"
mock_tee.tls_cert_der = None
mock_tee.tls_cert_der = b"fake-der"
mock_tee.tee_id = "test-tee-id"
mock_tee.payment_address = "0xTestPaymentAddress"
mock_tee_registry.return_value.get_llm_tee.return_value = mock_tee
Expand Down
167 changes: 158 additions & 9 deletions tests/tee_connection_test.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,25 @@
"""Tests for RegistryTEEConnection and ActiveTEE."""

import asyncio
import datetime
import os
import ssl
import tempfile
from unittest.mock import AsyncMock, MagicMock, patch

import httpx
import pytest
from cryptography import x509
from cryptography.hazmat.primitives import hashes, serialization
from cryptography.hazmat.primitives.asymmetric import rsa
from cryptography.x509.oid import NameOID
from x402 import x402Client

from src.opengradient.client.tee_connection import (
ActiveTEE,
RegistryTEEConnection,
)
from src.opengradient.client.tee_registry import build_ssl_context_from_der


# ── Helpers ──────────────────────────────────────────────────────────
Expand All @@ -32,17 +42,23 @@ def _mock_x402_client():
def _make_registry_connection(*, registry=None, http_factory=None):
"""Build a RegistryTEEConnection with patched externals."""
factory = http_factory or FakeHTTPClient
with patch(
"src.opengradient.client.tee_connection.x402HttpxClient",
side_effect=factory,
with (
patch(
"src.opengradient.client.tee_connection.x402HttpxClient",
side_effect=factory,
),
patch(
"src.opengradient.client.tee_connection.build_ssl_context_from_der",
return_value=MagicMock(spec=ssl.SSLContext),
),
):
return RegistryTEEConnection(
x402_client=_mock_x402_client(),
registry=registry,
)


def _mock_registry_with_tee(endpoint="https://tee.endpoint", tls_cert_der=None, tee_id="tee-1", payment_address="0xPay"):
def _mock_registry_with_tee(endpoint="https://tee.endpoint", tls_cert_der=b"fake-der", tee_id="tee-1", payment_address="0xPay"):
mock_reg = MagicMock()
mock_tee = MagicMock()
mock_tee.endpoint = endpoint
Expand Down Expand Up @@ -169,9 +185,15 @@ def make_client(*args, **kwargs):

mock_reg = _mock_registry_with_tee()

with patch(
"src.opengradient.client.tee_connection.x402HttpxClient",
side_effect=make_client,
with (
patch(
"src.opengradient.client.tee_connection.x402HttpxClient",
side_effect=make_client,
),
patch(
"src.opengradient.client.tee_connection.build_ssl_context_from_der",
return_value=MagicMock(spec=ssl.SSLContext),
),
):
conn = RegistryTEEConnection(
x402_client=_mock_x402_client(),
Expand All @@ -183,7 +205,6 @@ def make_client(*args, **kwargs):
assert conn.get().http_client is not old_client
assert len(clients_created) == 2


async def test_reconnect_swallows_close_failure(self):
mock_reg = _mock_registry_with_tee()
conn = _make_registry_connection(registry=mock_reg)
Expand All @@ -208,7 +229,13 @@ def slow_connect(self):
mock_reg = _mock_registry_with_tee()
conn = _make_registry_connection(registry=mock_reg)

with patch.object(RegistryTEEConnection, "_connect", slow_connect):
with patch.object(RegistryTEEConnection, "_connect", slow_connect), patch(
"src.opengradient.client.tee_connection.build_ssl_context_from_der",
return_value=MagicMock(spec=ssl.SSLContext),
), patch(
"src.opengradient.client.tee_connection.x402HttpxClient",
side_effect=FakeHTTPClient,
):
await asyncio.gather(conn.reconnect(), conn.reconnect())

assert call_order == ["start", "end", "start", "end"]
Expand Down Expand Up @@ -327,3 +354,125 @@ async def test_close_without_refresh_task(self):
conn = _make_registry_connection(registry=mock_reg)

await conn.close() # should not raise


# ── TLS certificate verification (real handshake) ────────────────────


def _make_self_signed_cert():
"""Generate a self-signed cert. Returns (der_bytes, pem_cert_bytes, pem_key_bytes)."""
key = rsa.generate_private_key(public_exponent=65537, key_size=2048)
subject = issuer = x509.Name([x509.NameAttribute(NameOID.COMMON_NAME, "localhost")])
cert = (
x509.CertificateBuilder()
.subject_name(subject)
.issuer_name(issuer)
.public_key(key.public_key())
.serial_number(x509.random_serial_number())
.not_valid_before(datetime.datetime.now(datetime.UTC))
.not_valid_after(datetime.datetime.now(datetime.UTC) + datetime.timedelta(days=1))
.sign(key, hashes.SHA256())
)
return (
cert.public_bytes(serialization.Encoding.DER),
cert.public_bytes(serialization.Encoding.PEM),
key.private_bytes(serialization.Encoding.PEM, serialization.PrivateFormat.TraditionalOpenSSL, serialization.NoEncryption()),
)


@pytest.fixture
async def tls_server():
"""Spin up a local TLS server with a self-signed cert."""
der, pem_cert, pem_key = _make_self_signed_cert()

cert_file = tempfile.NamedTemporaryFile(suffix=".pem", delete=False)
key_file = tempfile.NamedTemporaryFile(suffix=".pem", delete=False)
try:
cert_file.write(pem_cert)
cert_file.close()
key_file.write(pem_key)
key_file.close()

server_ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
server_ctx.load_cert_chain(cert_file.name, key_file.name)

async def handler(reader, writer):
await reader.read(4096)
writer.write(b"HTTP/1.1 200 OK\r\nContent-Length: 2\r\nConnection: close\r\n\r\nok")
await writer.drain()
writer.close()

server = await asyncio.start_server(handler, "127.0.0.1", 0, ssl=server_ctx)
port = server.sockets[0].getsockname()[1]

yield {"port": port, "der": der}

server.close()
await server.wait_closed()
finally:
os.unlink(cert_file.name)
os.unlink(key_file.name)


def _registry_with_real_cert(tls_server):
"""Return a mock registry that serves the local TLS server's real DER cert."""
return _mock_registry_with_tee(
endpoint=f"https://127.0.0.1:{tls_server['port']}",
tls_cert_der=tls_server["der"],
tee_id="tee-real",
payment_address="0xRealPay",
)


@pytest.mark.asyncio
class TestTlsCertVerification:
"""End-to-end TLS handshake tests through RegistryTEEConnection.

A real local TLS server is started with a self-signed cert. The registry
mock returns that cert's DER bytes. RegistryTEEConnection._connect() runs
its real code (build_ssl_context_from_der → x402HttpxClient(verify=ctx))
so the full cert-pinning path is exercised with an actual TLS handshake.
"""

async def test_connect_succeeds_with_matching_cert(self, tls_server):
mock_reg = _registry_with_real_cert(tls_server)
conn = RegistryTEEConnection(x402_client=x402Client(), registry=mock_reg)

resp = await conn.get().http_client.get(f"https://127.0.0.1:{tls_server['port']}/")
assert resp.status_code == 200
assert conn.get().tee_id == "tee-real"
assert conn.get().payment_address == "0xRealPay"
await conn.close()

async def test_connect_fails_with_wrong_cert(self, tls_server):
wrong_der, _, _ = _make_self_signed_cert() # different key pair
mock_reg = _mock_registry_with_tee(
endpoint=f"https://127.0.0.1:{tls_server['port']}",
tls_cert_der=wrong_der,
)
conn = RegistryTEEConnection(x402_client=x402Client(), registry=mock_reg)

with pytest.raises(httpx.ConnectError):
await conn.get().http_client.get(f"https://127.0.0.1:{tls_server['port']}/")
await conn.close()

async def test_connect_fails_with_no_cert_pinning(self, tls_server):
"""Without a pinned cert (tls_cert_der=None), build_ssl_context_from_der
rejects the None value and connection construction fails."""
mock_reg = _mock_registry_with_tee(
endpoint=f"https://127.0.0.1:{tls_server['port']}",
tls_cert_der=None,
)
with pytest.raises(TypeError):
RegistryTEEConnection(x402_client=x402Client(), registry=mock_reg)

async def test_reconnect_picks_up_new_cert(self, tls_server):
"""After reconnect, the connection uses the freshly-resolved cert."""
mock_reg = _registry_with_real_cert(tls_server)
conn = RegistryTEEConnection(x402_client=x402Client(), registry=mock_reg)

await conn.reconnect()

resp = await conn.get().http_client.get(f"https://127.0.0.1:{tls_server['port']}/")
assert resp.status_code == 200
await conn.close()
Loading