diff --git a/src/opengradient/client/llm.py b/src/opengradient/client/llm.py index 20e32fb..610c44b 100644 --- a/src/opengradient/client/llm.py +++ b/src/opengradient/client/llm.py @@ -18,7 +18,12 @@ from ..types import TEE_LLM, ResponseFormat, StreamChoice, StreamChunk, StreamDelta, TextGenerationOutput, x402SettlementMode from .opg_token import Permit2ApprovalResult, ensure_opg_approval -from .tee_connection import RegistryTEEConnection, StaticTEEConnection, TEEConnectionInterface +from .tee_connection import ( + ActiveTEE, + RegistryTEEConnection, + StaticTEEConnection, + TEEConnectionInterface, +) from .tee_registry import TEERegistry logger = logging.getLogger(__name__) @@ -140,6 +145,15 @@ async def close(self) -> None: """Cancel the background refresh loop and close the HTTP client.""" await self._tee.close() + def resolve_tee_connection(self, tee_id: Optional[str] = None) -> ActiveTEE: + """Resolve the current TEE or a specific active registry TEE. + + This is primarily for backend relays that need SDK-managed TEE routing, + TLS pinning, and x402 clients without using the chat/completion helpers + directly, for example when forwarding OHTTP ciphertext. + """ + return self._tee.resolve(tee_id) + # ── Request helpers ───────────────────────────────────────────────── def _headers(self, settlement_mode: x402SettlementMode) -> Dict[str, str]: diff --git a/src/opengradient/client/tee_connection.py b/src/opengradient/client/tee_connection.py index 2807b88..23bb97b 100644 --- a/src/opengradient/client/tee_connection.py +++ b/src/opengradient/client/tee_connection.py @@ -9,7 +9,12 @@ from x402 import x402Client from x402.http.clients import x402HttpxClient -from .tee_registry import TEE_TYPE_LLM_PROXY, TEERegistry, build_ssl_context_from_der +from .tee_registry import ( + TEE_TYPE_LLM_PROXY, + TEEEndpoint, + TEERegistry, + build_ssl_context_from_der, +) logger = logging.getLogger(__name__) @@ -38,11 +43,23 @@ class TEEConnectionInterface(Protocol): """Interface for TEE connection implementations.""" def get(self) -> ActiveTEE: ... + def resolve(self, tee_id: Optional[str] = None) -> ActiveTEE: ... def ensure_refresh_loop(self) -> None: ... async def reconnect(self) -> None: ... async def close(self) -> None: ... +def _normalize_tee_id(tee_id: Optional[str]) -> Optional[str]: + if not tee_id: + return None + normalized = tee_id.strip().lower() + if not normalized: + return None + if not normalized.startswith("0x"): + normalized = f"0x{normalized}" + return normalized + + class StaticTEEConnection: """TEE connection with a hardcoded endpoint URL. @@ -63,6 +80,14 @@ def get(self) -> ActiveTEE: """Return a snapshot of the current TEE connection.""" return self._active + def resolve(self, tee_id: Optional[str] = None) -> ActiveTEE: + """Return the static connection. + + Static/dev connections do not have a registry to validate selected + TEE ids against, so they always resolve to the configured endpoint. + """ + return self._active + def _connect(self) -> ActiveTEE: return ActiveTEE( endpoint=self._endpoint, @@ -106,6 +131,7 @@ def __init__(self, x402_client: x402Client, registry: TEERegistry): self._refresh_lock = asyncio.Lock() self._refresh_task: Optional[asyncio.Task] = None + self._active_by_tee_id: dict[str, ActiveTEE] = {} self._active: ActiveTEE = self._connect() @@ -115,9 +141,47 @@ def get(self) -> ActiveTEE: """Return a snapshot of the current TEE connection.""" return self._active + def resolve(self, tee_id: Optional[str] = None) -> ActiveTEE: + """Resolve a TEE connection, optionally pinned to an active TEE id. + + Backend OHTTP relays can use this when the browser encrypted to a + specific on-chain TEE config, while the backend still owns x402 payment. + """ + normalized_tee_id = _normalize_tee_id(tee_id) + if normalized_tee_id is None: + return self._active + + active_tee_id = _normalize_tee_id(self._active.tee_id) + if normalized_tee_id == active_tee_id: + return self._active + + for tee in self._registry.get_active_tees_by_type(TEE_TYPE_LLM_PROXY): + if _normalize_tee_id(tee.tee_id) != normalized_tee_id: + continue + + cached = self._active_by_tee_id.get(normalized_tee_id) + if ( + cached is not None + and cached.endpoint.rstrip("/") == tee.endpoint.rstrip("/") + ): + return cached + + resolved = self._connect_to_tee(tee) + self._active_by_tee_id[normalized_tee_id] = resolved + logger.info( + "Resolved selected TEE endpoint from registry: %s (teeId=%s)", + resolved.endpoint, + normalized_tee_id, + ) + return resolved + + raise ValueError( + f"Selected TEE is not active in the registry: {normalized_tee_id}" + ) + # ── Connection management ─────────────────────────────────────────── - def _resolve_tee(self): + def _resolve_tee(self) -> TEEEndpoint: """Resolve TEE endpoint and metadata from the on-chain registry. Returns: @@ -141,7 +205,10 @@ def _resolve_tee(self): def _connect(self) -> ActiveTEE: """Resolve TEE from registry and create a secure HTTP client.""" tee = self._resolve_tee() + return self._connect_to_tee(tee) + def _connect_to_tee(self, tee: TEEEndpoint) -> ActiveTEE: + """Create a pinned x402 HTTP client for a resolved registry TEE.""" ssl_ctx = build_ssl_context_from_der(tee.tls_cert_der) return ActiveTEE( endpoint=tee.endpoint,