diff --git a/openhands-sdk/openhands/sdk/__init__.py b/openhands-sdk/openhands/sdk/__init__.py index fbad92889f..52d4163425 100644 --- a/openhands-sdk/openhands/sdk/__init__.py +++ b/openhands-sdk/openhands/sdk/__init__.py @@ -117,6 +117,27 @@ _print_banner(__version__) +def create_llm(model: str = "", **kwargs) -> LLM: + """Factory function that routes to the correct LLM subclass. + + Routes models with the "databricks/" prefix to DatabricksLLM (native FMAPI + provider, bypassing LiteLLM). All other models use the base LLM class. + + Uses lazy import for DatabricksLLM to avoid circular import and to keep + Databricks dependencies optional (install via: pip install openhands-sdk[databricks]). + + Example: + llm = create_llm("databricks/databricks-meta-llama-3-3-70b-instruct", + databricks_host="https://adb-xxx.azuredatabricks.net", + api_key=SecretStr("dapi...")) + llm = create_llm("claude-sonnet-4-20250514", api_key=SecretStr("sk-ant-...")) + """ + if model.startswith("databricks/"): + from openhands.sdk.llm.providers.databricks.llm import DatabricksLLM + + return DatabricksLLM(model=model, **kwargs) + return LLM(model=model, **kwargs) + __all__ = [ "LLM", "LLM_PROFILE_SCHEMA_VERSION", @@ -203,4 +224,5 @@ "load_user_skills", "page_iterator", "__version__", + "create_llm", ] diff --git a/openhands-sdk/openhands/sdk/agent/base.py b/openhands-sdk/openhands/sdk/agent/base.py index dbd580e2d1..109b9ffbc7 100644 --- a/openhands-sdk/openhands/sdk/agent/base.py +++ b/openhands-sdk/openhands/sdk/agent/base.py @@ -8,7 +8,7 @@ from abc import ABC, abstractmethod from collections.abc import Generator, Iterable, Sequence from concurrent.futures import ThreadPoolExecutor -from typing import TYPE_CHECKING, Any, Literal +from typing import TYPE_CHECKING, Annotated, Any, Literal from pydantic import ( BaseModel, @@ -17,6 +17,7 @@ PrivateAttr, SecretStr, SerializationInfo, + SerializeAsAny, ValidationInfo, model_serializer, model_validator, @@ -110,7 +111,7 @@ class AgentBase(DiscriminatedUnionMixin, ABC): arbitrary_types_allowed=True, ) - llm: LLM = Field( + llm: Annotated[LLM, SerializeAsAny()] = Field( ..., description="LLM configuration for the agent.", examples=[ diff --git a/openhands-sdk/openhands/sdk/context/condenser/llm_summarizing_condenser.py b/openhands-sdk/openhands/sdk/context/condenser/llm_summarizing_condenser.py index aae387be9b..1544843054 100644 --- a/openhands-sdk/openhands/sdk/context/condenser/llm_summarizing_condenser.py +++ b/openhands-sdk/openhands/sdk/context/condenser/llm_summarizing_condenser.py @@ -1,8 +1,9 @@ import os from collections.abc import Sequence from enum import Enum +from typing import Annotated -from pydantic import Field, model_validator +from pydantic import Field, SerializeAsAny, model_validator from openhands.sdk.context.condenser.base import ( CondensationRequirement, @@ -43,7 +44,7 @@ class LLMSummarizingCondenser(RollingCondenser): it is the same as the one defined in this condenser. """ - llm: LLM + llm: Annotated[LLM, SerializeAsAny()] max_size: int = Field(default=240, gt=0) max_tokens: int | None = None diff --git a/openhands-sdk/openhands/sdk/event/conversation_error.py b/openhands-sdk/openhands/sdk/event/conversation_error.py index 499d727e98..08a5806d22 100644 --- a/openhands-sdk/openhands/sdk/event/conversation_error.py +++ b/openhands-sdk/openhands/sdk/event/conversation_error.py @@ -1,9 +1,78 @@ +import re + from pydantic import Field from rich.text import Text from openhands.sdk.event.base import Event +# --------------------------------------------------------------------------- +# Hint rules: list of (pattern, hint_text) pairs. The first matching pattern +# wins. Patterns are matched case-insensitively against ``detail``. +# --------------------------------------------------------------------------- +_HINT_RULES: list[tuple[re.Pattern[str], str]] = [ + # Databricks AI Gateway: endpoint not found in workspace (404). + # Typically means cross-geography routing is disabled, or the endpoint + # has not been deployed in this workspace. + ( + re.compile( + r"\[404\]\s*AI\s+Gateway\s+endpoint\s+['\"]?(\S+?)['\"]?\s+does\s+not\s+exist", + re.IGNORECASE, + ), + ( + "This Databricks endpoint is not available in your workspace.\n" + "Possible reasons:\n" + " • The model requires cross-geography routing, which is not\n" + " enabled in your workspace (contact your admin).\n" + " • The endpoint name is misspelled or not yet deployed.\n" + "Tip: Open Settings → click 'Refresh Models' to see the endpoints\n" + "that are actually available in your workspace, then save a\n" + "different model." + ), + ), + # Databricks: org-level access denied (403 Invalid access to Org). + # Gemini and other cross-geography models route through a Databricks + # global GCP org. The 403 means that routing is not enabled for this + # workspace account. + ( + re.compile(r"\[403\].*Invalid\s+access\s+to\s+Org", re.IGNORECASE), + ( + "Your workspace does not have permission to access this model.\n" + "This error most commonly occurs with Gemini models, which require\n" + "cross-geography routing through a Databricks GCP organisation.\n" + "Action: ask your Databricks account admin to enable\n" + " 'Cross-geography model serving' for your account, or choose a\n" + " different model (Claude / Llama / DBRX) that runs within your\n" + " workspace region.\n" + "Tip: Open Settings → click '↻ Refresh Models' to see only the\n" + "endpoints available in your workspace, then pick a different model." + ), + ), + # Databricks: authentication failure (401 / token expired). + ( + re.compile(r"\[401\].*databricks|databricks.*\[401\]|UNAUTHENTICATED", re.IGNORECASE), + ( + "Databricks authentication failed.\n" + "Tip: Open Settings and re-authenticate (re-run the browser sign-in\n" + "for U2M, or verify your client credentials for M2M)." + ), + ), + # Generic LiteLLM / provider rate-limit. + ( + re.compile(r"\[429\]|rate.?limit|too many requests", re.IGNORECASE), + "The model endpoint returned a rate-limit error. Wait a moment and retry.", + ), +] + + +def _get_hint(detail: str) -> str | None: + """Return the first matching hint for the given error detail, or None.""" + for pattern, hint in _HINT_RULES: + if pattern.search(detail): + return hint + return None + + class ConversationErrorEvent(Event): """ Conversation-level failure that is NOT sent back to the LLM. @@ -34,4 +103,10 @@ def visualize(self) -> Text: content.append(self.code) content.append("\n\nDetail:\n", style="bold") content.append(self.detail) + + hint = _get_hint(self.detail) + if hint: + content.append("\n\nHint:\n", style="bold yellow") + content.append(hint, style="yellow") + return content diff --git a/openhands-sdk/openhands/sdk/llm/__init__.py b/openhands-sdk/openhands/sdk/llm/__init__.py index 1a823532a9..23f049e95a 100644 --- a/openhands-sdk/openhands/sdk/llm/__init__.py +++ b/openhands-sdk/openhands/sdk/llm/__init__.py @@ -32,6 +32,16 @@ ) from openhands.sdk.llm.utils.verified_models import VERIFIED_MODELS +# Eagerly import DatabricksLLM so it registers with LLM.__subclasses__() at +# module-load time. This ensures that LLM._dispatch_to_provider_subclass can +# reconstruct a DatabricksLLM when deserializing persisted agent JSON that +# carries provider="databricks" — even in processes (e.g. the agent server) +# that never explicitly import the Databricks provider. +try: + from openhands.sdk.llm.providers.databricks.llm import DatabricksLLM # noqa: F401 +except Exception: # pragma: no cover — optional dependency + pass + __all__ = [ # Auth diff --git a/openhands-sdk/openhands/sdk/llm/llm.py b/openhands-sdk/openhands/sdk/llm/llm.py index 4ed05998de..0048c1f414 100644 --- a/openhands-sdk/openhands/sdk/llm/llm.py +++ b/openhands-sdk/openhands/sdk/llm/llm.py @@ -14,8 +14,10 @@ BaseModel, ConfigDict, Field, + ModelWrapValidatorHandler, PrivateAttr, SecretStr, + ValidationInfo, field_serializer, field_validator, model_validator, @@ -509,6 +511,38 @@ class LLM(BaseModel, RetryMixin, NonNativeToolCallingMixin): # ========================================================================= # Validators # ========================================================================= + @model_validator(mode="wrap") + @classmethod + def _dispatch_to_provider_subclass( + cls, + data: Any, + handler: ModelWrapValidatorHandler["LLM"], + info: ValidationInfo, + ) -> "LLM": + """Route persisted agent JSON to the matching LLM subclass. + + When a saved agent is rehydrated the ``llm`` dict may carry a + ``provider`` key written by a subclass (e.g. ``"provider": + "databricks"``). Without this, ``LLM.model_validate(data)`` would + build a base ``LLM`` and silently drop the subclass-only fields. + + Only fires when called directly on the base ``LLM`` class with a + plain dict — subclass validators and non-dict inputs pass through + unchanged. Subclasses are discovered generically via + ``__subclasses__()`` keyed off their ``provider`` Literal annotation, + so no provider names are hardcoded here. + """ + if cls is not LLM or not isinstance(data, dict): + return handler(data) + provider = data.get("provider") + if not provider: + return handler(data) + for sub in LLM.__subclasses__(): + f = sub.model_fields.get("provider") + if f and provider in getattr(f.annotation, "__args__", ()): + return sub.model_validate(data, context=info.context) + return handler(data) + @field_validator( "api_key", "aws_access_key_id", "aws_secret_access_key", "aws_session_token" ) diff --git a/openhands-sdk/openhands/sdk/llm/providers/__init__.py b/openhands-sdk/openhands/sdk/llm/providers/__init__.py new file mode 100644 index 0000000000..9432c80ee1 --- /dev/null +++ b/openhands-sdk/openhands/sdk/llm/providers/__init__.py @@ -0,0 +1,3 @@ +# Provider implementations for the OpenHands V1 SDK LLM layer. +# Each provider subpackage contains a DatabricksLLM (or equivalent) subclass +# that bypasses LiteLLM for direct, PWAF-compliant API communication. diff --git a/openhands-sdk/openhands/sdk/llm/providers/databricks/README.md b/openhands-sdk/openhands/sdk/llm/providers/databricks/README.md new file mode 100644 index 0000000000..8a3fa7a416 --- /dev/null +++ b/openhands-sdk/openhands/sdk/llm/providers/databricks/README.md @@ -0,0 +1,190 @@ +# Databricks AI Gateway provider + +Native provider for Databricks Foundation Model APIs, routed through the +**Databricks AI Gateway**. Uses a direct `httpx` transport against the +gateway instead of routing HTTP through `litellm.completion`. + +Implements the Partner Well-Architected Framework (PWAF) contract for +Databricks OSS connectors: isolated auth strategies, `_/` +User-Agent on every request, typed errors, retry/backoff, and metadata-first +routing with a safe name-pattern fallback. + +## TL;DR — construct an LLM and use it + +```python +from pydantic import SecretStr +from openhands.sdk import create_llm +from openhands.sdk.llm.message import Message, TextContent + +llm = create_llm( + model="databricks/databricks-claude-sonnet-4-5", + databricks_host="https://adb-xxx.cloud.databricks.com", + api_key=SecretStr("dapi..."), + usage_id="my-agent", +) + +resp = llm.completion(messages=[ + Message(role="user", content=[TextContent(text="Hello!")]), +]) +print(resp.message.content[0].text) +``` + +No `databricks/` prefix + no explicit provider keyword means you get the +base `LLM` (default LiteLLM transport). Adding the `databricks/` prefix is +the only signal `create_llm` needs to route to this provider. + +## Supported native APIs + +| Family | AI Gateway path | When selected | +|---|---|---| +| `ProviderFamily.OPENAI` | `POST /serving-endpoints/{endpoint}/invocations` | Default — every `llm/v1/chat` endpoint | +| `ProviderFamily.OPENAI_RESPONSES` | `POST /serving-endpoints/v1/responses` | GPT-5 series (`databricks-gpt-5*`) | +| `ProviderFamily.ANTHROPIC` | `POST /serving-endpoints/anthropic/v1/messages` | Claude models (`*claude*`) | +| `ProviderFamily.GEMINI` | `POST /serving-endpoints/gemini/v1beta/models/{endpoint}:generateContent` | Gemini models (`*gemini*`) | + +Routing is metadata-first +(`GET /api/2.0/serving-endpoints/{name}` → `foundation_model.api_types` / +`external_model.provider`) with a name-pattern fallback (see `models.py`). +Results are cached in-process with a 5-minute TTL. + +See `models.py` for the authoritative routing table and `native.py` for +the per-family request/response adapters. + +## Authentication + +Five PWAF-compliant strategies, resolved in this priority order by +`resolve_credentials()`: + +1. **U2M** — OAuth browser PKCE, tokens passed in via `stored_u2m_tokens`. +2. **M2M** — `databricks_client_id` + `databricks_client_secret` + (service principal / client_credentials grant). +3. **PAT** — `api_key=SecretStr("dapi...")`. +4. **PROFILE** — `databricks_profile="DEFAULT"` (reads `~/.databrickscfg`; + requires the `databricks` extra: `pip install openhands-sdk[databricks]`). +5. **UNIFIED** — fallback to the `databricks-sdk` unified auth chain + (env vars, Azure MSI/Entra ID, etc.). + +Defer to the **Databricks Partner PWAF skills** *(URL TBD)* for end-to-end +auth details (token caching, refresh policies, CLI profile selection). + +Use `llm.auth_method` to see which strategy resolved. + +The interactive **U2M browser login** (Authorization Code + PKCE) is built from +the dependency-light helpers in `pkce.py` — `generate_pkce()`, +`build_authorize_url()`, and `exchange_code_for_tokens()` / +`async_exchange_code_for_tokens()` — all re-exported from this package's +`__init__`. These are the single source of truth consumed by both the web +backend and the OpenHands-CLI, so the three front-ends can't drift apart. Each +caller supplies its own local redirect/callback handling and passes the +resulting tokens back in via `stored_u2m_tokens`. + +## Alignment with Databricks `ucode` + +This connector follows the same credential model as +[Databricks `ucode`](https://github.com/databricks/ucode) — the *Unity AI +Gateway Coding CLI* — which routes coding agents through the Databricks AI +Gateway using workspace credentials, **no API keys required**. The `PROFILE` +and `UNIFIED` strategies read the workspace login a developer has already +established (`databricks auth login` / `~/.databrickscfg`), and `U2M` provides +interactive browser OAuth. An OpenHands agent can therefore reach AI Gateway the +same key-free, governed way `ucode` does — reusing the existing workspace +session rather than minting a separate token — over one consistent path to the +gateway (and the Unity Catalog–governed resources behind it). + +## Discovery (picker UIs) + +Listing AI-Gateway-shaped chat endpoints: + +```python +from openhands.sdk.llm.providers.databricks import ( + DatabricksCredentials, list_chat_endpoints, +) + +creds = DatabricksCredentials(host="...", get_token=lambda: "dapi...", auth_method="pat") +for ep in list_chat_endpoints(creds): + print(ep.qualified_name, ep.endpoint_type) +``` + +`list_chat_endpoints` includes both `FOUNDATION_MODEL_API` and +`EXTERNAL_MODEL` endpoints (customer-configured gpt-5 / gemini / +claude proxies). The lighter `list_foundation_models` returns flat +`databricks/` strings and is what `list_models_from_env` uses +with a 5-minute TTL cache. + +## PWAF surfaces on `DatabricksLLM` + +| Attribute | What it tells you | +|---|---| +| `llm.auth_method` | Resolved strategy: `pat` / `m2m` / `u2m` / `profile` / `unified` / `env` | +| `llm.predicted_family` | Family by name pattern only (pure compute, no HTTP) | +| `llm.resolve_family()` | Authoritative family (metadata probe, cached; falls back to predicted) | +| `llm.max_input_tokens` | Context window from the model-capability table | +| `llm.max_output_tokens` | Output budget (generous on reasoning models — gpt-5, gemini 2.5, gpt-oss) | + +## Module layout + +| Module | Role | +|---|---| +| `__init__.py` | Public API (`DatabricksLLM`, `ProviderFamily`, discovery, auth types, PKCE helpers) | +| `llm.py` | `DatabricksLLM` — Pydantic subclass of `LLM`; transport override | +| `client.py` | `DatabricksFMAPIClient` — `httpx` transport, family dispatch, retry | +| `native.py` | Per-family `to_native` / `from_native` adapters (request/response shaping) | +| `models.py` | `ProviderFamily`, `AIGatewayPaths`, routing functions, token containers | +| `auth.py` | Five credential strategies, `resolve_credentials()`, token providers | +| `pkce.py` | Shared U2M browser-login helpers (`generate_pkce`, `build_authorize_url`, sync/async `exchange_code_for_tokens`) — single source of truth for web + CLI | +| `settings_bridge.py` | `kwargs_from_settings()` — the one path that turns user settings (env / DB / TUI) into `create_llm(...)` kwargs, shared by backend and CLI | +| `discovery.py` | `list_chat_endpoints` / `list_foundation_models` + TTL cache | +| `utils.py` | `USER_AGENT`, `DatabricksTimeouts`, retry/backoff, error mapping | + +## Relationship to LiteLLM + +The connector owns its own HTTP path — all wire traffic to the Databricks AI +Gateway goes through `client.py` (`httpx`), not through `litellm.completion`. +It does, however, interoperate with LiteLLM at the type boundary to stay +compatible with the OpenHands base `LLM` class: + +- `client.py` returns a `litellm.types.utils.ModelResponse` so callers of the + base `LLM.completion` API get the expected return shape. +- `utils.py` and `auth.py` raise `litellm.exceptions.*` for HTTP error mapping + so retry/backoff behaves consistently with other providers. + +Removing this type-level coupling would require decoupling the OpenHands base +`LLM` class itself from LiteLLM and is tracked as a separate investigation — +it is deliberately out of scope for this provider. + +### Parked: true LiteLLM decoupling (Scope B) + +The work needed to drop the residual LiteLLM type dependency is intentionally +parked and recorded here so it isn't forgotten: + +- **Goal.** Remove both `litellm.types.utils.ModelResponse` and + `litellm.exceptions` from the Databricks provider's public surface so a + future OpenHands deployment could run this connector without LiteLLM in the + dependency graph at all. +- **Blocker.** The OpenHands base `LLM` class (in the shared SDK) still + accepts and returns `ModelResponse`, and its retry/backoff wiring catches + `litellm.exceptions.*`. Changing only the Databricks connector would break + that contract and ripple into every other provider plus the callers. +- **Shape of the fix (for the future investigation).** + 1. Introduce a small SDK-level `LLMResponse` dataclass that `LLM.completion` + returns regardless of provider, with a thin adapter that today wraps + `ModelResponse` for LiteLLM-backed providers. + 2. Replace `litellm.exceptions.*` with a set of SDK-owned error classes + (`LLMAuthError`, `LLMRateLimitError`, …) and translate at the LiteLLM + boundary only. + 3. Switch this provider to the new types, delete the two LiteLLM imports + from `client.py` / `utils.py` / `auth.py`, and remove the dependency + marker from the Databricks provider's optional extras. +- **Test plan.** Re-run the full `tests/sdk/llm/providers/databricks/` suite + plus the shared `LLM` contract tests without `litellm` installed in a + separate venv to prove the connector is truly standalone. +- **Why not now.** Would touch every other provider and the base `LLM` + class; out of scope for the native-Databricks provider milestone. + +## Testing + +Every routing / adapter / auth path has unit coverage under +`tests/sdk/llm/providers/databricks/`. End-to-end calls against +`e2e-demo-field-eng` have been run live across all three native families +(Llama → OpenAI Chat, Claude → Anthropic, Gemini → `generateContent`) and +all five auth strategies. diff --git a/openhands-sdk/openhands/sdk/llm/providers/databricks/__init__.py b/openhands-sdk/openhands/sdk/llm/providers/databricks/__init__.py new file mode 100644 index 0000000000..1883589abd --- /dev/null +++ b/openhands-sdk/openhands/sdk/llm/providers/databricks/__init__.py @@ -0,0 +1,105 @@ +"""Databricks AI Gateway native provider for the OpenHands V1 SDK. + +PWAF-compliant. Uses a direct synchronous ``httpx`` transport against the +Databricks AI Gateway rather than routing HTTP through ``litellm.completion``, +and dispatches by provider family to the correct native surface: + +* **OpenAI Chat** → ``/serving-endpoints/{endpoint}/invocations`` (universal) +* **OpenAI Responses** → ``/serving-endpoints/v1/responses`` (GPT-5 series) +* **Anthropic Messages** → ``/serving-endpoints/anthropic/v1/messages`` (Claude) +* **Gemini generateContent** → ``/serving-endpoints/gemini/v1beta/models/{endpoint}:generateContent`` + +Routing is metadata-first (``GET /api/2.0/serving-endpoints/{name}`` → +``foundation_model.api_types`` / ``external_model.provider``) with a +name-pattern fallback. + +Auth: PAT, OAuth M2M (client_credentials), OAuth U2M (browser PKCE), +``~/.databrickscfg`` profile, and unified Databricks SDK chain — see +:mod:`.auth` for details. The PWAF Partner AI Dev Kit skills are the +authoritative reference for credential handling. + +See the companion skill ``databricks-ai-gateway-fm-apis`` (in ``_local/skills``) +for the routing table, worked examples, and a runnable ``probe.py`` that +self-verifies every native path. + +Typical usage (via :func:`openhands.sdk.create_llm` factory): + +.. code-block:: python + + from openhands.sdk import create_llm + from openhands.sdk.llm.message import Message, TextContent + from pydantic import SecretStr + + llm = create_llm( + model="databricks/databricks-claude-sonnet-4-5", + databricks_host="https://adb-xxx.cloud.databricks.com", + api_key=SecretStr("dapi..."), # or pass databricks_profile="DEFAULT" + usage_id="my-agent", + ) + print(llm.predicted_family) # ProviderFamily.ANTHROPIC (no HTTP) + print(llm.resolve_family()) # ANTHROPIC (metadata-confirmed) + + resp = llm.completion(messages=[ + Message(role="user", content=[TextContent(text="Hello!")]), + ]) +""" + +from openhands.sdk.llm.providers.databricks.auth import ( + AuthStrategy, + DatabricksCredentials, +) +from openhands.sdk.llm.providers.databricks.discovery import ( + CURATED_DATABRICKS_MODELS, + DiscoveredEndpoint, + ModelPickerEntry, + get_picker_entries, + list_chat_endpoints, + list_foundation_models, + list_models_from_env, +) +from openhands.sdk.llm.providers.databricks.llm import DatabricksLLM +from openhands.sdk.llm.providers.databricks.models import ( + AIGatewayPaths, + ProviderFamily, + StoredU2MTokens, + detect_family, + pick_family_from_api_types, +) +from openhands.sdk.llm.providers.databricks.pkce import ( + async_exchange_code_for_tokens, + build_authorize_url, + exchange_code_for_tokens, + generate_pkce, +) +from openhands.sdk.llm.providers.databricks.settings_bridge import kwargs_from_settings + + +__all__ = [ + # LLM + "DatabricksLLM", + # Routing primitives + "ProviderFamily", + "AIGatewayPaths", + "detect_family", + "pick_family_from_api_types", + # Auth + "AuthStrategy", + "DatabricksCredentials", + "StoredU2MTokens", + # U2M browser-login PKCE primitives (shared by web + CLI) + "generate_pkce", + "build_authorize_url", + "exchange_code_for_tokens", + "async_exchange_code_for_tokens", + # Discovery + "DiscoveredEndpoint", + "list_chat_endpoints", + "list_foundation_models", + "list_models_from_env", + # Two-tier model picker + "ModelPickerEntry", + "CURATED_DATABRICKS_MODELS", + "get_picker_entries", + # Settings → create_llm bridge (drift-guarded) + "kwargs_from_settings", +] diff --git a/openhands-sdk/openhands/sdk/llm/providers/databricks/auth.py b/openhands-sdk/openhands/sdk/llm/providers/databricks/auth.py new file mode 100644 index 0000000000..5af69b0c5c --- /dev/null +++ b/openhands-sdk/openhands/sdk/llm/providers/databricks/auth.py @@ -0,0 +1,362 @@ +"""Databricks FMAPI authentication strategies. + +Supports all 5 PWAF-required auth paths: + U2M — OAuth browser PKCE flow (PWAF primary interactive auth). + Provider receives StoredU2MTokens from app layer; manages refresh only. + M2M — OAuth client credentials (PWAF primary service auth). + M2MTokenProvider fetches/refreshes tokens with threading.Lock. + PAT — Personal Access Token (additional option only per PWAF). + PROFILE — Databricks CLI profile (~/.databrickscfg). Requires databricks-sdk. + UNIFIED — databricks-sdk unified auth chain (workload identity, Azure AD, etc.). + Requires databricks-sdk. + +Auth priority (PWAF compliant): U2M > M2M > PAT > PROFILE > UNIFIED. + +Client ID distinction (critical): + DATABRICKS_CLIENT_ID — M2M service principal (grant_type=client_credentials) + DATABRICKS_U2M_CLIENT_ID — Custom OAuth app for browser login (PKCE flow) + Using M2M client_id for U2M → "OAuth application not available" from Databricks. +""" + +from __future__ import annotations + +import logging +import threading +import time +from dataclasses import dataclass +from enum import Enum +from typing import TYPE_CHECKING, Callable + +import httpx +from litellm.exceptions import AuthenticationError + +from openhands.sdk.llm.providers.databricks.models import StoredU2MTokens +from openhands.sdk.llm.providers.databricks.utils import ( + USER_AGENT, + normalize_host, + validate_databricks_config, +) + +if TYPE_CHECKING: + from openhands.sdk.llm.providers.databricks.llm import DatabricksLLM + +logger = logging.getLogger(__name__) + + +class AuthStrategy(str, Enum): + """Auth strategy discriminator. Used in resolve_credentials() priority chain.""" + + U2M = "u2m" # OAuth browser PKCE — PWAF primary interactive auth + M2M = "m2m" # OAuth client credentials — PWAF primary service auth + PAT = "pat" # Personal Access Token — additional option only per PWAF + PROFILE = "profile" # Databricks CLI profile (~/.databrickscfg) + UNIFIED = "unified" # databricks-sdk unified auth chain (fallback) + + +@dataclass +class DatabricksCredentials: + """Resolved Databricks credentials ready for use in API calls. + + get_token is always synchronous — threading.Lock used internally for M2M/U2M. + auth_method is a plain string logged for observability (never the token value). + """ + + host: str + get_token: Callable[[], str] + auth_method: str = "unknown" # "u2m" / "m2m" / "pat" / "profile" / "unified" + + +# --------------------------------------------------------------------------- +# M2M token provider — thread-safe via threading.Lock (not asyncio) +# --------------------------------------------------------------------------- + +class M2MTokenProvider: + """Thread-safe OAuth client credentials token provider. + + P0-3: constructor accepts host, client_id, client_secret (was missing in P3 plan). + Uses double-checked locking to ensure exactly one _fetch_new_token() call under + concurrent pressure. + """ + + def __init__(self, host: str, client_id: str, client_secret: str) -> None: + self._host = host + self._client_id = client_id + self._client_secret = client_secret + self._token: str | None = None + self._expires_at: float = 0.0 + self._lock = threading.Lock() # threading, NOT asyncio + + def get_token(self) -> str: + """Return a valid access token, refreshing proactively if nearing expiry.""" + # Fast path: check without acquiring lock + if self._token and time.time() < self._expires_at - 300: + return self._token + with self._lock: + # Double-check inside lock to prevent thundering herd + if self._token and time.time() < self._expires_at - 300: + return self._token + self._token, self._expires_at = self._fetch_new_token() + return self._token + + def _fetch_new_token(self) -> tuple[str, float]: + """Fetch a new M2M access token via client credentials grant.""" + resp = httpx.post( + f"{self._host}/oidc/v1/token", + data={ + "grant_type": "client_credentials", + "client_id": self._client_id, + "client_secret": self._client_secret, + "scope": "all-apis", # required by Databricks OIDC + }, + headers={"User-Agent": USER_AGENT}, # PWAF: UA on ALL Databricks HTTP + timeout=15.0, + ) + resp.raise_for_status() + data = resp.json() + expires_at = time.time() + data.get("expires_in", 3600) + return data["access_token"], expires_at + + +# --------------------------------------------------------------------------- +# Main credential resolver +# --------------------------------------------------------------------------- + +def resolve_credentials(llm: "DatabricksLLM") -> DatabricksCredentials: + """Resolve auth strategy and return DatabricksCredentials. + + Priority (PWAF: OAuth primary, PAT additional only): + 1. U2M stored_u2m_tokens is set (user completed browser login) + 2. M2M databricks_client_id + databricks_client_secret both set + 3. PAT api_key is set + 4. PROFILE databricks_profile named + 5. UNIFIED databricks-sdk auth chain (fallback) + + Host resolution: ``databricks_host`` (workspace) is required for U2M / + M2M / PROFILE / UNIFIED (workspace OAuth) and when + ``databricks_metadata_probe=True``. PAT only needs *some* host for the + FM client to route to — either ``databricks_host`` (canonical) or + ``databricks_ai_gateway_host`` (override). The validator in + ``DatabricksLLM`` enforces that at least one is set. + """ + stored = llm.stored_u2m_tokens + host = ( + llm.databricks_host + or llm.base_url + or (stored.host if stored else None) + ) + if host: + host = normalize_host(host) + + # Path 1: U2M — highest priority (user already browser-logged in) + if stored: + if not host: + raise ValueError("databricks_host is required for U2M auth.") + validate_databricks_config(host, AuthStrategy.U2M, stored_tokens=stored) + # Forward the OAuth client secret (confidential app) if present. + u2m_secret = llm.databricks_u2m_client_secret + u2m_secret_str: str | None = ( + u2m_secret.get_secret_value() if u2m_secret is not None else None + ) + return _resolve_u2m(host, stored, client_secret=u2m_secret_str) + + # Path 2: M2M — service principal client credentials + if llm.databricks_client_id and llm.databricks_client_secret: + if not host: + raise ValueError("databricks_host is required for M2M auth.") + validate_databricks_config( + host, + AuthStrategy.M2M, + client_id=llm.databricks_client_id, + client_secret=llm.databricks_client_secret.get_secret_value(), + ) + return _resolve_m2m(host, llm) + + # Path 3: PAT — token is sent directly to the AI Gateway, no workspace + # host needed. credentials.host is left empty so any accidental use + # surfaces clearly. + if llm.api_key: + token = ( + llm.api_key.get_secret_value() + if hasattr(llm.api_key, "get_secret_value") + else str(llm.api_key) + ) + logger.info("databricks_auth_resolved", extra={"method": "pat"}) + return DatabricksCredentials( + host=host or "", get_token=lambda: token, auth_method="pat" + ) + + # Path 4: Named CLI profile + if llm.databricks_profile: + if not host: + raise ValueError( + "databricks_host is required when using databricks_profile." + ) + return _resolve_profile(host, llm.databricks_profile) + + # Path 5: SDK unified auth chain (workload identity, Azure AD, ~/.databrickscfg) + if not host: + raise ValueError( + "databricks_host is required for unified-SDK auth." + ) + return _resolve_sdk_auth(host) + + +# --------------------------------------------------------------------------- +# Strategy implementations +# --------------------------------------------------------------------------- + +def _resolve_u2m( + host: str, + stored: StoredU2MTokens, + client_secret: str | None = None, +) -> DatabricksCredentials: + """U2M: return current access token, refreshing silently via refresh_token. + + Proactive refresh: 5 minutes before expiry (300s buffer). + Uses threading.Lock for thread safety in the synchronous call path. + + ``client_secret`` must be supplied for confidential OAuth apps (apps that + have a client secret configured in Databricks App Connections). Public + PKCE apps leave it ``None``. + """ + lock = threading.Lock() + state: dict[str, object] = { + "token": stored.access_token, + "expires_at": stored.expires_at, + } + + def get_token() -> str: + # Fast path — no lock needed + if time.time() < float(state["expires_at"]) - 300: + return str(state["token"]) + with lock: + if time.time() < float(state["expires_at"]) - 300: + return str(state["token"]) + remaining = float(state["expires_at"]) - time.time() + logger.info( + "databricks_u2m_token_refresh", + extra={"remaining_s": round(remaining, 1)}, + ) + token_data: dict[str, str] = { + "grant_type": "refresh_token", + "refresh_token": stored.refresh_token, + "client_id": stored.client_id, + } + if client_secret: + # Confidential OAuth apps require the client secret on refresh. + token_data["client_secret"] = client_secret + resp = httpx.post( + f"{host}/oidc/v1/token", + data=token_data, + headers={"User-Agent": USER_AGENT}, # PWAF: UA on token endpoint + timeout=15.0, + ) + if not resp.is_success: + hint = ( + "Re-authenticate via browser sign-in." + if client_secret + else "Re-authenticate at /auth/databricks/initiate." + ) + raise AuthenticationError( + f"U2M token refresh failed [{resp.status_code}]. {hint}", + model="", + llm_provider="databricks", + ) + data = resp.json() + state["token"] = data["access_token"] + state["expires_at"] = time.time() + data.get("expires_in", 3600) + logger.info("databricks_u2m_token_refreshed", extra={"method": "u2m"}) + return str(state["token"]) + + logger.info("databricks_auth_resolved", extra={"method": "u2m"}) + return DatabricksCredentials(host=host, get_token=get_token, auth_method="u2m") + + +def _resolve_m2m(host: str, llm: "DatabricksLLM") -> DatabricksCredentials: + """M2M: client credentials grant via M2MTokenProvider.""" + assert llm.databricks_client_secret is not None # validated in resolve_credentials + provider = M2MTokenProvider( + host=host, + client_id=llm.databricks_client_id, # type: ignore[arg-type] + client_secret=llm.databricks_client_secret.get_secret_value(), + ) + logger.info("databricks_auth_resolved", extra={"method": "m2m"}) + return DatabricksCredentials( + host=host, get_token=provider.get_token, auth_method="m2m" + ) + + +def _resolve_profile(host: str, profile: str) -> DatabricksCredentials: + """PROFILE: Databricks CLI profile via databricks-sdk. Requires optional dep. + + The import check is deferred into ``get_token()`` so that saving settings + succeeds even if ``databricks-sdk`` is not yet installed; the clear error + with install instructions surfaces only when the agent first makes an API + call. + """ + client_holder: dict[str, object] = {} + lock = threading.Lock() + + def get_token() -> str: + try: + from databricks.sdk import WorkspaceClient as _WC + except ImportError: + raise ImportError( + "PROFILE auth requires the 'databricks-sdk' package.\n" + "Install it: pip install databricks-sdk\n" + f"Then verify: databricks auth profiles " + f"(profile '{profile}' must appear)" + ) from None + + client = client_holder.get("client") + if client is None: + with lock: + client = client_holder.get("client") + if client is None: + client = _WC(host=host, profile=profile) + client_holder["client"] = client + auth_header = client.config.authenticate()["Authorization"] # type: ignore[attr-defined] + return auth_header.split(" ", 1)[1] if " " in auth_header else auth_header + + logger.info("databricks_auth_resolved", extra={"method": "profile", "profile": profile}) + return DatabricksCredentials(host=host, get_token=get_token, auth_method="profile") + + +def _resolve_sdk_auth(host: str) -> DatabricksCredentials: + """UNIFIED: databricks-sdk auth chain (workload identity, Azure AD, ~/.databrickscfg). + + The import check is deferred into ``get_token()`` so that saving settings + succeeds even if ``databricks-sdk`` is not yet installed; the clear error + with install + pre-login instructions surfaces only when the agent first + makes an API call. + + Pre-requisites (outside the agent): + 1. pip install databricks-sdk + 2. databricks auth login --host + """ + client_holder: dict[str, object] = {} + lock = threading.Lock() + + def get_token() -> str: + try: + from databricks.sdk import WorkspaceClient as _WC + except ImportError: + raise ImportError( + "Browser-SSO / unified auth requires the 'databricks-sdk' package.\n" + " Step 1 — install: pip install databricks-sdk\n" + f" Step 2 — login: databricks auth login --host {host}\n" + "Re-run the agent after completing both steps." + ) from None + + client = client_holder.get("client") + if client is None: + with lock: + client = client_holder.get("client") + if client is None: + client = _WC(host=host) + client_holder["client"] = client + auth_header = client.config.authenticate()["Authorization"] # type: ignore[attr-defined] + return auth_header.split(" ", 1)[1] if " " in auth_header else auth_header + + logger.info("databricks_auth_resolved", extra={"method": "unified"}) + return DatabricksCredentials(host=host, get_token=get_token, auth_method="unified") diff --git a/openhands-sdk/openhands/sdk/llm/providers/databricks/client.py b/openhands-sdk/openhands/sdk/llm/providers/databricks/client.py new file mode 100644 index 0000000000..56e0deeab9 --- /dev/null +++ b/openhands-sdk/openhands/sdk/llm/providers/databricks/client.py @@ -0,0 +1,404 @@ +"""Databricks AI Gateway synchronous HTTP client. + +Two distinct Databricks surfaces, kept strictly separate: + +* **AI Gateway** (FM invocations) → ``ai_gateway_host`` (optional override). + Defaults to ``credentials.host``. All ``POST /ai-gateway/...`` traffic + goes here, exclusively. +* **Workspace** (auth, discovery, metadata probes) → ``credentials.host``. + Used by auth/token resolution and the opt-in metadata probe; never used + for FM invocations. + +Path templates by provider family (relative to the AI Gateway base URL, +which :meth:`AIGatewayPaths.normalize_base` produces from the configured +host — see ``models.py``): + +* :attr:`ProviderFamily.OPENAI` → ``POST {base}/mlflow/v1/chat/completions`` +* :attr:`ProviderFamily.OPENAI_RESPONSES` → ``POST {base}/openai/v1/responses`` +* :attr:`ProviderFamily.ANTHROPIC` → ``POST {base}/anthropic/v1/messages`` +* :attr:`ProviderFamily.GEMINI` → ``POST {base}/gemini/v1beta/models/{endpoint}:generateContent`` + +Family routing: + +* Default — ``detect_family(model)`` resolves the family from the model + name; no network call. +* ``metadata_probe=True`` — issues + ``GET /api/2.0/serving-endpoints/{name}`` against the workspace before + each cache-miss invocation; results cached for 5 minutes per process. + +Streaming stays on the universal OpenAI Chat SSE path +(``{base}/mlflow/v1/chat/completions`` with ``stream=true`` in the body). + +Non-streaming: singleton ``httpx.Client`` for connection pooling, thread-safe. +Streaming: context-managed client per request — always closed, no leak. +""" + +from __future__ import annotations + +import json +import logging +import threading +import time +from typing import Any, Callable + +import httpx +from litellm.types.utils import ModelResponse + +from openhands.sdk.llm.providers.databricks.auth import DatabricksCredentials +from openhands.sdk.llm.providers.databricks.models import ( + AIGatewayPaths, + ProviderFamily, + detect_family, + pick_family_from_api_types, +) +from openhands.sdk.llm.providers.databricks.native import from_native, to_native +from openhands.sdk.llm.providers.databricks.utils import ( + USER_AGENT, + DatabricksTimeouts, + _raise_non_retryable, + fetch_with_retry, +) + +logger = logging.getLogger(__name__) + +TokenCallbackType = Callable[[str], None] + +# Metadata cache: endpoint name -> (family, expires_at_epoch_seconds). +_METADATA_TTL_S: float = 300.0 +_METADATA_NEGATIVE_TTL_S: float = 60.0 + + +def _bare_endpoint(model: str) -> str: + """Strip ``databricks/`` / ``databricks-`` prefix to get the endpoint name.""" + name = model.strip() + for prefix in ("databricks/",): + if name.startswith(prefix): + name = name[len(prefix):] + return name + + +class DatabricksFMAPIClient: + """Synchronous Foundation Model client for Databricks AI Gateway. + + Public API is unchanged from the previous single-route client — + ``chat_completion(model=..., messages=..., stream=..., tools=..., **kwargs)`` — + but internally the request is routed to the correct native surface by + provider family. + + Thread-safety: the singleton ``self._http`` is used for all non-streaming + calls (including metadata lookups). Streaming opens a fresh + context-managed client per request. + """ + + def __init__( + self, + credentials: DatabricksCredentials, + timeouts: DatabricksTimeouts, + ai_gateway_host: str | None = None, + max_retries: int = 3, + ssl_verify: bool = True, + paths: AIGatewayPaths | None = None, + metadata_probe: bool = False, + ) -> None: + # AI Gateway host is an optional override; for the common + # single-URL Databricks deployment the workspace host doubles as + # the gateway base (``/ai-gateway/``). + gateway_host = (ai_gateway_host or credentials.host or "").rstrip("/") + if not gateway_host: + raise ValueError( + "Either ai_gateway_host or credentials.host must be provided; " + "the FM client needs at least one URL to route invocations to." + ) + self._credentials = credentials + self._timeouts = timeouts + self._max_retries = max_retries + self._ssl_verify = ssl_verify + self._paths = paths or AIGatewayPaths() + self._ai_gateway_host = gateway_host + self._metadata_probe = metadata_probe + self._http = httpx.Client( + verify=ssl_verify, + timeout=httpx.Timeout( + connect=timeouts.connect_s, + read=timeouts.read_s, + write=10.0, + pool=timeouts.pool_s, + ), + ) + self._metadata_cache: dict[str, tuple[ProviderFamily, float]] = {} + self._metadata_lock = threading.Lock() + + # ------------------------------------------------------------------ + # Lifecycle + # ------------------------------------------------------------------ + + def __del__(self) -> None: # pragma: no cover — best-effort cleanup + try: + self._http.close() + except Exception: + pass + + def close(self) -> None: + """Explicitly close the singleton HTTP client.""" + self._http.close() + + # ------------------------------------------------------------------ + # Headers / auth + # ------------------------------------------------------------------ + + def _make_headers(self, family: ProviderFamily) -> dict[str, str]: + """Build request headers. Never re-import USER_AGENT per request.""" + h = { + "Authorization": f"Bearer {self._credentials.get_token()}", + "Content-Type": "application/json", + "User-Agent": USER_AGENT, + } + if family is ProviderFamily.ANTHROPIC: + h["anthropic-version"] = "2023-06-01" + return h + + # ------------------------------------------------------------------ + # Routing + # ------------------------------------------------------------------ + + def resolve_family(self, model: str) -> ProviderFamily: + """Resolve the AI Gateway path family for ``model``. + + Default: name-pattern only via ``detect_family(model)``. + + Opt-in (``metadata_probe=True``): metadata-first with name-pattern + fallback; issues ``GET /api/2.0/serving-endpoints/{name}`` against + the workspace, cached for 5 minutes per endpoint. + """ + if not self._metadata_probe: + return detect_family(model) + + endpoint = _bare_endpoint(model) + now = time.time() + with self._metadata_lock: + hit = self._metadata_cache.get(endpoint) + if hit and hit[1] > now: + return hit[0] + family = self._probe_metadata(endpoint) or detect_family(model) + ttl = _METADATA_TTL_S if self._probe_succeeded else _METADATA_NEGATIVE_TTL_S + with self._metadata_lock: + self._metadata_cache[endpoint] = (family, now + ttl) + return family + + _probe_succeeded: bool = False # set by _probe_metadata; read by resolve_family + + def _probe_metadata(self, endpoint: str) -> ProviderFamily | None: + """GET /api/2.0/serving-endpoints/{endpoint} → family (or None).""" + if not self._credentials.host: + raise ValueError( + "databricks_host is required when " + "databricks_metadata_probe=True; the metadata probe targets " + "the workspace control plane." + ) + url = f"{self._credentials.host}/api/2.0/serving-endpoints/{endpoint}" + try: + resp = self._http.get( + url, + headers={ + "Authorization": f"Bearer {self._credentials.get_token()}", + "User-Agent": USER_AGENT, + }, + timeout=10.0, + ) + except httpx.HTTPError as exc: + logger.debug("databricks_metadata_probe_failed", extra={ + "endpoint": endpoint, "error": str(exc) + }) + self._probe_succeeded = False + return None + if resp.status_code != 200: + logger.debug("databricks_metadata_probe_nonok", extra={ + "endpoint": endpoint, "status": resp.status_code, + }) + self._probe_succeeded = False + return None + try: + meta = resp.json() + except ValueError: + self._probe_succeeded = False + return None + entities = ((meta.get("config") or {}).get("served_entities")) or [] + fm = (entities[0] if entities else {}).get("foundation_model") or {} + em = (entities[0] if entities else {}).get("external_model") or {} + family = pick_family_from_api_types( + api_types=fm.get("api_types"), + external_provider=em.get("provider"), + ) + self._probe_succeeded = True + return family + + def invalidate_metadata(self, model: str) -> None: + """Drop the cached family for an endpoint (e.g. on 404/410).""" + endpoint = _bare_endpoint(model) + with self._metadata_lock: + self._metadata_cache.pop(endpoint, None) + + # ------------------------------------------------------------------ + # Public API + # ------------------------------------------------------------------ + + def chat_completion( + self, + model: str, + messages: list[dict], + stream: bool = False, + tools: list[dict] | None = None, + on_token: TokenCallbackType | None = None, + **kwargs: Any, + ) -> ModelResponse: + """Dispatch a chat-completion call across AI Gateway surfaces.""" + endpoint = _bare_endpoint(model) + + if stream: + family = ProviderFamily.OPENAI + url = self._paths.url(self._ai_gateway_host, family, endpoint) + payload = to_native( + family, endpoint, messages, + tools=tools, stream=True, **kwargs, + ) + return self._handle_stream(url, self._make_headers(family), payload, + endpoint, on_token) + + family = self.resolve_family(model) + url = self._paths.url(self._ai_gateway_host, family, endpoint) + payload = to_native(family, endpoint, messages, tools=tools, **kwargs) + headers = self._make_headers(family) + + response = fetch_with_retry( + client=self._http, + url=url, + headers=headers, + json=payload, + max_retries=self._max_retries, + ) + logger.debug( + "databricks_ai_gateway_response", + extra={ + "status": response.status_code, + "request_id": response.headers.get("x-request-id"), + "endpoint": endpoint, + "family": family.value, + "auth_method": self._credentials.auth_method, + # Intentionally NOT logging: Authorization header, token values, + # request/response bodies (may contain user prompts). + }, + ) + return self._parse_response(response, endpoint, family) + + # ------------------------------------------------------------------ + # Response parsing + # ------------------------------------------------------------------ + + def _parse_response( + self, + response: httpx.Response, + model: str, + family: ProviderFamily, + ) -> ModelResponse: + """Convert a native response to ``litellm.ModelResponse``. + + Flow: native JSON → ``from_native`` → OpenAI ChatCompletion dict → + ``ModelResponse``. Fallback builds a minimal response if the native + shape is unexpected. + """ + try: + data = response.json() + except ValueError: + return ModelResponse(id="databricks-response", choices=[], model=model) + + try: + chat = from_native(family, model, data) + return ModelResponse(**chat) + except Exception: + logger.warning( + "databricks_parse_fallback", + extra={"family": family.value, "endpoint": model}, + ) + return ModelResponse( + id=data.get("id", "databricks-response"), + choices=data.get("choices", []), + model=data.get("model", model), + usage=data.get("usage"), + ) + + # ------------------------------------------------------------------ + # Streaming (OpenAI Chat only for V1) + # ------------------------------------------------------------------ + + def _handle_stream( + self, + url: str, + headers: dict[str, str], + payload: dict, + model: str, + on_token: TokenCallbackType | None, + ) -> ModelResponse: + """Stream ``/invocations`` with a fresh context-managed client.""" + chunk_count = 0 + accumulated_content = "" + last_chunk_id = "" + + with httpx.Client( + verify=self._ssl_verify, + timeout=httpx.Timeout( + connect=self._timeouts.connect_s, + read=self._timeouts.chunk_s, + write=10.0, + pool=self._timeouts.pool_s, + ), + ) as stream_client: + with stream_client.stream( + "POST", url, headers=headers, json=payload + ) as resp: + if resp.status_code >= 400: + resp.read() + _raise_non_retryable(resp) + for line in resp.iter_lines(): + if not line.startswith("data: ") or line == "data: [DONE]": + continue + chunk_count += 1 + try: + chunk = json.loads(line[6:]) + except json.JSONDecodeError: + continue + last_chunk_id = chunk.get("id", last_chunk_id) + choices = chunk.get("choices", []) + if choices: + delta = choices[0].get("delta", {}) + token = delta.get("content") + if token: + accumulated_content += token + if on_token is not None: + on_token(token) + + logger.debug( + "databricks_stream_complete", + extra={ + "chunks": chunk_count, + "endpoint": model, + "auth_method": self._credentials.auth_method, + }, + ) + return self._build_stream_response(accumulated_content, last_chunk_id, model) + + def _build_stream_response( + self, content: str, response_id: str, model: str, + ) -> ModelResponse: + """Build a ``ModelResponse`` from accumulated stream content.""" + return ModelResponse( + id=response_id or "databricks-stream", + choices=[ + { + "index": 0, + "message": {"role": "assistant", "content": content}, + "finish_reason": "stop", + } + ], + model=model, + object="chat.completion", + ) diff --git a/openhands-sdk/openhands/sdk/llm/providers/databricks/discovery.py b/openhands-sdk/openhands/sdk/llm/providers/databricks/discovery.py new file mode 100644 index 0000000000..eee7e53138 --- /dev/null +++ b/openhands-sdk/openhands/sdk/llm/providers/databricks/discovery.py @@ -0,0 +1,364 @@ +"""Databricks AI Gateway model discovery. + +Queries GET /api/2.0/serving-endpoints and returns the chat-capable endpoints +exposed through the AI Gateway. Three endpoint classes are surfaced: + +* ``FOUNDATION_MODEL_API`` — workspace-hosted Llama / Claude / Gemini / + GPT-5 pay-per-token models (native AI Gateway). +* ``EXTERNAL_MODEL`` — customer-configured external model endpoints + proxied through the gateway (still routed to provider-native APIs). +* ``CUSTOM_MODEL`` and ``endpoint_type=None`` endpoints are intentionally + excluded — those are agent / custom-deployment endpoints whose payload + shape is not guaranteed to be OpenAI-Chat-compatible. + +Results are TTL-cached for 5 minutes in ``list_models_from_env``. The structured +``list_chat_endpoints`` call always hits the network because metadata is cheap +and callers may want fresh data. + +PWAF: User-Agent header is included on every discovery call (required on ALL +Databricks HTTP). + +Cache race condition: on cache miss multiple threads may call +``list_foundation_models()`` concurrently (thundering herd). Last writer wins — +no data corruption, just redundant API calls. We prefer this over holding the +lock during the HTTP call, which would serialize all model-picker refreshes. +""" + +from __future__ import annotations + +import logging +import os +import threading +import time +from dataclasses import dataclass + +import httpx + +from openhands.sdk.llm.providers.databricks.auth import DatabricksCredentials +from openhands.sdk.llm.providers.databricks.models import ( + ProviderFamily, + detect_family, +) +from openhands.sdk.llm.providers.databricks.utils import USER_AGENT, normalize_host + +logger = logging.getLogger(__name__) + +_CACHE_LOCK = threading.Lock() +_CACHED_MODELS: list[str] = [] +_CACHE_EXPIRES_AT: float = 0.0 +_CACHE_TTL_S: int = 300 # 5 minutes + +# Endpoint types that expose AI-Gateway-shaped chat payloads. Everything else +# (CUSTOM_MODEL, None, agent/* tasks) is excluded by default — callers that +# really want them can pass a custom filter to list_chat_endpoints. +_GATEWAY_ENDPOINT_TYPES = frozenset({"FOUNDATION_MODEL_API", "EXTERNAL_MODEL"}) + + +@dataclass(frozen=True) +class DiscoveredEndpoint: + """Structured view of a serving endpoint from the list call. + + Only fields the list response reliably returns are captured here. + Authoritative routing metadata (``foundation_model.api_types``, + ``external_model.provider``) only comes from the per-endpoint describe + call and is resolved lazily by ``DatabricksFMAPIClient._probe_metadata``. + """ + + name: str # e.g. "databricks-claude-sonnet-4-5" + qualified_name: str # e.g. "databricks/databricks-claude-sonnet-4-5" + endpoint_type: str | None # FOUNDATION_MODEL_API | EXTERNAL_MODEL | None + task: str # "llm/v1/chat" + ready: bool + creator: str | None = None + + +def list_chat_endpoints( + credentials: DatabricksCredentials, + *, + include_not_ready: bool = False, + allowed_endpoint_types: frozenset[str] = _GATEWAY_ENDPOINT_TYPES, +) -> list[DiscoveredEndpoint]: + """Return all AI-Gateway chat endpoints visible in this workspace. + + Filters: + * ``task == "llm/v1/chat"`` + * ``endpoint_type`` in ``allowed_endpoint_types`` (default: + FOUNDATION_MODEL_API + EXTERNAL_MODEL) + * ``state.ready == "READY"`` unless ``include_not_ready`` is True + + PWAF: ``User-Agent`` header required on ALL Databricks HTTP calls. + """ + resp = httpx.get( + f"{credentials.host}/api/2.0/serving-endpoints", + headers={ + "Authorization": f"Bearer {credentials.get_token()}", + "User-Agent": USER_AGENT, + }, + timeout=30.0, + ) + resp.raise_for_status() + endpoints = resp.json().get("endpoints", []) + + out: list[DiscoveredEndpoint] = [] + for ep in endpoints: + task = ep.get("task") + if task != "llm/v1/chat": + continue + et = ep.get("endpoint_type") + if et not in allowed_endpoint_types: + continue + ready = ep.get("state", {}).get("ready") == "READY" + if not ready and not include_not_ready: + continue + name = ep.get("name") + if not name: + continue + out.append( + DiscoveredEndpoint( + name=name, + qualified_name=f"databricks/{name}", + endpoint_type=et, + task=task, + ready=ready, + creator=ep.get("creator"), + ) + ) + return out + + +def list_foundation_models(credentials: DatabricksCredentials) -> list[str]: + """Return qualified names of gateway chat endpoints (back-compat API). + + Historical name — kept for back-compat. Delegates to ``list_chat_endpoints`` + and returns only the ``databricks/`` strings. Includes both + ``FOUNDATION_MODEL_API`` and ``EXTERNAL_MODEL`` endpoints, READY only. + """ + return [e.qualified_name for e in list_chat_endpoints(credentials)] + + +# --------------------------------------------------------------------------- +# Two-tier model picker: curated (static) + discovered (dynamic) +# --------------------------------------------------------------------------- +# +# Static/global lists go stale fast and paper over per-workspace availability. +# Dynamic discovery solves that but isn't available until the user has entered +# host + credentials. We surface both: +# +# - tier 1: CURATED_DATABRICKS_MODELS — a small, hand-picked, family-balanced +# set (Claude, GPT, Gemini) that known-good works against FMAPI on any +# standard workspace. Used as the picker default before auth. +# - tier 2: list_chat_endpoints(creds) — the actual endpoints this workspace +# exposes (FOUNDATION_MODEL_API + EXTERNAL_MODEL), fetched live. Merged on +# top of the curated set so customer-configured gpt-5 / gemini / claude +# endpoints also show up. +# +# ``get_picker_entries`` is the single call UIs should use. It dedups by +# qualified name, preserves the curated "recommended" flag when the same +# endpoint is also discovered, and sorts recommended-first then by family/name. + + +@dataclass(frozen=True) +class ModelPickerEntry: + """UI-facing model picker row — merged view across curated + discovered. + + Fields: + qualified_name — "databricks/" (use as the model id for create_llm) + name — bare endpoint name + family — predicted provider family (OPENAI / ANTHROPIC / ...) + source — "curated" | "discovered" | "curated+discovered" + endpoint_type — FOUNDATION_MODEL_API | EXTERNAL_MODEL | None (curated-only) + ready — True if discovered and READY; True for curated (optimistic) + recommended — True for the curated "one per family" default picks + """ + + qualified_name: str + name: str + family: ProviderFamily + source: str + endpoint_type: str | None = None + ready: bool = True + recommended: bool = False + + +def _curated_entry( + name: str, family: ProviderFamily, *, recommended: bool = False +) -> ModelPickerEntry: + return ModelPickerEntry( + qualified_name=f"databricks/{name}", + name=name, + family=family, + source="curated", + endpoint_type="FOUNDATION_MODEL_API", + ready=True, + recommended=recommended, + ) + + +# Curated tier-1 set — Claude / GPT / Gemini only. One "recommended" per family +# (fast + capable), plus a couple of siblings. Intentionally excludes Llama +# and legacy endpoints — those surface automatically via discovery if the +# workspace has them enabled. +# +# Last sync with Databricks FMAPI docs: May 2026. +# Source: https://docs.databricks.com/aws/en/machine-learning/foundation-model-apis/supported-models +CURATED_DATABRICKS_MODELS: tuple[ModelPickerEntry, ...] = ( + # ------------------------------------------------------------------ # + # Anthropic — Claude (native Anthropic Messages API) + # All live-tested PASS except opus-4-7 (temporarily rate-limited). + # ------------------------------------------------------------------ # + _curated_entry( + "databricks-claude-sonnet-4-6", ProviderFamily.ANTHROPIC, recommended=True + ), + _curated_entry("databricks-claude-sonnet-4-5", ProviderFamily.ANTHROPIC), + _curated_entry("databricks-claude-haiku-4-5", ProviderFamily.ANTHROPIC), + _curated_entry("databricks-claude-opus-4-7", ProviderFamily.ANTHROPIC), + _curated_entry("databricks-claude-opus-4-6", ProviderFamily.ANTHROPIC), + _curated_entry("databricks-claude-opus-4-5", ProviderFamily.ANTHROPIC), + _curated_entry("databricks-claude-opus-4-1", ProviderFamily.ANTHROPIC), + # ------------------------------------------------------------------ # + # OpenAI — GPT-5 series (Responses API) and gpt-oss (OpenAI Chat) + # All live-tested PASS. gpt-5-5 / gpt-5-5-pro may be temporarily + # rate-limited (403) on some workspaces. + # ------------------------------------------------------------------ # + _curated_entry( + "databricks-gpt-5-mini", ProviderFamily.OPENAI_RESPONSES, recommended=True + ), + _curated_entry("databricks-gpt-5-5-pro", ProviderFamily.OPENAI_RESPONSES), + _curated_entry("databricks-gpt-5-5", ProviderFamily.OPENAI_RESPONSES), + _curated_entry("databricks-gpt-5-4", ProviderFamily.OPENAI_RESPONSES), + _curated_entry("databricks-gpt-5-4-mini", ProviderFamily.OPENAI_RESPONSES), + _curated_entry("databricks-gpt-5-4-nano", ProviderFamily.OPENAI_RESPONSES), + _curated_entry("databricks-gpt-5-3-codex", ProviderFamily.OPENAI_RESPONSES), + _curated_entry("databricks-gpt-5-2-codex", ProviderFamily.OPENAI_RESPONSES), + _curated_entry("databricks-gpt-5-2", ProviderFamily.OPENAI_RESPONSES), + _curated_entry("databricks-gpt-5-1", ProviderFamily.OPENAI_RESPONSES), + _curated_entry("databricks-gpt-5-nano", ProviderFamily.OPENAI_RESPONSES), + _curated_entry("databricks-gpt-5", ProviderFamily.OPENAI_RESPONSES), + _curated_entry("databricks-gpt-oss-120b", ProviderFamily.OPENAI), + # ------------------------------------------------------------------ # + # Google — Gemini (native generateContent) + # ------------------------------------------------------------------ # + # Live-tested PASS: gemini-3-5-flash, gemini-3-1-flash-lite, + # gemini-2-5-flash, gemini-2-5-pro + # gemini-3-flash / gemini-3-pro: NOT available in typical workspaces — + # they require cross-geo routing on global endpoints. They surface via + # live workspace discovery when the endpoint is actually available. + # gemma-3-12b: excluded — 8,192-token context window is below the 16k + # minimum required by OpenHands. + _curated_entry( + "databricks-gemini-3-5-flash", ProviderFamily.GEMINI, recommended=True + ), + _curated_entry("databricks-gemini-3-1-flash-lite", ProviderFamily.GEMINI), + _curated_entry("databricks-gemini-2-5-flash", ProviderFamily.GEMINI), + _curated_entry("databricks-gemini-2-5-pro", ProviderFamily.GEMINI), +) + + +def get_picker_entries( + credentials: DatabricksCredentials | None = None, + *, + include_curated: bool = True, + include_discovered: bool = True, + include_not_ready: bool = False, +) -> list[ModelPickerEntry]: + """Merged view of curated + discovered Databricks models for picker UIs. + + Dedup rule: if a qualified name is in both tiers, the curated entry wins + on ``recommended`` / ``family`` (our opinion) but picks up the live + ``endpoint_type`` and ``ready`` fields from discovery, and its ``source`` + becomes ``"curated+discovered"`` so UIs can show a "verified + available" + badge. Order: recommended-first, then by family, then by name. + + Network: calls ``list_chat_endpoints`` **only** if ``credentials`` is + provided and ``include_discovered`` is True. Without creds this is a pure + compute over the static curated set and safe to call from sync UI code. + + Errors during discovery are logged and swallowed — the curated tier is + always returned even if the workspace is unreachable. That keeps the + picker usable offline / during outages. + """ + merged: dict[str, ModelPickerEntry] = {} + + if include_curated: + for e in CURATED_DATABRICKS_MODELS: + merged[e.qualified_name] = e + + if include_discovered and credentials is not None: + try: + discovered = list_chat_endpoints( + credentials, include_not_ready=include_not_ready + ) + except Exception as exc: + logger.warning( + "databricks_discovery_failed_in_picker", extra={"error": str(exc)} + ) + discovered = [] + + for d in discovered: + existing = merged.get(d.qualified_name) + if existing is not None: + # Curated entry already present — upgrade with live signals, + # keep our opinion on family + recommended. + merged[d.qualified_name] = ModelPickerEntry( + qualified_name=existing.qualified_name, + name=existing.name, + family=existing.family, + source="curated+discovered", + endpoint_type=d.endpoint_type, + ready=d.ready, + recommended=existing.recommended, + ) + else: + merged[d.qualified_name] = ModelPickerEntry( + qualified_name=d.qualified_name, + name=d.name, + family=detect_family(d.name), + source="discovered", + endpoint_type=d.endpoint_type, + ready=d.ready, + recommended=False, + ) + + return sorted( + merged.values(), + key=lambda e: (not e.recommended, e.family.value, e.name), + ) + + +def list_models_from_env() -> list[str]: + """Convenience wrapper that reads env vars and returns TTL-cached model list. + + Reads: + DATABRICKS_HOST — required + DATABRICKS_TOKEN or DATABRICKS_ACCESS_TOKEN — required + + Returns [] silently if env vars are not set or on any API error. + Results are cached for ``_CACHE_TTL_S`` seconds. Cache writes are protected + by ``_CACHE_LOCK``; reads use a lock-free fast path (last-writer-wins on miss). + """ + global _CACHED_MODELS, _CACHE_EXPIRES_AT + + if time.time() < _CACHE_EXPIRES_AT: + return _CACHED_MODELS + + host = os.environ.get("DATABRICKS_HOST", "").rstrip("/") + token = os.environ.get("DATABRICKS_TOKEN") or os.environ.get( + "DATABRICKS_ACCESS_TOKEN" + ) + if not host or not token: + return [] + + credentials = DatabricksCredentials( + host=normalize_host(host), + get_token=lambda: token, # type: ignore[return-value] + auth_method="env", + ) + try: + models = list_foundation_models(credentials) + with _CACHE_LOCK: + _CACHED_MODELS = models + _CACHE_EXPIRES_AT = time.time() + _CACHE_TTL_S + return models + except Exception as exc: + logger.warning("databricks_discovery_failed", extra={"error": str(exc)}) + return [] diff --git a/openhands-sdk/openhands/sdk/llm/providers/databricks/llm.py b/openhands-sdk/openhands/sdk/llm/providers/databricks/llm.py new file mode 100644 index 0000000000..cbc9bfd9a2 --- /dev/null +++ b/openhands-sdk/openhands/sdk/llm/providers/databricks/llm.py @@ -0,0 +1,462 @@ +"""DatabricksLLM — native Databricks AI Gateway provider for the OpenHands V1 SDK. + +Subclasses LLM and overrides the transport layer to talk to the AI Gateway +``/ai-gateway/`` directly over httpx. + +Usage (via factory — preferred): + from openhands.sdk import create_llm + llm = create_llm( + "databricks/databricks-claude-opus-4-6", + databricks_host="https://adb-1234.cloud.databricks.com", + api_key=SecretStr("dapi..."), + ) + +The workspace URL (``databricks_host``) is the canonical configured host. +The SDK derives the AI Gateway base from it +(``/ai-gateway/``) for every FM invocation. + +``databricks_ai_gateway_host`` is an optional override for deployments +with a dedicated gateway hostname (``.ai-gateway.cloud.databricks.com``); +when set, FM invocations route through it directly. Discovery, auth, and +metadata probes always go to ``databricks_host`` regardless. +""" + +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING, Any, Literal + +from pydantic import PrivateAttr, SecretStr, field_serializer, field_validator, model_validator + +from openhands.sdk.llm.llm import LLM +from openhands.sdk.llm.providers.databricks.auth import ( + DatabricksCredentials, + resolve_credentials, +) +from openhands.sdk.llm.providers.databricks.client import DatabricksFMAPIClient +from openhands.sdk.llm.providers.databricks.models import ( + ProviderFamily, + StoredU2MTokens, + detect_family, +) +from openhands.sdk.llm.providers.databricks.utils import DatabricksTimeouts + +if TYPE_CHECKING: + from litellm.types.utils import ModelResponse + +logger = logging.getLogger(__name__) + +# --------------------------------------------------------------------------- +# Model capability tables +# --------------------------------------------------------------------------- + +# Context windows (input tokens) for known Databricks FMAPI models. +# Unknown models fall back to 128K. +# +# Values reflect what the AI Gateway endpoint accepts, not the raw upstream model +# limits — e.g. Anthropic Claude endpoints are gateway-capped at 200K even if +# the upstream contract supports more. Keep this table conservative. +DATABRICKS_CONTEXT_WINDOWS: dict[str, int] = { + # --- Databricks-hosted FM (OpenAI Chat / mlflow family) --- + "databricks/databricks-meta-llama-3-1-8b-instruct": 128_000, + "databricks/databricks-meta-llama-3-1-405b-instruct": 128_000, + "databricks/databricks-meta-llama-3-3-70b-instruct": 128_000, + "databricks/databricks-llama-4-maverick": 128_000, + "databricks/databricks-gpt-oss-20b": 128_000, + "databricks/databricks-gpt-oss-120b": 128_000, + "databricks/databricks-qwen35-122b-a10b": 128_000, + "databricks/databricks-qwen3-next-80b-a3b-instruct": 128_000, + "databricks/databricks-gemma-3-12b": 8_192, + # --- Anthropic native (Claude 4 series on gateway) --- + "databricks/databricks-claude-sonnet-4": 200_000, + "databricks/databricks-claude-sonnet-4-5": 200_000, + "databricks/databricks-claude-sonnet-4-6": 200_000, + "databricks/databricks-claude-haiku-4-5": 200_000, + "databricks/databricks-claude-opus-4-1": 200_000, + "databricks/databricks-claude-opus-4-5": 200_000, + "databricks/databricks-claude-opus-4-6": 200_000, + "databricks/databricks-claude-opus-4-7": 200_000, + # --- Google Gemini native --- + "databricks/databricks-gemini-2-5-flash": 1_048_576, + "databricks/databricks-gemini-2-5-pro": 1_048_576, + "databricks/databricks-gemini-3-1-flash-lite": 1_048_576, + "databricks/databricks-gemini-3-flash": 1_048_576, + "databricks/databricks-gemini-3-5-flash": 1_048_576, + "databricks/databricks-gemini-3-pro": 1_048_576, + "databricks/databricks-gemini-3-1-pro": 1_048_576, + # --- OpenAI Responses (GPT-5 series) --- + "databricks/databricks-gpt-5": 400_000, + "databricks/databricks-gpt-5-mini": 400_000, + "databricks/databricks-gpt-5-nano": 400_000, + "databricks/databricks-gpt-5-1": 400_000, + "databricks/databricks-gpt-5-1-codex-max": 400_000, + "databricks/databricks-gpt-5-1-codex-mini": 400_000, + "databricks/databricks-gpt-5-2": 400_000, + "databricks/databricks-gpt-5-2-codex": 400_000, + "databricks/databricks-gpt-5-3-codex": 400_000, + "databricks/databricks-gpt-5-4": 400_000, + "databricks/databricks-gpt-5-4-mini": 400_000, + "databricks/databricks-gpt-5-4-nano": 400_000, + "databricks/databricks-gpt-5-5": 400_000, + "databricks/databricks-gpt-5-5-pro": 400_000, +} + +# Maximum output tokens for known Databricks FMAPI models. +# Unknown models fall back to 16K. +# +# For reasoning-capable endpoints (gpt-5 series, gemini 2.5, gpt-oss), output +# tokens include internal thinking tokens — the budget must be generous enough +# that visible text actually fits. See ``databricks-ai-gateway-fm-apis`` skill. +DATABRICKS_MAX_OUTPUT: dict[str, int] = { + # --- OpenAI Chat / mlflow family --- + "databricks/databricks-meta-llama-3-1-8b-instruct": 4_096, + "databricks/databricks-meta-llama-3-1-405b-instruct": 4_096, + "databricks/databricks-meta-llama-3-3-70b-instruct": 4_096, + "databricks/databricks-llama-4-maverick": 8_192, + "databricks/databricks-gpt-oss-20b": 16_384, # reasoning + "databricks/databricks-gpt-oss-120b": 16_384, + "databricks/databricks-qwen35-122b-a10b": 8_192, + "databricks/databricks-qwen3-next-80b-a3b-instruct": 8_192, + "databricks/databricks-gemma-3-12b": 4_096, + # --- Anthropic (Claude 4 series) --- + "databricks/databricks-claude-sonnet-4": 8_192, + "databricks/databricks-claude-sonnet-4-5": 64_000, + "databricks/databricks-claude-sonnet-4-6": 64_000, + "databricks/databricks-claude-haiku-4-5": 8_192, + "databricks/databricks-claude-opus-4-1": 32_000, + "databricks/databricks-claude-opus-4-5": 32_000, + "databricks/databricks-claude-opus-4-6": 32_000, + "databricks/databricks-claude-opus-4-7": 32_000, + # --- Gemini (budget includes thinking tokens) --- + "databricks/databricks-gemini-2-5-flash": 65_536, + "databricks/databricks-gemini-2-5-pro": 65_536, + "databricks/databricks-gemini-3-1-flash-lite": 65_536, + "databricks/databricks-gemini-3-flash": 65_536, + "databricks/databricks-gemini-3-5-flash": 65_536, + "databricks/databricks-gemini-3-pro": 65_536, + "databricks/databricks-gemini-3-1-pro": 65_536, + # --- OpenAI Responses (GPT-5) — generous so reasoning tokens fit --- + "databricks/databricks-gpt-5": 16_384, + "databricks/databricks-gpt-5-mini": 16_384, + "databricks/databricks-gpt-5-nano": 16_384, + "databricks/databricks-gpt-5-1": 16_384, + "databricks/databricks-gpt-5-1-codex-max": 32_768, + "databricks/databricks-gpt-5-1-codex-mini": 16_384, + "databricks/databricks-gpt-5-2": 16_384, + "databricks/databricks-gpt-5-2-codex": 32_768, + "databricks/databricks-gpt-5-3-codex": 32_768, + "databricks/databricks-gpt-5-4": 16_384, + "databricks/databricks-gpt-5-4-mini": 16_384, + "databricks/databricks-gpt-5-4-nano": 16_384, + "databricks/databricks-gpt-5-5": 32_768, + "databricks/databricks-gpt-5-5-pro": 32_768, +} + + +# --------------------------------------------------------------------------- +# DatabricksLLM +# --------------------------------------------------------------------------- + +class DatabricksLLM(LLM): + """Native Databricks Foundation Model API provider. PWAF-compliant. + + Uses a direct httpx transport to the Databricks AI Gateway instead of + routing HTTP through litellm.completion. Supports OAuth U2M (browser + PKCE), OAuth M2M (client credentials), PAT, CLI profile, and the + databricks-sdk unified auth chain. + """ + + # Pydantic provider discriminator. Serialized by ``SerializeAsAny`` on + # ``AgentBase.llm`` / ``LLMSummarizingCondenser.llm``; read back by the + # ``_dispatch_to_provider_subclass`` wrap-validator on the base ``LLM`` + # class to route the payload to this subclass on load. + provider: Literal["databricks"] = "databricks" + + # --- Databricks-specific fields --- + + databricks_ai_gateway_host: str | None = None + """Optional AI Gateway override — host only, scheme + hostname[:port], no path. + + When set, all FM invocations route through this host instead of the + workspace URL. Use this for deployments with a dedicated gateway, e.g. + ``https://.ai-gateway.cloud.databricks.com``. + + Leave unset for the common single-URL deployment — the SDK then routes + invocations through ``/ai-gateway/``. + + Must start with ``https://`` and contain no path.""" + + databricks_host: str | None = None + """Workspace URL — the canonical Databricks endpoint. + + Used for: + + * FM invocations (default base, becomes ``/ai-gateway/``) + unless ``databricks_ai_gateway_host`` is set. + * Auth / token resolution (OAuth flows mint tokens here). + * Discovery and the opt-in metadata probe + (``GET /api/2.0/serving-endpoints/...``). + + Required for OAuth-based auth (``profile`` / ``m2m`` / ``u2m`` / + unified). For PAT auth it's optional only when + ``databricks_ai_gateway_host`` is also set (the gateway then has its + own URL and the workspace isn't needed).""" + + databricks_metadata_probe: bool = False + """When True, ``resolve_family`` issues + ``GET /api/2.0/serving-endpoints/{name}`` against ``databricks_host`` + to authoritatively determine the AI Gateway path family from the + server-side ``api_types``. Results cached in-process for 5 minutes + per endpoint. Default False (name-pattern resolution only).""" + + databricks_client_id: str | None = None + """M2M service principal application ID (OAuth client_credentials grant). + NOT the same as DATABRICKS_U2M_CLIENT_ID (browser OAuth app). + Set DATABRICKS_CLIENT_SECRET alongside this.""" + + databricks_client_secret: SecretStr | None = None + """M2M service principal OAuth secret. Paired with databricks_client_id.""" + + databricks_profile: str | None = None + """Databricks CLI profile name from ~/.databrickscfg. Requires databricks-sdk.""" + + databricks_ssl_verify: bool = True + """SSL/TLS verification. Set to path string for custom CA bundle.""" + + stored_u2m_tokens: StoredU2MTokens | None = None + """U2M OAuth tokens from browser login flow. Passed from app layer. + Highest-priority auth path (PWAF: OAuth primary).""" + + databricks_u2m_client_id: str | None = None + """Custom OAuth application client ID for the U2M browser PKCE flow. + When set, PKCE uses this client_id instead of the default Databricks CLI + OAuth app. Preserved across sessions so the user only enters it once.""" + + databricks_u2m_client_secret: SecretStr | None = None + """Client secret for confidential U2M OAuth apps (PKCE flow). + Required when the Databricks App Connection is configured as a confidential + app. Leave None for public apps. Persisted so re-authentication only needed + when the secret rotates.""" + + databricks_u2m_redirect_uri: str | None = None + """Redirect URI for the custom U2M OAuth app (PKCE flow). + Defaults to 'http://localhost:8080/callback' when not set.""" + + # --- Resilience knobs --- + + databricks_max_retries: int = 3 + databricks_connect_timeout_s: float = 10.0 + databricks_read_timeout_s: float = 120.0 + databricks_chunk_timeout_s: float = 30.0 + + # --- Private state (not serialized) --- + + _db_credentials: DatabricksCredentials = PrivateAttr() + _db_client: DatabricksFMAPIClient = PrivateAttr() + + # --------------------------------------------------------------------------- + # Validators + # --------------------------------------------------------------------------- + + @field_validator("databricks_host", mode="before") + @classmethod + def _validate_host(cls, v: str | None) -> str | None: + if v is not None and not v.startswith("https://"): + raise ValueError( + f"databricks_host must start with 'https://'. Got: {v!r}" + ) + return v + + @field_validator("databricks_ai_gateway_host", mode="before") + @classmethod + def _validate_ai_gateway_host(cls, v: str | None) -> str | None: + if v is None or v == "": + return None + if not v.startswith("https://"): + raise ValueError( + "databricks_ai_gateway_host must start with 'https://'. " + f"Got: {v!r}" + ) + from urllib.parse import urlsplit + parts = urlsplit(v) + # Allow trailing ``/ai-gateway`` so users who copy-paste the full + # gateway base URL aren't penalised; everything else must be host-only. + path = (parts.path or "").rstrip("/") + if path and path != "/ai-gateway": + raise ValueError( + "databricks_ai_gateway_host must be host-only (optionally " + "ending in '/ai-gateway'); the SDK appends the per-family " + f"route itself. Got path={parts.path!r} in {v!r}" + ) + if parts.query or parts.fragment: + raise ValueError( + "databricks_ai_gateway_host must not include a query string " + f"or fragment. Got: {v!r}" + ) + return v.rstrip("/") + + def _serialize_secret_field(self, v: SecretStr | None, info) -> str | None: + """Shared serializer body for all DatabricksLLM SecretStr fields. + + DatabricksLLM-specific secret fields are not in the base LLM_SECRET_FIELDS + tuple so they don't benefit from the base _serialize_secrets serializer. + This method mirrors the same logic so save/load round-trips work: + - expose_secrets=True → plaintext (AgentStore.save path) + - default → redacted string "**********" + Always returns str | None (never SecretStr) to avoid Pydantic warnings. + """ + if v is None: + return None + from openhands.sdk.utils.pydantic_secrets import ( + REDACTED_SECRET_VALUE, + serialize_secret, + ) + result = serialize_secret(v, info) + if isinstance(result, SecretStr): + return REDACTED_SECRET_VALUE + return result + + @field_validator("databricks_u2m_client_secret", mode="before") + @classmethod + def _validate_u2m_secret( + cls, v: "str | SecretStr | None", info + ) -> "SecretStr | None": + """Coerce str → SecretStr and discard redacted placeholder values.""" + from openhands.sdk.utils.pydantic_secrets import validate_secret + + return validate_secret(v, info) + + @field_serializer("databricks_client_secret", when_used="always") + def _serialize_databricks_secret( + self, v: SecretStr | None, info + ) -> str | None: + return self._serialize_secret_field(v, info) + + @field_serializer("databricks_u2m_client_secret", when_used="always") + def _serialize_databricks_u2m_secret( + self, v: SecretStr | None, info + ) -> str | None: + return self._serialize_secret_field(v, info) + + @model_validator(mode="after") + def _init_databricks(self) -> "DatabricksLLM": + if not (self.databricks_ai_gateway_host or self.databricks_host): + raise ValueError( + "databricks_host is required (or databricks_ai_gateway_host " + "as an override). FM invocations route through " + "/ai-gateway/ by default." + ) + self._db_credentials = resolve_credentials(self) + self._db_client = DatabricksFMAPIClient( + credentials=self._db_credentials, + ai_gateway_host=self.databricks_ai_gateway_host, + timeouts=DatabricksTimeouts( + connect_s=self.databricks_connect_timeout_s, + read_s=self.databricks_read_timeout_s, + chunk_s=self.databricks_chunk_timeout_s, + ), + max_retries=self.databricks_max_retries, + ssl_verify=self.databricks_ssl_verify, + metadata_probe=self.databricks_metadata_probe, + ) + return self + + # --------------------------------------------------------------------------- + # PWAF surfaces (observability, diagnostics, pickers) + # --------------------------------------------------------------------------- + + @property + def auth_method(self) -> str: + """Resolved auth strategy: ``pat`` | ``m2m`` | ``u2m`` | ``profile`` | ``unified`` | ``env``. + + Read-only — set by ``resolve_credentials()`` during construction. + Handy for log correlation and operator dashboards. + """ + return self._db_credentials.auth_method + + @property + def predicted_family(self) -> ProviderFamily: + """Provider family predicted by **name pattern only** (no HTTP call). + + Useful for picker UIs and validation — gives an immediate answer without + hitting the ``/api/2.0/serving-endpoints/{name}`` describe endpoint. For + the authoritative family used at request time, call + :meth:`resolve_family` (it performs a metadata probe with in-process + caching and falls back to this same prediction on error). + """ + return detect_family(self.model) + + def resolve_family(self) -> ProviderFamily: + """Provider family used at request time. + + Default (``databricks_metadata_probe=False``): pure name-pattern + resolution — same as :attr:`predicted_family`, no network call. + + Opt-in (``databricks_metadata_probe=True``): metadata-first with + name-pattern fallback. Triggers at most one + ``GET /api/2.0/serving-endpoints/{name}`` per endpoint per 5-minute + TTL window against the workspace URL. + """ + endpoint = self.model.removeprefix("databricks/") + return self._db_client.resolve_family(endpoint) + + # --------------------------------------------------------------------------- + # LLM overrides + # --------------------------------------------------------------------------- + + def _init_model_info_and_caps(self) -> None: + """Override: set context windows from Databricks capability tables.""" + self.max_input_tokens = DATABRICKS_CONTEXT_WINDOWS.get(self.model, 128_000) + self.max_output_tokens = DATABRICKS_MAX_OUTPUT.get(self.model, 16_384) + self._validate_context_window_size() + + def _get_litellm_api_key_value(self) -> str | None: + """Override: return a fresh Databricks token rather than the static api_key.""" + return self._db_credentials.get_token() + + def close(self) -> None: + """Release the underlying HTTP connection pool. + + Call this when discarding a DatabricksLLM instance to avoid leaking + file descriptors. Safe to call multiple times. + """ + try: + self._db_client.close() + except Exception: + pass + + def _transport_call( + self, + *, + messages: list[dict[str, Any]], + enable_streaming: bool = False, + on_token=None, + **kwargs, + ) -> "ModelResponse": + """Override: call the Databricks FMAPI directly via httpx.""" + model_name = self.model.removeprefix("databricks/") + logger.debug( + "databricks_transport_call", + extra={ + "endpoint": model_name, + "auth_method": self._db_credentials.auth_method, + "predicted_family": detect_family(self.model).value, + # Authoritative family is resolved inside the client (with cache); + # we don't re-probe here to avoid a second call path. + "streaming": enable_streaming, + }, + ) + # Strip litellm-specific kwargs that must not appear in the JSON body + # forwarded to the Databricks AI Gateway. + # - stream: controlled via enable_streaming to avoid duplicate kwarg + # - extra_headers: litellm convention; headers are set by _make_headers() + # - extra_body: litellm convention; unsupported by the gateway + for _k in ("stream", "extra_headers", "extra_body"): + kwargs.pop(_k, None) + return self._db_client.chat_completion( + model=model_name, + messages=messages, + stream=enable_streaming, + on_token=on_token, + **kwargs, + ) diff --git a/openhands-sdk/openhands/sdk/llm/providers/databricks/models.py b/openhands-sdk/openhands/sdk/llm/providers/databricks/models.py new file mode 100644 index 0000000000..3139c10985 --- /dev/null +++ b/openhands-sdk/openhands/sdk/llm/providers/databricks/models.py @@ -0,0 +1,286 @@ +"""Databricks AI Gateway routing config + shared Pydantic models. + +All FM traffic is routed through **Databricks AI Gateway** under the +``/ai-gateway/`` URL prefix. The gateway is reachable two ways: + +* Reverse-proxied under the workspace host + (``https:///ai-gateway/...``) — this is what + :func:`AIGatewayPaths.normalize_base` produces when the user only + configures the workspace URL. +* Dedicated AI-Gateway hostname + (``https://.ai-gateway.cloud.databricks.com/...``) — used + when the customer has a separate gateway endpoint (PrivateLink, Front + Door). In that case the path templates are appended directly without + the ``/ai-gateway`` prefix. + +This module carries: + +1. :class:`StoredU2MTokens` — OAuth token container shared with the app layer. +2. :class:`ProviderFamily` — dispatch key for which provider-native contract + the target endpoint speaks (OpenAI Chat, Anthropic Messages, Google Gemini + ``generateContent``, OpenAI Responses). +3. :class:`AIGatewayPaths` — path templates for each family's AI Gateway route. +4. :func:`detect_family` — name-pattern router (fast path, no HTTP call). +5. :func:`pick_family_from_api_types` — metadata router (authoritative; uses the + ``foundation_model.api_types`` / ``external_model.provider`` signals returned + by ``GET /api/2.0/serving-endpoints/{name}``). + +The two routers mirror the `databricks-ai-gateway-fm-apis` skill exactly — +keep them in sync when the skill's routing table changes. + +``StoredU2MTokens`` is defined **exactly once** here; `auth.py` and the +OpenHands app layer import from this module (no duplicate definitions). +""" + +from __future__ import annotations + +import re +from enum import Enum +from typing import Iterable + +from pydantic import BaseModel, Field + + +# --------------------------------------------------------------------------- +# OAuth tokens (shared container) +# --------------------------------------------------------------------------- + + +class StoredU2MTokens(BaseModel): + """OAuth tokens stored in the OpenHands user session after browser login. + + Passed from the app layer (after ``/auth/databricks/callback``) to + ``resolve_credentials()``. The provider module never initiates the browser + flow — it only manages token refresh. + """ + + access_token: str + refresh_token: str + expires_at: float # Unix epoch seconds + client_id: str # DATABRICKS_U2M_CLIENT_ID — required for token refresh + host: str # Workspace host — fallback when databricks_host is unset + + +# --------------------------------------------------------------------------- +# Provider family — drives AI Gateway payload format + URL path +# --------------------------------------------------------------------------- + + +class ProviderFamily(str, Enum): + """Which provider-native API format the AI Gateway endpoint speaks. + + Routing (in priority order — first match wins): + + ===================== ========================== ============================== + Family Name pattern Metadata ``api_types`` entry + ===================== ========================== ============================== + :attr:`ANTHROPIC` ``*claude*`` ``anthropic/v1/messages`` + :attr:`GEMINI` ``*gemini*`` ``gemini/v1/generateContent`` + :attr:`OPENAI_RESPONSES` ``databricks-gpt-5*`` ``openai/v1/responses`` + :attr:`OPENAI` *(default)* ``mlflow/v1/chat/completions`` + ===================== ========================== ============================== + + :attr:`OPENAI` is the **always-safe default** — every ``task=llm/v1/chat`` + endpoint accepts OpenAI-chat payloads at ``/{endpoint}/invocations`` and + returns OpenAI ``ChatCompletion`` responses. The other families are opt-in + and only used when there's positive evidence (metadata or name match) that + the endpoint speaks the native contract. + """ + + OPENAI = "openai" # OpenAI Chat — universal default + OPENAI_RESPONSES = "openai_responses" # OpenAI Responses — GPT-5 series only + ANTHROPIC = "anthropic" # Anthropic Messages — Claude models + GEMINI = "gemini" # Google Gemini generateContent + + +# --------------------------------------------------------------------------- +# AI Gateway path templates +# --------------------------------------------------------------------------- + + +class AIGatewayPaths(BaseModel): + """Path templates appended to the AI Gateway base for each native API. + + All four templates have been verified against a live Databricks workspace: + + * :attr:`openai` — OpenAI Chat Completions, mlflow flavor, universal + default. Endpoint name is carried in the body. + * :attr:`openai_responses` — OpenAI Responses API for the GPT-5 series. + Endpoint name is in the body. + * :attr:`anthropic` — Anthropic Messages API, native flavor for Claude + models. Endpoint name is in the body. + * :attr:`gemini` — Google Gemini ``generateContent`` native path. The + endpoint name is part of the URL. + + Templates intentionally start at the AI Gateway base (without the + ``/ai-gateway`` prefix). :meth:`url` calls :meth:`normalize_base` first + to produce the right base URL given the configured host: + + * Workspace URL (``adb-*.cloud.databricks.com``) → + ``/ai-gateway`` (the gateway is reverse-proxied under the + workspace control plane). + * Dedicated gateway URL (``*.ai-gateway.*``) → host as-is, the gateway + hostname is itself the base. + + Each template can be overridden for deployments with non-standard path + layouts. ``{endpoint}`` is substituted with the bare endpoint name + (after stripping the ``databricks/`` prefix). + """ + + openai: str = Field( + default="/mlflow/v1/chat/completions", + description="OpenAI Chat Completions (mlflow flavor; universal default).", + ) + openai_responses: str = Field( + default="/openai/v1/responses", + description="OpenAI Responses API. Endpoint name is in the body.", + ) + anthropic: str = Field( + default="/anthropic/v1/messages", + description="Anthropic Messages API. Endpoint name is in the body.", + ) + gemini: str = Field( + default="/gemini/v1beta/models/{endpoint}:generateContent", + description="Google Gemini generateContent native path.", + ) + + @staticmethod + def normalize_base(host: str) -> str: + """Return the AI Gateway base URL for a configured host. + + - Hosts whose netloc matches ``*.ai-gateway.*`` are dedicated AI + Gateway endpoints; the host itself is the base — return as-is + (after stripping any trailing slash). + - Hosts that already end with ``/ai-gateway`` are returned as-is. + - Anything else is treated as a workspace URL with the gateway + reverse-proxied; ``/ai-gateway`` is appended. + """ + h = host.rstrip("/") + # Crude netloc check; ``http(s):///...`` and bare ``netloc`` + # both work here without pulling in urllib for one substring match. + scheme_split = h.split("://", 1) + netloc = scheme_split[1].split("/", 1)[0] if len(scheme_split) == 2 else h + if ".ai-gateway." in netloc: + return h + if h.endswith("/ai-gateway"): + return h + return h + "/ai-gateway" + + def url(self, host: str, family: ProviderFamily, endpoint: str) -> str: + """Build the fully-qualified URL for a ``(family, endpoint)`` pair. + + ``host`` may be a workspace URL or a dedicated AI Gateway hostname; + :meth:`normalize_base` figures out which prefix to apply. + """ + tmpl = getattr(self, family.value) + return self.normalize_base(host) + tmpl.format(endpoint=endpoint) + + +# --------------------------------------------------------------------------- +# Routers +# --------------------------------------------------------------------------- + +# Strip these from the model id before matching. Keeps support for +# "databricks/databricks-claude-sonnet-4-5", "databricks-claude-...", etc. +_MODEL_PREFIXES = ("databricks/", "databricks-") + + +def _bare_name(model: str) -> str: + name = model.lower().strip() + for prefix in _MODEL_PREFIXES: + if name.startswith(prefix): + name = name[len(prefix):] + return name + + +def detect_family(model: str) -> ProviderFamily: + """Name-pattern router (fast path, no extra API call). + + Mirrors the ``databricks-ai-gateway-fm-apis`` skill's ``route_by_name``. + + Priority (first match wins): + + 1. ``*claude*`` → :attr:`ProviderFamily.ANTHROPIC` + 2. ``*gemini*`` → :attr:`ProviderFamily.GEMINI` + 3. ``gpt-*`` → :attr:`ProviderFamily.OPENAI_RESPONSES` + The leading digit requirement naturally excludes ``gpt-oss-*`` + (starts with ``gpt-o``) without needing an explicit blocklist. + Confirmed live against the full GPT-5 product line (April 2026): + ``gpt-5``, ``gpt-5-1``, ``gpt-5-1-codex-{max,mini}``, + ``gpt-5-2``, ``gpt-5-2-codex``, ``gpt-5-3-codex``, ``gpt-5-4``, + ``gpt-5-4-{mini,nano}``, ``gpt-5-mini``, ``gpt-5-nano``. + Any future ``gpt-6``, ``gpt-7``, … variants automatically inherit + this rule — Databricks routes all numbered GPT generations through + the OpenAI Responses API (``/openai/v1/responses``). + 4. Everything else → :attr:`ProviderFamily.OPENAI` (universal + MLflow Chat Completions — safe default for ``gpt-oss``, Llama, …) + """ + name = _bare_name(model) + if "claude" in name: + return ProviderFamily.ANTHROPIC + if "gemini" in name: + return ProviderFamily.GEMINI + # ``_bare_name`` strips both ``databricks/`` and ``databricks-`` prefixes. + # ``re.match`` anchors at the start only — ``gpt-\d`` requires a digit + # immediately after the dash, so ``gpt-oss-*`` falls through cleanly. + if re.match(r"gpt-\d", name): + return ProviderFamily.OPENAI_RESPONSES + return ProviderFamily.OPENAI + + +# ``api_types`` strings exposed by ``GET /api/2.0/serving-endpoints/{name}`` +# (``config.served_entities[0].foundation_model.api_types``). +_API_TYPE_TO_FAMILY: dict[str, ProviderFamily] = { + "anthropic/v1/messages": ProviderFamily.ANTHROPIC, + "gemini/v1/generateContent": ProviderFamily.GEMINI, + "openai/v1/responses": ProviderFamily.OPENAI_RESPONSES, + # mlflow/v1/chat/completions is the universal fallback → OPENAI, + # handled by priority ordering below. +} + +# ``external_model.provider`` values (external endpoints don't have +# ``foundation_model``; they expose a single upstream provider instead). +_EXTERNAL_PROVIDER_TO_FAMILY: dict[str, ProviderFamily] = { + "anthropic": ProviderFamily.ANTHROPIC, + "bedrock-anthropic": ProviderFamily.ANTHROPIC, + "google": ProviderFamily.GEMINI, + "gemini": ProviderFamily.GEMINI, + # OpenAI & azure-openai endpoints speak Chat Completions on /invocations; + # map to OPENAI so we take the default path. + "openai": ProviderFamily.OPENAI, + "azure-openai": ProviderFamily.OPENAI, +} + +# Priority order when an endpoint exposes multiple ``api_types`` — prefer the +# most specific native API, fall back to OpenAI Chat. Reverse this list if you +# want "stay on OpenAI Chat unless explicitly overridden" behaviour. +_API_TYPE_PRIORITY: tuple[str, ...] = ( + "anthropic/v1/messages", + "gemini/v1/generateContent", + "openai/v1/responses", +) + + +def pick_family_from_api_types( + api_types: Iterable[str] | None, + external_provider: str | None = None, +) -> ProviderFamily: + """Metadata-first router (authoritative — no name-based guessing). + + ``api_types`` comes from ``foundation_model.api_types`` on a foundation-model + endpoint. ``external_provider`` comes from ``external_model.provider`` on an + external-model endpoint. Exactly one of them is populated for a given + endpoint; passing both is fine (foundation signals win). + + Returns :attr:`ProviderFamily.OPENAI` when no native signal is present — + this is the always-safe default for any ``task=llm/v1/chat`` endpoint. + """ + present = set(api_types or ()) + for key in _API_TYPE_PRIORITY: + if key in present: + return _API_TYPE_TO_FAMILY[key] + if external_provider: + return _EXTERNAL_PROVIDER_TO_FAMILY.get( + external_provider.lower(), ProviderFamily.OPENAI + ) + return ProviderFamily.OPENAI diff --git a/openhands-sdk/openhands/sdk/llm/providers/databricks/native.py b/openhands-sdk/openhands/sdk/llm/providers/databricks/native.py new file mode 100644 index 0000000000..7091c690f9 --- /dev/null +++ b/openhands-sdk/openhands/sdk/llm/providers/databricks/native.py @@ -0,0 +1,570 @@ +"""Minimal native-API adapters for Databricks AI Gateway. + +Everything in the OpenHands agent loop is OpenAI Chat Completions. This module +has one small adapter per non-default provider family — each ~30 LOC — that: + +1. Converts an OpenAI-chat ``messages`` list + generation kwargs to the + native request body for that family (``to_native``). +2. Converts the native response back to a minimal OpenAI ``ChatCompletion`` + dict that ``client._parse_response`` / ``litellm.ModelResponse`` accept + unchanged (``from_native``). + +Streaming is **not** adapted here — streaming stays on the universal OpenAI +Chat SSE path (``/invocations``) in ``client.py``. Native-API streaming +(Anthropic / Gemini / Responses) is a follow-up, documented in the skill. + +Out-of-scope by design (documented, not silently stripped): + +* Anthropic prompt caching, tool-use block preservation beyond plain text. +* Gemini multi-modal ``parts`` (inlineData, fileData). +* Responses ``custom`` / ``apply_patch`` / ``mcp`` tool types. + +For these, use the native provider SDK directly against the Databricks +gateway ``base_url`` — see the companion skill's "Using provider SDKs +directly" section. +""" + +from __future__ import annotations + +import time +import uuid +from typing import Any + +from openhands.sdk.llm.providers.databricks.models import ProviderFamily + + +# --------------------------------------------------------------------------- +# Public dispatch +# --------------------------------------------------------------------------- + + +def to_native( + family: ProviderFamily, + model: str, + messages: list[dict], + **kwargs: Any, +) -> dict: + """Build the native request body for the given family. + + ``model`` is the bare endpoint name (no ``databricks/`` prefix). + Unknown kwargs are passed through where the native API supports them + and dropped where it doesn't. + """ + if family is ProviderFamily.ANTHROPIC: + return _to_anthropic(model, messages, **kwargs) + if family is ProviderFamily.GEMINI: + return _to_gemini(model, messages, **kwargs) + if family is ProviderFamily.OPENAI_RESPONSES: + return _to_responses(model, messages, **kwargs) + return _to_openai_chat(model, messages, **kwargs) + + +def from_native( + family: ProviderFamily, + model: str, + data: dict, +) -> dict: + """Normalize a native response ``data`` to an OpenAI ChatCompletion dict.""" + if family is ProviderFamily.ANTHROPIC: + return _from_anthropic(model, data) + if family is ProviderFamily.GEMINI: + return _from_gemini(model, data) + if family is ProviderFamily.OPENAI_RESPONSES: + return _from_responses(model, data) + return _from_openai_chat(model, data) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _chat_completion( + model: str, + content: str, + *, + finish_reason: str = "stop", + usage: dict | None = None, + tool_calls: list[dict] | None = None, + response_id: str | None = None, +) -> dict: + """Build a minimal OpenAI-chat ``ChatCompletion`` dict.""" + msg: dict[str, Any] = {"role": "assistant", "content": content} + if tool_calls: + msg["tool_calls"] = tool_calls + return { + "id": response_id or f"chatcmpl-{uuid.uuid4().hex[:24]}", + "object": "chat.completion", + "created": int(time.time()), + "model": model, + "choices": [{"index": 0, "message": msg, "finish_reason": finish_reason}], + "usage": usage or {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}, + } + + +def _flatten_content(content: Any) -> str: + """Flatten OpenAI-style content (str | list of {type,text}) to plain str. + + Handles the reasoning-model shape where ``choices[0].message.content`` is + a list of blocks (``{"type":"reasoning",...}``, ``{"type":"text","text":...}``). + Skips ``reasoning`` blocks; concatenates ``text`` blocks. + """ + if isinstance(content, str): + return content + if isinstance(content, list): + parts: list[str] = [] + for blk in content: + if not isinstance(blk, dict): + continue + t = blk.get("type") + if t == "text" and isinstance(blk.get("text"), str): + parts.append(blk["text"]) + # "reasoning" blocks are intentionally skipped. + return "".join(parts) + return "" + + +_GENERIC_KWARGS = {"temperature", "top_p", "stop"} + + +# --------------------------------------------------------------------------- +# OpenAI Chat (default) +# --------------------------------------------------------------------------- + + +def _to_openai_chat(model: str, messages: list[dict], **kwargs: Any) -> dict: + # The AI Gateway ``/mlflow/v1/chat/completions`` endpoint reads the + # target endpoint name from the body, not the URL — every other family + # already does this, so the OpenAI Chat path now matches. + body: dict[str, Any] = {"model": model, "messages": messages} + if "max_tokens" in kwargs and kwargs["max_tokens"] is not None: + body["max_tokens"] = kwargs["max_tokens"] + if kwargs.get("tools"): + body["tools"] = kwargs["tools"] + if kwargs.get("tool_choice") is not None: + body["tool_choice"] = kwargs["tool_choice"] + for k in _GENERIC_KWARGS: + if kwargs.get(k) is not None: + body[k] = kwargs[k] + if kwargs.get("stream"): + body["stream"] = True + return body + + +def _from_openai_chat(model: str, data: dict) -> dict: + # Gateway already returns ChatCompletion shape; only normalize reasoning + # models' list-of-blocks content back to a string so downstream consumers + # don't need to special-case it. + choices = data.get("choices") or [] + if choices: + msg = choices[0].get("message") or {} + if isinstance(msg.get("content"), list): + msg["content"] = _flatten_content(msg["content"]) + return data + + +# --------------------------------------------------------------------------- +# Anthropic Messages +# --------------------------------------------------------------------------- + + +def _to_anthropic(model: str, messages: list[dict], **kwargs: Any) -> dict: + """OpenAI messages → Anthropic Messages body. + + System messages become the top-level ``system`` string (Anthropic doesn't + support ``role=system`` inside ``messages``). Tool calls beyond plain text + are not converted here — use the Anthropic SDK against the gateway + ``base_url`` if you need Anthropic-native tool_use blocks. + """ + system_parts: list[str] = [] + conv: list[dict] = [] + for m in messages: + role = m.get("role") + content = _flatten_content(m.get("content")) + if role == "system": + if content: + system_parts.append(content) + continue + if role in ("user", "assistant") and content: + conv.append({"role": role, "content": content}) + body: dict[str, Any] = { + "model": model, + "messages": conv, + # max_tokens is REQUIRED by Anthropic Messages; provide a safe default. + "max_tokens": int(kwargs.get("max_tokens") or 1024), + } + if system_parts: + body["system"] = "\n\n".join(system_parts) + if kwargs.get("temperature") is not None: + body["temperature"] = kwargs["temperature"] + if kwargs.get("top_p") is not None: + body["top_p"] = kwargs["top_p"] + if kwargs.get("stop"): + stop = kwargs["stop"] + body["stop_sequences"] = [stop] if isinstance(stop, str) else list(stop) + return body + + +_ANTHROPIC_STOP_MAP = { + "end_turn": "stop", + "stop_sequence": "stop", + "max_tokens": "length", + "tool_use": "tool_calls", +} + + +def _from_anthropic(model: str, data: dict) -> dict: + text_parts: list[str] = [] + for blk in data.get("content") or []: + if isinstance(blk, dict) and blk.get("type") == "text": + text_parts.append(blk.get("text", "")) + u = data.get("usage") or {} + usage = { + "prompt_tokens": u.get("input_tokens", 0), + "completion_tokens": u.get("output_tokens", 0), + "total_tokens": (u.get("input_tokens", 0) + u.get("output_tokens", 0)), + } + return _chat_completion( + model, + "".join(text_parts), + finish_reason=_ANTHROPIC_STOP_MAP.get(data.get("stop_reason", ""), "stop"), + usage=usage, + response_id=data.get("id"), + ) + + +# --------------------------------------------------------------------------- +# Google Gemini generateContent +# --------------------------------------------------------------------------- + + +def _to_gemini(model: str, messages: list[dict], **kwargs: Any) -> dict: + """OpenAI messages → Gemini ``generateContent`` body. + + ``role=system`` maps to ``systemInstruction``. OpenAI ``assistant`` becomes + Gemini ``model``. ``maxOutputTokens`` defaults to 1024 — Gemini budgets + *thinking + output* against this, so anything below ~256 can return empty. + """ + system_parts: list[str] = [] + contents: list[dict] = [] + for m in messages: + role = m.get("role") + text = _flatten_content(m.get("content")) + if role == "system": + if text: + system_parts.append(text) + continue + g_role = "model" if role == "assistant" else "user" + if text: + contents.append({"role": g_role, "parts": [{"text": text}]}) + gen_config: dict[str, Any] = { + "maxOutputTokens": int(kwargs.get("max_tokens") or 1024), + } + if kwargs.get("temperature") is not None: + gen_config["temperature"] = kwargs["temperature"] + if kwargs.get("top_p") is not None: + gen_config["topP"] = kwargs["top_p"] + if kwargs.get("stop"): + stop = kwargs["stop"] + gen_config["stopSequences"] = [stop] if isinstance(stop, str) else list(stop) + body: dict[str, Any] = {"contents": contents, "generationConfig": gen_config} + if system_parts: + body["systemInstruction"] = {"parts": [{"text": "\n\n".join(system_parts)}]} + return body + + +_GEMINI_FINISH_MAP = { + "STOP": "stop", + "MAX_TOKENS": "length", + "SAFETY": "content_filter", + "RECITATION": "content_filter", + "OTHER": "stop", +} + + +def _from_gemini(model: str, data: dict) -> dict: + cands = data.get("candidates") or [] + text = "" + finish = "stop" + if cands: + cand = cands[0] + parts = ((cand.get("content") or {}).get("parts")) or [] + text = "".join(p.get("text", "") for p in parts if isinstance(p, dict)) + finish = _GEMINI_FINISH_MAP.get(cand.get("finishReason", ""), "stop") + um = data.get("usageMetadata") or {} + usage = { + "prompt_tokens": um.get("promptTokenCount", 0), + "completion_tokens": um.get("candidatesTokenCount", 0), + "total_tokens": um.get("totalTokenCount", 0), + } + return _chat_completion( + model, + text, + finish_reason=finish, + usage=usage, + response_id=data.get("responseId"), + ) + + +# --------------------------------------------------------------------------- +# OpenAI Responses (GPT-5 series) +# --------------------------------------------------------------------------- + +# Pay-per-token FM endpoints reject these — they're documented for the +# hosted OpenAI Responses API but not supported via the gateway. +_RESPONSES_DROP = { + "background", + "store", + "previous_response_id", + "service_tier", + # ``max_completion_tokens`` is the Chat-Completions name; the upstream + # LLM/litellm path emits it for OpenAI-flavoured calls. Responses uses + # ``max_output_tokens`` (set explicitly above), so drop the chat-style + # alias to avoid the API's ``unsupported_parameter`` 400. + "max_completion_tokens", + # GPT-5 reasoning models routed through ``/openai/v1/responses`` reject + # both ``temperature`` and ``top_p`` ("Unsupported parameter…"). The + # SDK still respects user intent by carrying the values into kwargs, + # but the adapter drops them before they hit the wire so a single + # ``temperature=0.0`` default doesn't break every GPT-5 call. + "temperature", + "top_p", +} + + +def _to_responses_input(messages: list[dict]) -> list[dict]: + """Translate Chat-Completions messages into Responses API ``input`` items. + + The Responses API uses a flat item list instead of a ``messages`` array. + This function handles all message types that appear in a multi-turn + OpenHands conversation: + + Chat Completions role → Responses input item(s) + ───────────────────────────────────────────────────────────────────────── + user / system → {"role":"user/system","content":[{"type":"input_text","text":...}]} + assistant (text only) → {"role":"assistant","content":[{"type":"output_text","text":...}]} + assistant (tool_calls) → one {"type":"function_call","call_id":...} item per tool call, + followed by an output_text item if there is also text content + tool (result) → {"type":"function_call_output","call_id":...,"output":...} + ───────────────────────────────────────────────────────────────────────── + + Without the tool_calls and tool-result translations, multi-turn + conversations where GPT-5 calls a tool fail on the second turn because + the Responses API has no record of the function_call in the history. + """ + role_to_part_type = { + "user": "input_text", + "system": "input_text", + "developer": "input_text", + } + out: list[dict] = [] + for msg in messages: + if not isinstance(msg, dict): + out.append(msg) + continue + role = msg.get("role") + content = msg.get("content") + + # ── tool result ───────────────────────────────────────────────────── + if role == "tool": + out.append({ + "type": "function_call_output", + "call_id": msg.get("tool_call_id", ""), + "output": str(_flatten_content(content)) if content else "", + }) + continue + + # ── assistant with tool_calls (possibly also with text) ───────────── + if role == "assistant" and msg.get("tool_calls"): + for tc in msg.get("tool_calls") or []: + if not isinstance(tc, dict): + continue + fn = tc.get("function") or {} + out.append({ + "type": "function_call", + "call_id": tc.get("id", ""), + "name": fn.get("name", ""), + "arguments": fn.get("arguments", "{}"), + }) + # If the assistant message also contained text, emit it too. + text = _flatten_content(content) if content else "" + if text: + out.append({ + "role": "assistant", + "content": [{"type": "output_text", "text": text}], + }) + continue + + # ── text-only assistant messages ───────────────────────────────────── + if role == "assistant": + text = _flatten_content(content) if content else "" + if text: + out.append({ + "role": "assistant", + "content": [{"type": "output_text", "text": text}], + }) + continue + + # ── user / system / developer ──────────────────────────────────────── + part_type = role_to_part_type.get(role) + if part_type is None or content is None: + # Unknown role — pass through unchanged so the gateway surfaces + # a clear error rather than silently mangling the message. + out.append(msg) + continue + + if isinstance(content, str): + translated_content: list[dict] = [ + {"type": part_type, "text": content} + ] + elif isinstance(content, list): + translated_content = [] + for c in content: + if isinstance(c, dict) and c.get("type") == "text": + translated_content.append( + {"type": part_type, "text": c.get("text", "")} + ) + else: + translated_content.append(c) + else: + out.append(msg) + continue + + new_msg = dict(msg) + new_msg["content"] = translated_content + out.append(new_msg) + return out + + +def _chat_tools_to_responses(tools: list[dict]) -> list[dict]: + """Convert Chat Completions tool format to Responses API format. + + Chat Completions: + {"type": "function", "function": {"name": "...", "description": "...", "parameters": {...}}} + + Responses API: + {"type": "function", "name": "...", "description": "...", "parameters": {...}} + + The Responses API requires ``name`` at the top level of each tool object. + Passing the nested Chat Completions format causes a 400 "Missing required + parameter: 'tools[0].name'" error from the gateway. + """ + out: list[dict] = [] + for tool in tools: + if not isinstance(tool, dict): + out.append(tool) + continue + fn = tool.get("function") + if tool.get("type") == "function" and isinstance(fn, dict): + # Unwrap the function wrapper — Responses API is flat. + converted: dict[str, Any] = {"type": "function"} + converted.update(fn) + out.append(converted) + else: + # Non-function tool or already in Responses format — pass through. + out.append(tool) + return out + + +def _to_responses(model: str, messages: list[dict], **kwargs: Any) -> dict: + """OpenAI messages → OpenAI Responses body. + + Responses uses ``input`` (not ``messages``), ``max_output_tokens`` (not + ``max_tokens``), and a different content-part vocabulary + (``input_text`` / ``output_text`` instead of ``text``). Default budget + is 1024 since GPT-5 spends tokens on reasoning before producing + visible output. + """ + body: dict[str, Any] = { + "model": model, + "input": _to_responses_input(messages), + "max_output_tokens": int(kwargs.get("max_tokens") or 1024), + } + # NB: ``temperature`` / ``top_p`` are intentionally NOT forwarded — + # see ``_RESPONSES_DROP`` above. GPT-5 reasoning models reject them. + if kwargs.get("tools"): + body["tools"] = _chat_tools_to_responses(kwargs["tools"]) + if kwargs.get("tool_choice") is not None: + body["tool_choice"] = kwargs["tool_choice"] + # Forward other allowed kwargs; silently drop known-unsupported ones. + for k, v in kwargs.items(): + if k in _RESPONSES_DROP or v is None: + continue + if k in body or k in {"max_tokens", "tools", "tool_choice", "temperature", + "top_p", "stream", "stop"}: + continue + body[k] = v + return body + + +def _from_responses(model: str, data: dict) -> dict: + """Convert Responses API output to OpenAI Chat Completions shape. + + Responses returns an array of output items: + * ``{"type":"message","content":[...]}`` → assistant text + * ``{"type":"function_call","name":...}`` → tool_calls + * ``{"type":"reasoning",...}`` → skip (internal thinking) + + Tool calls are converted to the Chat Completions ``tool_calls`` array so + OpenHands can dispatch them unchanged. Without this conversion GPT-5's tool + invocations are silently dropped, causing the "response did not include a + function call" loop. + """ + text_parts: list[str] = [] + tool_calls: list[dict] = [] + + for item in data.get("output") or []: + if not isinstance(item, dict): + continue + + item_type = item.get("type") + + if item_type == "message": + for c in item.get("content") or []: + if isinstance(c, dict) and c.get("type") in ("output_text", "text"): + text_parts.append(c.get("text", "")) + + elif item_type == "function_call": + # Responses API function_call shape: + # {"type": "function_call", "id": "fc_...", "call_id": "call_...", + # "name": "terminal", "arguments": "{\"command\":\"ls\"}"} + # → Chat Completions tool_calls shape: + # {"id": "call_...", "type": "function", + # "function": {"name": "terminal", "arguments": "..."}} + call_id = item.get("call_id") or item.get("id") or f"call_{uuid.uuid4().hex[:8]}" + tool_calls.append({ + "id": call_id, + "type": "function", + "function": { + "name": item.get("name", ""), + "arguments": item.get("arguments", "{}"), + }, + }) + # "reasoning" items are intentionally skipped. + + # Fallback: some responses expose aggregated text at `output_text`. + if not text_parts and not tool_calls and isinstance(data.get("output_text"), str): + text_parts.append(data["output_text"]) + + u = data.get("usage") or {} + usage = { + "prompt_tokens": u.get("input_tokens", 0), + "completion_tokens": u.get("output_tokens", 0), + "total_tokens": u.get("total_tokens", + u.get("input_tokens", 0) + u.get("output_tokens", 0)), + } + finish = "tool_calls" if tool_calls else ( + "stop" if data.get("status") == "completed" else "stop" + ) + + msg: dict[str, Any] = {"role": "assistant", "content": "".join(text_parts) or None} + if tool_calls: + msg["tool_calls"] = tool_calls + + return { + "id": data.get("id", f"resp-{uuid.uuid4().hex[:8]}"), + "object": "chat.completion", + "model": model, + "choices": [{"index": 0, "message": msg, "finish_reason": finish}], + "usage": usage, + } diff --git a/openhands-sdk/openhands/sdk/llm/providers/databricks/pkce.py b/openhands-sdk/openhands/sdk/llm/providers/databricks/pkce.py new file mode 100644 index 0000000000..68cd7dc8c6 --- /dev/null +++ b/openhands-sdk/openhands/sdk/llm/providers/databricks/pkce.py @@ -0,0 +1,166 @@ +"""Databricks U2M OAuth PKCE primitives (shared by web + CLI front-ends). + +These are the dependency-light helpers for the interactive *browser login* +(Authorization Code + PKCE) flow: + +* :func:`generate_pkce` — verifier / S256 challenge pair. +* :func:`build_authorize_url` — Databricks OIDC ``/authorize`` URL. +* :func:`exchange_code_for_tokens` — sync code → tokens exchange. +* :func:`async_exchange_code_for_tokens` — async variant for event-loop callers. + +The provider's :mod:`.auth` module owns token *refresh*; this module owns the +one-time *login*. Both the OpenHands web app and the OpenHands CLI consume these +helpers so the PKCE logic lives in exactly one place. + +The returned token dict round-trips through +:class:`~openhands.sdk.llm.providers.databricks.models.StoredU2MTokens`: +``access_token``, ``refresh_token``, ``expires_at``, ``client_id``, ``host``. + +No ``litellm`` / FastAPI imports here — kept minimal so both front-ends (which +may pin different framework versions) can import it cheaply. +""" + +from __future__ import annotations + +import base64 +import hashlib +import secrets +import time +from typing import Any +from urllib.parse import urlencode + +import httpx + +from openhands.sdk.llm.providers.databricks.utils import USER_AGENT + + +_TOKEN_TIMEOUT_S = 15.0 +_DEFAULT_EXPIRES_IN = 3600 + + +def generate_pkce() -> tuple[str, str]: + """Return ``(verifier, challenge)`` where challenge is S256 of verifier.""" + verifier = base64.urlsafe_b64encode(secrets.token_bytes(32)).rstrip(b"=").decode() + digest = hashlib.sha256(verifier.encode()).digest() + challenge = base64.urlsafe_b64encode(digest).rstrip(b"=").decode() + return verifier, challenge + + +def build_authorize_url( + host: str, + client_id: str, + redirect_uri: str, + state: str, + challenge: str, +) -> str: + """Build the Databricks OIDC authorize URL with PKCE (S256).""" + host = host.rstrip("/") + params = { + "response_type": "code", + "client_id": client_id, + "redirect_uri": redirect_uri, + "scope": "all-apis offline_access", + "state": state, + "code_challenge": challenge, + "code_challenge_method": "S256", + } + return f"{host}/oidc/v1/authorize?{urlencode(params)}" + + +def _build_token_request( + host: str, + client_id: str, + redirect_uri: str, + code: str, + verifier: str, + client_secret: str | None, +) -> tuple[str, dict[str, str]]: + """Return ``(token_url, form_data)`` for the code → token exchange. + + ``client_secret`` is required for **confidential** OAuth apps (apps with a + secret registered in Databricks App connections). Public PKCE apps omit it; + omitting it for a confidential app returns ``{"error": "invalid_client"}``. + """ + host = host.rstrip("/") + token_data: dict[str, str] = { + "grant_type": "authorization_code", + "code": code, + "redirect_uri": redirect_uri, + "client_id": client_id, + "code_verifier": verifier, + } + if client_secret: + token_data["client_secret"] = client_secret + return f"{host}/oidc/v1/token", token_data + + +def _to_stored_payload( + data: dict[str, Any], client_id: str, host: str +) -> dict[str, Any]: + """Shape a Databricks token response into a ``StoredU2MTokens``-compatible dict.""" + return { + "access_token": data["access_token"], + "refresh_token": data.get("refresh_token", ""), + "expires_at": time.time() + data.get("expires_in", _DEFAULT_EXPIRES_IN), + "client_id": client_id, + "host": host.rstrip("/"), + } + + +def exchange_code_for_tokens( + host: str, + client_id: str, + redirect_uri: str, + code: str, + verifier: str, + client_secret: str | None = None, +) -> dict[str, Any]: + """Exchange an authorization code for tokens (synchronous). + + Sends the PWAF ``User-Agent`` on the token request. Returns a dict + compatible with ``StoredU2MTokens.model_validate``. + + Raises: + httpx.HTTPStatusError: if the token endpoint returns a non-2xx status. + """ + token_url, token_data = _build_token_request( + host, client_id, redirect_uri, code, verifier, client_secret + ) + resp = httpx.post( + token_url, + data=token_data, + headers={"User-Agent": USER_AGENT}, + timeout=_TOKEN_TIMEOUT_S, + ) + resp.raise_for_status() + return _to_stored_payload(resp.json(), client_id, host) + + +async def async_exchange_code_for_tokens( + host: str, + client_id: str, + redirect_uri: str, + code: str, + verifier: str, + client_secret: str | None = None, +) -> dict[str, Any]: + """Exchange an authorization code for tokens (asynchronous). + + Identical to :func:`exchange_code_for_tokens` but uses + ``httpx.AsyncClient`` so it does not block the event loop when called from + an ``async`` request handler (e.g. the web app's OAuth callback route). + + Raises: + httpx.HTTPStatusError: if the token endpoint returns a non-2xx status. + """ + token_url, token_data = _build_token_request( + host, client_id, redirect_uri, code, verifier, client_secret + ) + async with httpx.AsyncClient(timeout=_TOKEN_TIMEOUT_S) as client: + resp = await client.post( + token_url, + data=token_data, + headers={"User-Agent": USER_AGENT}, + ) + resp.raise_for_status() + return _to_stored_payload(resp.json(), client_id, host) diff --git a/openhands-sdk/openhands/sdk/llm/providers/databricks/settings_bridge.py b/openhands-sdk/openhands/sdk/llm/providers/databricks/settings_bridge.py new file mode 100644 index 0000000000..57713e2aa9 --- /dev/null +++ b/openhands-sdk/openhands/sdk/llm/providers/databricks/settings_bridge.py @@ -0,0 +1,188 @@ +"""Settings → ``create_llm(...)`` kwargs bridge for the Databricks provider. + +This is the single code path both the OpenHands backend and the OpenHands-CLI +go through when turning user settings (env vars, DB rows, TUI form state) into +kwargs for :func:`openhands.sdk.create_llm`. + +Keeping it in the SDK prevents silent drift: when a new field is added to +``DatabricksLLM``, the contract test in ``test_settings_bridge.py`` fails until +the bridge is extended — which forces a conscious decision about whether the +new field should be exposed in settings UIs. + +Usage: + + from openhands.sdk import create_llm + from openhands.sdk.llm.providers.databricks.settings_bridge import ( + kwargs_from_settings, + ) + + kwargs = kwargs_from_settings(user, usage_id="agent") + llm = create_llm(**kwargs) + +The ``settings`` argument is deliberately duck-typed (Protocol, not a concrete +class). Any object exposing a subset of the attribute names below works: +pydantic models (``UserInfo``, ``CliSettings``, ``LLMEnvOverrides``), dataclasses, +``SimpleNamespace``, or plain ``dict``-wrappers. +""" + +from __future__ import annotations + +from typing import Any + +from pydantic import SecretStr + +from openhands.sdk.llm.providers.databricks.models import StoredU2MTokens + + +# Fields the bridge recognizes. Every public, user-settable field on +# ``DatabricksLLM`` (and the subset of base ``LLM`` fields that UIs expose) +# must appear here. Enforced by +# ``test_bridge_covers_all_databricks_llm_public_fields``. +_BRIDGE_FIELDS: tuple[str, ...] = ( + # --- Base LLM fields commonly set from UI --- + "model", + "api_key", + "base_url", + "timeout", + "max_input_tokens", + # --- Databricks-specific --- + "databricks_host", + "databricks_ai_gateway_host", + "databricks_metadata_probe", + "databricks_client_id", + "databricks_client_secret", + "databricks_profile", + "databricks_ssl_verify", + "databricks_max_retries", + "databricks_connect_timeout_s", + "databricks_read_timeout_s", + "databricks_chunk_timeout_s", + "stored_u2m_tokens", + "databricks_u2m_client_id", + "databricks_u2m_client_secret", + "databricks_u2m_redirect_uri", +) + +# Fields present on ``DatabricksLLM`` that are deliberately NOT bridged from +# settings — either internal pydantic discriminators or private state. +_NOT_BRIDGED: frozenset[str] = frozenset( + { + "provider", # Literal discriminator, always "databricks" + } +) + +_SECRET_FIELDS: frozenset[str] = frozenset( + { + "api_key", + "databricks_client_secret", + "databricks_u2m_client_secret", + } +) + + +#: Useful when the settings object uses a different attribute name for a +#: bridged field, e.g. OpenHands' ``UserInfo`` uses ``llm_model`` / +#: ``llm_api_key`` / ``llm_base_url``. The bridge first tries the canonical +#: attribute name, then each alias in order. +UserInfoAliases: dict[str, tuple[str, ...]] = { + "model": ("llm_model",), + "api_key": ("llm_api_key",), + "base_url": ("llm_base_url",), + # OpenHands web app stores the Databricks workspace URL in llm_base_url. + # Fall back to it when databricks_host is not set as a dedicated field. + "databricks_host": ("llm_base_url",), +} + + +def kwargs_from_settings( + settings: Any, + *, + usage_id: str, + model_override: str | None = None, + base_url_fallback: str | None = None, + extras: dict[str, Any] | None = None, + aliases: dict[str, tuple[str, ...]] | None = None, +) -> dict[str, Any]: + """Build a kwargs dict ready for ``openhands.sdk.create_llm(**kwargs)``. + + Behavior: + + * Attributes are read via ``getattr`` — missing attrs are skipped (so the + same bridge works for partial settings objects). + * ``None`` and empty-string values are dropped so pydantic defaults apply. + * Secret fields (``api_key``, ``databricks_client_secret``) are coerced + to :class:`pydantic.SecretStr` if supplied as bare strings. + * ``stored_u2m_tokens`` accepts both :class:`StoredU2MTokens` instances + and plain dicts (validated with ``model_validate``; invalid dicts are + silently dropped). + * ``model_override`` wins over ``settings.model`` when both are set. + * ``base_url_fallback`` is only applied when *neither* ``base_url`` nor + ``databricks_host`` is present (preserves existing callers' + host-vs-base-url disambiguation). + * ``extras`` are merged last and win over everything else — use this for + per-request overrides like session U2M tokens. + * ``usage_id`` is always set; it's the one field not read from settings. + + Args: + settings: Any object exposing a subset of the bridged field names. + usage_id: Per-call usage id (``"agent"``, ``"condenser"``, ...). + model_override: Replaces ``settings.model`` when non-None. + base_url_fallback: Applied only when neither ``base_url`` nor + ``databricks_host`` is populated. + extras: Last-write-wins overrides. + aliases: Optional map of canonical field → fallback attribute names. + Tried in order after the canonical name itself. Convenient for + settings shapes like OpenHands' ``UserInfo`` that prefix fields + with ``llm_`` — pass :data:`UserInfoAliases`. + + Returns: + A dict safe to splat into :func:`openhands.sdk.create_llm`. + """ + kwargs: dict[str, Any] = {"usage_id": usage_id} + aliases = aliases or {} + + for field in _BRIDGE_FIELDS: + val = getattr(settings, field, None) + if val is None or val == "": + for alias in aliases.get(field, ()): + val = getattr(settings, alias, None) + if val not in (None, ""): + break + if val is None or val == "": + continue + if field in _SECRET_FIELDS and not isinstance(val, SecretStr): + val = SecretStr(str(val)) + if field == "stored_u2m_tokens" and isinstance(val, dict): + try: + val = StoredU2MTokens.model_validate(val) + except Exception: + continue + kwargs[field] = val + + if model_override is not None: + kwargs["model"] = model_override + + if ( + "base_url" not in kwargs + and "databricks_host" not in kwargs + and base_url_fallback + ): + kwargs["base_url"] = base_url_fallback + + if extras: + for k, v in extras.items(): + if v is None: + continue + if k in _SECRET_FIELDS and not isinstance(v, SecretStr): + v = SecretStr(str(v)) + kwargs[k] = v + + return kwargs + + +__all__ = [ + "kwargs_from_settings", + "UserInfoAliases", + "_BRIDGE_FIELDS", + "_NOT_BRIDGED", +] diff --git a/openhands-sdk/openhands/sdk/llm/providers/databricks/utils.py b/openhands-sdk/openhands/sdk/llm/providers/databricks/utils.py new file mode 100644 index 0000000000..64445931f2 --- /dev/null +++ b/openhands-sdk/openhands/sdk/llm/providers/databricks/utils.py @@ -0,0 +1,270 @@ +"""Databricks FMAPI resilience utilities. + +Provides: + USER_AGENT — PWAF-required constant; set once, applied to ALL Databricks HTTP calls. + fetch_with_retry — synchronous retry loop with exponential back-off + Retry-After. + Helper functions: _log_retry, _raise_non_retryable, _raise_mapped, compute_backoff, + normalize_host, map_databricks_error, validate_databricks_config. +""" + +from __future__ import annotations + +import importlib.metadata +import logging +import random +import time +from dataclasses import dataclass +from typing import TYPE_CHECKING + +import httpx +from litellm.exceptions import ( + APIConnectionError, + AuthenticationError, + BadRequestError, + RateLimitError, + ServiceUnavailableError, +) + +if TYPE_CHECKING: + from openhands.sdk.llm.providers.databricks.auth import AuthStrategy + +logger = logging.getLogger(__name__) + +# --------------------------------------------------------------------------- +# PWAF: User-Agent constant +# Must be set once at module load time and applied to ALL Databricks HTTP calls. +# Never re-imported per request. Never user-configurable. +# --------------------------------------------------------------------------- +def _get_version() -> str: + try: + return importlib.metadata.version("openhands-sdk") + except importlib.metadata.PackageNotFoundError: + return "unknown" + + +USER_AGENT: str = f"OpenHandsOSS/{_get_version()}" +"""User-Agent for the OpenHands OSS Databricks connector. + +Format: OpenHandsOSS/ + +- Product: OpenHandsOSS (matches the runtime plugin and env vars; + one consistent identity across all Databricks calls) +- Version: resolved from the installed `openhands-sdk` package metadata. + +Applied to every Databricks HTTP call (AI Gateway, OAuth token endpoint, +serving-endpoints discovery). Never exposed as a user config knob. +""" + + +# --------------------------------------------------------------------------- +# Timeout configuration +# --------------------------------------------------------------------------- +@dataclass +class DatabricksTimeouts: + connect_s: float = 10.0 # TCP + TLS; fail fast on unreachable host + read_s: float = 120.0 # Non-streaming: full response wait + chunk_s: float = 30.0 # Streaming: per-chunk idle timeout (resets per chunk) + pool_s: float = 5.0 # Wait for connection from pool + + +# --------------------------------------------------------------------------- +# Retry tables +# --------------------------------------------------------------------------- +RETRYABLE_STATUS_CODES: frozenset[int] = frozenset({429, 500, 502, 503, 504}) +NON_RETRYABLE_STATUS_CODES: frozenset[int] = frozenset({400, 401, 403, 404, 422}) +RETRYABLE_EXCEPTIONS: tuple[type[Exception], ...] = ( + httpx.ConnectError, + httpx.ReadTimeout, + httpx.RemoteProtocolError, +) + +STATUS_TO_LITELLM: dict[int, type[Exception]] = { + 429: RateLimitError, + 500: APIConnectionError, + 502: ServiceUnavailableError, + 503: ServiceUnavailableError, + 504: ServiceUnavailableError, + 400: BadRequestError, + 401: AuthenticationError, + 403: AuthenticationError, + 404: BadRequestError, + 422: BadRequestError, +} + +# Hard cap on Retry-After header value to prevent runaway sleeps from misbehaving proxies. +_RETRY_AFTER_MAX_S: float = 300.0 + + +# --------------------------------------------------------------------------- +# Retry helper functions (P0-2: were called but not defined in P3 plan) +# --------------------------------------------------------------------------- + +def _log_retry( + attempt: int, + max_retries: int, + status_code: int, + wait_s: float, + url: str, + response_headers: httpx.Headers, +) -> None: + """Log a retry event. Never logs credential values.""" + logger.warning( + "databricks_fmapi_retry", + extra={ + "attempt": attempt + 1, + "max_retries": max_retries, + "status_code": status_code, + "wait_s": round(wait_s, 2), + "request_id": response_headers.get("x-request-id"), + "url": url, + # Intentionally NOT logging: Authorization header, token, or secret + }, + ) + + +def _raise_non_retryable(response: httpx.Response) -> None: + """Raise the appropriate LiteLLM exception for a non-retryable status code. + + Called immediately (no sleep) for 400/401/403/404/422. + """ + raw_text = response.text[:500] if response.text else "" + try: + body = response.json() + except Exception: + body = {} + msg = map_databricks_error(response.status_code, body) + # Include raw response body in the message when no structured error field was found + if msg.endswith("Unknown error") and raw_text: + msg = f"{msg} | url={response.url} | body={raw_text}" + exc_class = STATUS_TO_LITELLM.get(response.status_code, BadRequestError) + raise exc_class(msg, model="", llm_provider="databricks") + + +def _raise_mapped(response: httpx.Response) -> None: + """Raise LiteLLM exception for a retryable status after all retries are exhausted.""" + try: + body = response.json() + except Exception: + body = {} + msg = map_databricks_error(response.status_code, body) + exc_class = STATUS_TO_LITELLM.get(response.status_code, APIConnectionError) + raise exc_class(msg, model="", llm_provider="databricks") + + +# --------------------------------------------------------------------------- +# Backoff and retry loop +# --------------------------------------------------------------------------- + +def compute_backoff(attempt: int, retry_after: str | None = None) -> float: + """Compute sleep duration for a retry attempt. + + Retry-After header wins but is capped at _RETRY_AFTER_MAX_S to prevent + runaway sleeps from misbehaving proxies. Falls back to full-jitter + exponential backoff: sleep in [0, min(60, 1 * 2^attempt)]. + """ + if retry_after: + return min(float(retry_after), _RETRY_AFTER_MAX_S) + return min(60.0, 1.0 * (2**attempt)) * random.uniform(0, 1) + + +def fetch_with_retry( + client: httpx.Client, + url: str, + headers: dict, + json: dict, + max_retries: int = 3, +) -> httpx.Response: + """Synchronous retry loop for FMAPI POST calls. + + Uses time.sleep (NOT asyncio.sleep) — _transport_call is always synchronous. + On exhaustion of retries, raises the mapped LiteLLM exception. + """ + last_exc: Exception | None = None + for attempt in range(max_retries + 1): + try: + response = client.post(url, headers=headers, json=json) + if response.status_code in RETRYABLE_STATUS_CODES and attempt < max_retries: + wait = compute_backoff( + attempt, response.headers.get("Retry-After") + ) + _log_retry(attempt, max_retries, response.status_code, wait, url, response.headers) + time.sleep(wait) + continue + if response.status_code in NON_RETRYABLE_STATUS_CODES: + _raise_non_retryable(response) + if response.status_code in RETRYABLE_STATUS_CODES: + # Exhausted retries on a retryable status code + _raise_mapped(response) + return response + except RETRYABLE_EXCEPTIONS as exc: + last_exc = exc + if attempt == max_retries: + raise APIConnectionError( + str(exc), model="", llm_provider="databricks" + ) from exc + time.sleep(compute_backoff(attempt)) + + # Unreachable; satisfies type checker + raise APIConnectionError( + f"Retry loop exhausted: {last_exc}", model="", llm_provider="databricks" + ) + + +# --------------------------------------------------------------------------- +# Miscellaneous helpers +# --------------------------------------------------------------------------- + +def normalize_host(host: str) -> str: + """Ensure host has https:// scheme and no trailing slash.""" + host = host.strip().rstrip("/") + if not host.startswith("https://"): + host = f"https://{host}" + return host + + +def map_databricks_error(status: int, body: dict) -> str: + """Extract human-readable error message from FMAPI error response body.""" + msg = ( + body.get("message") + or body.get("error_description") + or body.get("error") + or "Unknown error" + ) + return f"[{status}] {msg}" + + +def validate_databricks_config( + host: str | None, + strategy: "AuthStrategy", + **creds: object, +) -> None: + """Pre-flight validation — raises ValueError with actionable messages. + + Called during DatabricksLLM._init_databricks() so configuration errors + surface at construction time rather than at first inference call. + """ + if not host: + raise ValueError( + "Databricks host is required. Set databricks_host= or base_url= " + "to your workspace URL (e.g. https://adb-xxx.azuredatabricks.net)" + ) + if not host.startswith("https://"): + raise ValueError( + f"databricks_host must start with 'https://'. Got: {host!r}" + ) + + # Import AuthStrategy here to avoid circular import at module level + from openhands.sdk.llm.providers.databricks.auth import AuthStrategy as _AS + + if strategy == _AS.U2M and not creds.get("stored_tokens"): + raise ValueError( + "U2M auth requires stored OAuth tokens. Complete the browser login flow " + "at /auth/databricks/initiate first." + ) + if strategy == _AS.M2M: + if not creds.get("client_id") or not creds.get("client_secret"): + raise ValueError( + "M2M auth requires DATABRICKS_CLIENT_ID and DATABRICKS_CLIENT_SECRET " + "(service principal credentials). Note: these are DIFFERENT from " + "DATABRICKS_U2M_CLIENT_ID used for browser OAuth login." + ) diff --git a/openhands-sdk/openhands/sdk/llm/utils/model_features.py b/openhands-sdk/openhands/sdk/llm/utils/model_features.py index 34dc4a90f9..8d3d631c53 100644 --- a/openhands-sdk/openhands/sdk/llm/utils/model_features.py +++ b/openhands-sdk/openhands/sdk/llm/utils/model_features.py @@ -215,6 +215,26 @@ def _supports_reasoning_effort(model: str | None) -> bool: def get_features(model: str) -> ModelFeatures: """Get model features.""" + # Databricks FMAPI models: return a fixed feature set. + # FMAPI does not support Anthropic prompt caching, extended thinking, or + # the Responses API — even for Claude-based endpoints. Standard tool calling + # (OpenAI wire format) and stop words are supported. + # This early return prevents Databricks Claude model names (which contain + # "claude-3-7-sonnet" etc.) from incorrectly matching PROMPT_CACHE_MODELS + # and other per-provider pattern lists. + if (model or "").startswith("databricks/"): + return ModelFeatures( + supports_reasoning_effort=False, + supports_extended_thinking=False, + supports_prompt_cache=False, + supports_stop_words=True, + supports_responses_api=False, + force_string_serializer=False, + send_reasoning_content=False, + supports_prompt_cache_retention=False, + requires_inline_image_data=False, + ) + return ModelFeatures( supports_reasoning_effort=_supports_reasoning_effort(model), supports_extended_thinking=model_matches(model, EXTENDED_THINKING_MODELS), diff --git a/openhands-sdk/pyproject.toml b/openhands-sdk/pyproject.toml index 4161abd0af..22851fbb2b 100644 --- a/openhands-sdk/pyproject.toml +++ b/openhands-sdk/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "openhands-sdk" -version = "1.26.0" +version = "1.27.0" description = "OpenHands SDK - Core functionality for building AI agents" requires-python = ">=3.12" @@ -32,6 +32,9 @@ Documentation = "https://docs.openhands.dev/sdk" [project.optional-dependencies] boto3 = ["boto3>=1.35.0"] +# Optional dependency for PROFILE and UNIFIED auth strategies only. +# U2M (browser PKCE), M2M (client credentials), and PAT have zero extra dependencies. +databricks = ["databricks-sdk>=0.20.0"] [build-system] requires = ["setuptools>=61.0", "wheel"] diff --git a/tests/sdk/llm/providers/__init__.py b/tests/sdk/llm/providers/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/sdk/llm/providers/databricks/__init__.py b/tests/sdk/llm/providers/databricks/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/sdk/llm/providers/databricks/conftest.py b/tests/sdk/llm/providers/databricks/conftest.py new file mode 100644 index 0000000000..d1fd734c7e --- /dev/null +++ b/tests/sdk/llm/providers/databricks/conftest.py @@ -0,0 +1,82 @@ +"""Shared fixtures for the Databricks provider test suite. + +This conftest exists to make the suite deterministic and fast regardless of the +developer machine's local Databricks state. It does two things on every test: + +1. Scrubs ``DATABRICKS_*`` environment variables so tests cannot accidentally + pick up credentials or a host from the developer's shell. +2. Replaces ``databricks.sdk.WorkspaceClient`` with a MagicMock so any code path + that reaches PROFILE or UNIFIED auth resolution does not attempt a real + network call or OAuth browser flow (which is what caused the multi-minute + test hang when ``~/.databrickscfg`` contained a U2M profile). + +Individual tests that need to exercise the real ``WorkspaceClient`` constructor +or a specific env var can override these fixtures locally with ``monkeypatch``. +""" + +from __future__ import annotations + +import sys +from unittest.mock import MagicMock + +import pytest + + +_DATABRICKS_ENV_VARS: tuple[str, ...] = ( + "DATABRICKS_HOST", + "DATABRICKS_TOKEN", + "DATABRICKS_ACCESS_TOKEN", + "DATABRICKS_CLIENT_ID", + "DATABRICKS_CLIENT_SECRET", + "DATABRICKS_U2M_CLIENT_ID", + "DATABRICKS_CONFIG_PROFILE", + "DATABRICKS_CONFIG_FILE", + "DATABRICKS_AUTH_TYPE", + "DATABRICKS_CLUSTER_ID", + "DATABRICKS_WAREHOUSE_ID", +) + + +@pytest.fixture(autouse=True) +def _scrub_databricks_env(monkeypatch: pytest.MonkeyPatch) -> None: + """Remove every DATABRICKS_* env var for the duration of a test. + + Prevents credential leakage from the developer's shell into tests and stops + ``resolve_credentials`` from falling through to UNIFIED auth, which would + construct a real ``WorkspaceClient``. + """ + for name in _DATABRICKS_ENV_VARS: + monkeypatch.delenv(name, raising=False) + + +@pytest.fixture(autouse=True) +def _mock_workspace_client(monkeypatch: pytest.MonkeyPatch) -> MagicMock: + """Replace ``databricks.sdk.WorkspaceClient`` with a MagicMock. + + Safe no-op when ``databricks-sdk`` is not installed. When it is installed, + this prevents any accidental real constructor call (which can trigger an + OAuth browser flow or hang on token refresh) from reaching the network. + + Returns the mock class so tests that need to assert on constructor calls + can request the fixture by name. + """ + mock_cls = MagicMock(name="WorkspaceClient") + mock_instance = MagicMock(name="WorkspaceClient_instance") + mock_instance.config.authenticate.return_value = { + "Authorization": "Bearer mock-unified-token" + } + mock_cls.return_value = mock_instance + + # Patch the already-imported module if present; otherwise inject a shim so + # ``from databricks.sdk import WorkspaceClient`` inside auth.py resolves to + # our mock regardless of whether the real package is installed. + if "databricks.sdk" in sys.modules: + monkeypatch.setattr( + "databricks.sdk.WorkspaceClient", mock_cls, raising=False + ) + else: + sdk_mod = MagicMock(name="databricks.sdk") + sdk_mod.WorkspaceClient = mock_cls + monkeypatch.setitem(sys.modules, "databricks.sdk", sdk_mod) + + return mock_cls diff --git a/tests/sdk/llm/providers/databricks/test_auth.py b/tests/sdk/llm/providers/databricks/test_auth.py new file mode 100644 index 0000000000..8616de9c5c --- /dev/null +++ b/tests/sdk/llm/providers/databricks/test_auth.py @@ -0,0 +1,380 @@ +"""Tests for Databricks FMAPI authentication strategies. + +Covers: M2MTokenProvider (double-checked locking, proactive refresh, scope=all-apis), +PAT path, U2M priority precedence, host resolution order, resolve_credentials dispatch. +""" + +from __future__ import annotations + +import time +from unittest.mock import MagicMock, patch + +import httpx +import pytest +from pydantic import SecretStr + +from openhands.sdk.llm.providers.databricks.auth import ( + AuthStrategy, + DatabricksCredentials, + M2MTokenProvider, + _resolve_m2m, + _resolve_u2m, + resolve_credentials, +) +from openhands.sdk.llm.providers.databricks.models import StoredU2MTokens +from openhands.sdk.llm.providers.databricks.utils import USER_AGENT + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +_HOST = "https://adb-123.azuredatabricks.net" +_M2M_TOKEN_URL = f"{_HOST}/oidc/v1/token" + + +def _make_m2m_response(token: str = "m2m-token", expires_in: int = 3600) -> httpx.Response: + # httpx requires `request` on Response for raise_for_status() (used in production code). + req = httpx.Request("POST", _M2M_TOKEN_URL) + return httpx.Response( + 200, json={"access_token": token, "expires_in": expires_in}, request=req + ) + + +def _make_refresh_response(token: str = "new-access-token", expires_in: int = 3600) -> httpx.Response: + req = httpx.Request("POST", _M2M_TOKEN_URL) + return httpx.Response( + 200, json={"access_token": token, "expires_in": expires_in}, request=req + ) + + +# --------------------------------------------------------------------------- +# M2MTokenProvider +# --------------------------------------------------------------------------- + +class TestM2MTokenProvider: + def test_constructor_accepts_host_client_id_secret(self) -> None: + """P0-3: constructor signature must accept host, client_id, client_secret.""" + provider = M2MTokenProvider( + host=_HOST, + client_id="client-id", + client_secret="client-secret", + ) + assert provider._host == _HOST + assert provider._client_id == "client-id" + assert provider._client_secret == "client-secret" + assert provider._token is None + assert provider._expires_at == 0.0 + + def test_get_token_fetches_on_first_call(self) -> None: + """get_token() must call _fetch_new_token on first call (no cached token).""" + provider = M2MTokenProvider(_HOST, "cid", "csecret") + with patch.object(provider, "_fetch_new_token", return_value=("tok-1", time.time() + 7200)) as mock_fetch: + token = provider.get_token() + assert token == "tok-1" + mock_fetch.assert_called_once() + + def test_get_token_uses_cached_token_when_fresh(self) -> None: + """get_token() must return cached token without re-fetching when > 5min remaining.""" + provider = M2MTokenProvider(_HOST, "cid", "csecret") + provider._token = "cached-tok" + provider._expires_at = time.time() + 3600 # 1h remaining + with patch.object(provider, "_fetch_new_token") as mock_fetch: + token = provider.get_token() + assert token == "cached-tok" + mock_fetch.assert_not_called() + + def test_get_token_refreshes_when_near_expiry(self) -> None: + """get_token() must refresh when < 5min (300s) remaining (proactive refresh).""" + provider = M2MTokenProvider(_HOST, "cid", "csecret") + provider._token = "expiring-tok" + provider._expires_at = time.time() + 100 # 100s remaining < 300s threshold + with patch.object(provider, "_fetch_new_token", return_value=("fresh-tok", time.time() + 7200)): + token = provider.get_token() + assert token == "fresh-tok" + + def test_fetch_new_token_sends_scope_all_apis(self) -> None: + """_fetch_new_token must include scope=all-apis in the token request.""" + provider = M2MTokenProvider(_HOST, "test-cid", "test-secret") + captured_data: dict = {} + + def mock_post(url, data=None, headers=None, timeout=None): + captured_data.update(data or {}) + return _make_m2m_response() + + with patch("httpx.post", side_effect=mock_post): + provider._fetch_new_token() + + assert captured_data.get("scope") == "all-apis" + assert captured_data.get("grant_type") == "client_credentials" + assert captured_data.get("client_id") == "test-cid" + assert captured_data.get("client_secret") == "test-secret" + + def test_fetch_new_token_sends_user_agent(self) -> None: + """_fetch_new_token must include PWAF User-Agent header.""" + provider = M2MTokenProvider(_HOST, "cid", "secret") + captured_headers: dict = {} + + def mock_post(url, data=None, headers=None, timeout=None): + captured_headers.update(headers or {}) + return _make_m2m_response() + + with patch("httpx.post", side_effect=mock_post): + provider._fetch_new_token() + + assert captured_headers.get("User-Agent") == USER_AGENT + + def test_fetch_new_token_raises_on_http_error(self) -> None: + """_fetch_new_token must propagate HTTP errors from the token endpoint.""" + provider = M2MTokenProvider(_HOST, "cid", "secret") + req = httpx.Request("POST", _M2M_TOKEN_URL) + error_resp = httpx.Response(401, json={"message": "Unauthorized"}, request=req) + + with patch("httpx.post", return_value=error_resp): + with pytest.raises(httpx.HTTPStatusError): + provider._fetch_new_token() + + +# --------------------------------------------------------------------------- +# _resolve_u2m +# --------------------------------------------------------------------------- + +def _make_stored_tokens( + access_token: str = "u2m-access", + refresh_token: str = "u2m-refresh", + expires_at: float | None = None, + client_id: str = "u2m-cid", + host: str = _HOST, +) -> StoredU2MTokens: + return StoredU2MTokens( + access_token=access_token, + refresh_token=refresh_token, + expires_at=expires_at or (time.time() + 3600), + client_id=client_id, + host=host, + ) + + +def test_u2m_resolve_returns_credentials_with_u2m_method() -> None: + stored = _make_stored_tokens() + creds = _resolve_u2m(_HOST, stored) + assert isinstance(creds, DatabricksCredentials) + assert creds.auth_method == "u2m" + assert creds.host == _HOST + + +def test_u2m_get_token_returns_current_token_when_fresh() -> None: + """U2M: get_token() returns the stored access token without HTTP when still fresh.""" + stored = _make_stored_tokens(access_token="fresh-token", expires_at=time.time() + 3600) + creds = _resolve_u2m(_HOST, stored) + with patch("httpx.post") as mock_post: + token = creds.get_token() + assert token == "fresh-token" + mock_post.assert_not_called() + + +def test_u2m_get_token_refreshes_when_near_expiry() -> None: + """U2M: get_token() calls token endpoint when token is near expiry.""" + stored = _make_stored_tokens(access_token="old-token", expires_at=time.time() + 100) + creds = _resolve_u2m(_HOST, stored) + + def mock_post(url, data=None, headers=None, timeout=None): + return _make_refresh_response(token="refreshed-token") + + with patch("httpx.post", side_effect=mock_post): + token = creds.get_token() + + assert token == "refreshed-token" + + +def test_u2m_refresh_uses_no_client_secret() -> None: + """U2M PKCE refresh must NOT send client_secret (public client).""" + stored = _make_stored_tokens(expires_at=time.time() + 100) + creds = _resolve_u2m(_HOST, stored) + captured_data: dict = {} + + def mock_post(url, data=None, headers=None, timeout=None): + captured_data.update(data or {}) + return _make_refresh_response() + + with patch("httpx.post", side_effect=mock_post): + creds.get_token() + + assert "client_secret" not in captured_data + assert captured_data.get("grant_type") == "refresh_token" + + +def test_u2m_refresh_sends_client_secret_for_confidential_apps() -> None: + """U2M confidential-app refresh MUST include client_secret.""" + stored = _make_stored_tokens(expires_at=time.time() + 100) + creds = _resolve_u2m(_HOST, stored, client_secret="my-secret") + captured_data: dict = {} + + def mock_post(url, data=None, headers=None, timeout=None): + captured_data.update(data or {}) + return _make_refresh_response() + + with patch("httpx.post", side_effect=mock_post): + creds.get_token() + + assert captured_data.get("client_secret") == "my-secret" + assert captured_data.get("grant_type") == "refresh_token" + + +def test_u2m_refresh_failure_raises_auth_error() -> None: + """U2M refresh HTTP error → AuthenticationError with re-auth guidance.""" + from litellm.exceptions import AuthenticationError + + stored = _make_stored_tokens(expires_at=time.time() + 100) + creds = _resolve_u2m(_HOST, stored) + + with patch("httpx.post", return_value=httpx.Response(401, json={"error": "invalid_grant"})): + with pytest.raises(AuthenticationError, match="Re-authenticate"): + creds.get_token() + + +# --------------------------------------------------------------------------- +# resolve_credentials — priority chain +# --------------------------------------------------------------------------- + +def _make_mock_llm( + databricks_host: str = _HOST, + api_key: SecretStr | None = None, + stored_u2m_tokens: StoredU2MTokens | None = None, + databricks_client_id: str | None = None, + databricks_client_secret: SecretStr | None = None, + databricks_u2m_client_secret: SecretStr | None = None, + databricks_profile: str | None = None, + base_url: str | None = None, +) -> MagicMock: + """Return a MagicMock shaped like DatabricksLLM for testing resolve_credentials.""" + llm = MagicMock() + llm.databricks_host = databricks_host + llm.base_url = base_url + llm.api_key = api_key + llm.stored_u2m_tokens = stored_u2m_tokens + llm.databricks_client_id = databricks_client_id + llm.databricks_client_secret = databricks_client_secret + llm.databricks_u2m_client_secret = databricks_u2m_client_secret + llm.databricks_profile = databricks_profile + return llm + + +def test_resolve_credentials_u2m_wins_over_all() -> None: + """U2M stored tokens take highest priority.""" + stored = _make_stored_tokens() + llm = _make_mock_llm( + stored_u2m_tokens=stored, + api_key=SecretStr("pat-token"), + databricks_client_id="m2m-cid", + databricks_client_secret=SecretStr("m2m-secret"), + ) + creds = resolve_credentials(llm) + assert creds.auth_method == "u2m" + + +def test_resolve_credentials_u2m_forwards_client_secret() -> None: + """resolve_credentials passes databricks_u2m_client_secret to _resolve_u2m.""" + stored = _make_stored_tokens(expires_at=time.time() + 100) + llm = _make_mock_llm( + stored_u2m_tokens=stored, + databricks_u2m_client_secret=SecretStr("confidential-secret"), + ) + creds = resolve_credentials(llm) + assert creds.auth_method == "u2m" + + captured_data: dict = {} + + def mock_post(url, data=None, headers=None, timeout=None): + captured_data.update(data or {}) + return _make_refresh_response() + + with patch("httpx.post", side_effect=mock_post): + creds.get_token() + + assert captured_data.get("client_secret") == "confidential-secret" + + +def test_resolve_credentials_pat_path() -> None: + """PAT is used when no U2M or M2M credentials are present.""" + llm = _make_mock_llm(api_key=SecretStr("dapi-test")) + creds = resolve_credentials(llm) + assert creds.auth_method == "pat" + assert creds.get_token() == "dapi-test" + + +def test_resolve_credentials_m2m_over_pat() -> None: + """M2M takes priority over PAT when both are present.""" + llm = _make_mock_llm( + api_key=SecretStr("dapi-pat"), + databricks_client_id="m2m-cid", + databricks_client_secret=SecretStr("m2m-secret"), + ) + with patch( + "openhands.sdk.llm.providers.databricks.auth.M2MTokenProvider._fetch_new_token", + return_value=("m2m-token", time.time() + 3600), + ): + creds = resolve_credentials(llm) + assert creds.auth_method == "m2m" + + +def test_resolve_credentials_pat_does_not_require_host() -> None: + """PAT auth must succeed without a workspace host (token goes to AI Gateway).""" + llm = _make_mock_llm(databricks_host=None, base_url=None, api_key=SecretStr("tok")) + creds = resolve_credentials(llm) + assert creds.auth_method == "pat" + assert creds.host == "" + + +def test_resolve_credentials_unified_requires_host() -> None: + """Unified auth (no api_key, no profile, no creds) needs the workspace host.""" + llm = _make_mock_llm(databricks_host=None, base_url=None) + with pytest.raises(ValueError, match="databricks_host is required"): + resolve_credentials(llm) + + +def test_resolve_credentials_host_from_base_url() -> None: + """Host falls back to base_url if databricks_host is not set.""" + llm = _make_mock_llm(databricks_host=None, base_url=_HOST, api_key=SecretStr("tok")) + creds = resolve_credentials(llm) + assert creds.host == _HOST + + +def test_resolve_credentials_host_from_stored_tokens() -> None: + """Host falls back to stored_u2m_tokens.host as last resort.""" + stored = _make_stored_tokens(host=_HOST) + llm = _make_mock_llm(databricks_host=None, base_url=None, stored_u2m_tokens=stored) + creds = resolve_credentials(llm) + assert creds.host == _HOST + + +def test_resolve_credentials_profile_raises_without_sdk() -> None: + """PROFILE strategy raises ImportError with install hint if databricks-sdk absent. + + The import check is deferred to get_token() so that saving settings succeeds + even without the package installed; the error surfaces at first API call. + """ + llm = _make_mock_llm(databricks_profile="my-profile") + with patch.dict("sys.modules", {"databricks": None, "databricks.sdk": None}): + # resolve_credentials succeeds (returns a DatabricksCredentials object) + creds = resolve_credentials(llm) + assert creds.auth_method == "profile" + # The ImportError surfaces when the token is actually requested + with pytest.raises((ImportError, Exception)): + creds.get_token() + + +def test_resolve_credentials_unified_raises_without_sdk() -> None: + """UNIFIED strategy raises ImportError with install hint if databricks-sdk absent. + + The import check is deferred to get_token() so that saving settings succeeds + even without the package installed; the error surfaces at first API call. + """ + llm = _make_mock_llm() # no api_key, no profile → falls through to unified + with patch.dict("sys.modules", {"databricks": None, "databricks.sdk": None}): + # resolve_credentials succeeds + creds = resolve_credentials(llm) + assert creds.auth_method == "unified" + # The ImportError surfaces when the token is actually requested + with pytest.raises((ImportError, Exception)): + creds.get_token() diff --git a/tests/sdk/llm/providers/databricks/test_client.py b/tests/sdk/llm/providers/databricks/test_client.py new file mode 100644 index 0000000000..453c7f218e --- /dev/null +++ b/tests/sdk/llm/providers/databricks/test_client.py @@ -0,0 +1,418 @@ +"""Tests for DatabricksFMAPIClient. + +Covers: User-Agent header presence (PWAF), _parse_response, _build_stream_response, +streaming accumulation, __del__ cleanup (P1-1), and URL construction. +""" + +from __future__ import annotations + +from unittest.mock import MagicMock, patch + +import httpx +import pytest + +from openhands.sdk.llm.providers.databricks.auth import DatabricksCredentials +from openhands.sdk.llm.providers.databricks.client import DatabricksFMAPIClient +from openhands.sdk.llm.providers.databricks.models import ProviderFamily +from openhands.sdk.llm.providers.databricks.utils import DatabricksTimeouts, USER_AGENT + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +_HOST = "https://adb-123.azuredatabricks.net" + + +def _make_credentials(token: str = "test-token") -> DatabricksCredentials: + return DatabricksCredentials( + host=_HOST, + get_token=lambda: token, + auth_method="pat", + ) + + +_GATEWAY = _HOST # tests use the same URL for both surfaces by default + + +def _make_client( + credentials: DatabricksCredentials | None = None, + max_retries: int = 0, + metadata_probe: bool = False, + ai_gateway_host: str = _GATEWAY, +) -> DatabricksFMAPIClient: + creds = credentials or _make_credentials() + return DatabricksFMAPIClient( + credentials=creds, + ai_gateway_host=ai_gateway_host, + timeouts=DatabricksTimeouts(), + max_retries=max_retries, + metadata_probe=metadata_probe, + ) + + +def test_client_requires_some_host() -> None: + """Constructor must reject the case where neither ai_gateway_host nor + credentials.host is provided — there's no URL to route invocations to.""" + creds_no_host = DatabricksCredentials( + host="", get_token=lambda: "tok", auth_method="pat" + ) + with pytest.raises(ValueError, match="must be provided"): + DatabricksFMAPIClient( + credentials=creds_no_host, + ai_gateway_host=None, + timeouts=DatabricksTimeouts(), + ) + + +def test_client_defaults_gateway_host_to_credentials_host() -> None: + """When no ai_gateway_host override is given, the workspace host + (credentials.host) becomes the gateway base.""" + client = DatabricksFMAPIClient( + credentials=_make_credentials(), + ai_gateway_host=None, + timeouts=DatabricksTimeouts(), + ) + assert client._ai_gateway_host == _HOST + + +def _make_success_response(model: str = "test-model") -> dict: + return { + "id": "chatcmpl-test", + "object": "chat.completion", + "choices": [ + { + "index": 0, + "message": {"role": "assistant", "content": "Hello!"}, + "finish_reason": "stop", + } + ], + "usage": {"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15}, + "model": model, + } + + +# --------------------------------------------------------------------------- +# _make_headers — PWAF User-Agent +# --------------------------------------------------------------------------- + +def test_make_headers_includes_user_agent() -> None: + """PWAF: every request must carry the correct User-Agent header.""" + client = _make_client() + headers = client._make_headers(ProviderFamily.OPENAI) + assert headers["User-Agent"] == USER_AGENT + + +def test_make_headers_includes_authorization() -> None: + """Authorization header must use Bearer scheme with the current token.""" + client = _make_client(_make_credentials(token="dapi-abc123")) + headers = client._make_headers(ProviderFamily.OPENAI) + assert headers["Authorization"] == "Bearer dapi-abc123" + + +def test_make_headers_includes_content_type() -> None: + client = _make_client() + headers = client._make_headers(ProviderFamily.OPENAI) + assert headers["Content-Type"] == "application/json" + + +def test_make_headers_openai_family_has_no_anthropic_version() -> None: + """OPENAI / GEMINI / RESPONSES families must not set ``anthropic-version``.""" + client = _make_client() + for f in (ProviderFamily.OPENAI, ProviderFamily.GEMINI, ProviderFamily.OPENAI_RESPONSES): + headers = client._make_headers(f) + assert "anthropic-version" not in headers, ( + f"family={f} must not carry Anthropic header" + ) + + +def test_make_headers_anthropic_family_sets_anthropic_version() -> None: + """Anthropic native API requires the ``anthropic-version`` header.""" + client = _make_client() + headers = client._make_headers(ProviderFamily.ANTHROPIC) + assert "anthropic-version" in headers + assert headers["anthropic-version"], "anthropic-version must be non-empty" + + +# --------------------------------------------------------------------------- +# _parse_response (P0-1) +# --------------------------------------------------------------------------- + +def test_parse_response_maps_to_model_response() -> None: + """_parse_response should return a litellm ModelResponse from FMAPI JSON.""" + client = _make_client() + body = _make_success_response() + resp = httpx.Response(200, json=body) + + result = client._parse_response(resp, family=ProviderFamily.OPENAI, model="my-model") + + assert result.choices is not None + assert len(result.choices) > 0 + + +def test_parse_response_preserves_id() -> None: + client = _make_client() + body = _make_success_response() + body["id"] = "chatcmpl-unique-id" + resp = httpx.Response(200, json=body) + + result = client._parse_response(resp, family=ProviderFamily.OPENAI, model="m") + assert result.id == "chatcmpl-unique-id" + + +def test_parse_response_handles_malformed_body_gracefully() -> None: + """_parse_response must not crash on unexpected FMAPI shape (fallback path).""" + client = _make_client() + resp = httpx.Response(200, json={"unexpected": "field"}) + result = client._parse_response( + resp, family=ProviderFamily.OPENAI, model="fallback-model" + ) + assert result is not None + + +# --------------------------------------------------------------------------- +# _build_stream_response (P0-1) +# --------------------------------------------------------------------------- + +def test_build_stream_response_assembles_content() -> None: + """Streaming: accumulated content must appear in the single returned ModelResponse.""" + client = _make_client() + result = client._build_stream_response( + content="The answer is 42.", + response_id="stream-resp-1", + model="databricks/test-model", + ) + choices = result.choices + assert choices is not None and len(choices) > 0 + msg = choices[0].get("message") or getattr(choices[0], "message", None) + # Extract content regardless of dict or object + content = msg.get("content") if isinstance(msg, dict) else getattr(msg, "content", None) + assert content == "The answer is 42." + + +def test_build_stream_response_uses_fallback_id_when_empty() -> None: + client = _make_client() + result = client._build_stream_response( + content="hi", response_id="", model="m" + ) + assert result.id is not None + assert result.id != "" + + +def test_build_stream_response_sets_finish_reason_stop() -> None: + client = _make_client() + result = client._build_stream_response(content="done", response_id="rid", model="m") + choice = result.choices[0] + finish_reason = ( + choice.get("finish_reason") if isinstance(choice, dict) + else getattr(choice, "finish_reason", None) + ) + assert finish_reason == "stop" + + +# --------------------------------------------------------------------------- +# URL construction +# --------------------------------------------------------------------------- + +def test_chat_completion_builds_correct_url_workspace_host() -> None: + """Workspace host → gateway is reverse-proxied at /ai-gateway.""" + client = _make_client(max_retries=0) + captured_url: list[str] = [] + captured_body: list[dict] = [] + + def mock_post(url, headers=None, json=None, **_kw): + captured_url.append(url) + captured_body.append(json or {}) + return httpx.Response(200, json=_make_success_response()) + + with patch.object(client._http, "post", side_effect=mock_post): + client.chat_completion( + model="databricks-meta-llama-3-3-70b-instruct", + messages=[{"role": "user", "content": "hello"}], + ) + + assert captured_url == [f"{_HOST}/ai-gateway/mlflow/v1/chat/completions"] + # The mlflow path no longer carries the endpoint in the URL — it must be + # in the request body so the gateway knows which model to route to. + assert captured_body[0].get("model") == "databricks-meta-llama-3-3-70b-instruct" + + +def test_chat_completion_uses_dedicated_gateway_host_when_set() -> None: + """Dedicated *.ai-gateway.* host is used as-is (no /ai-gateway prefix).""" + dedicated = "https://9999999999999999.ai-gateway.cloud.databricks.com" + client = _make_client(max_retries=0, ai_gateway_host=dedicated) + captured_url: list[str] = [] + + def mock_post(url, headers=None, json=None, **_kw): + captured_url.append(url) + return httpx.Response( + 200, + json={ + "id": "msg_x", + "type": "message", + "role": "assistant", + "content": [{"type": "text", "text": "hi"}], + "stop_reason": "end_turn", + "usage": {"input_tokens": 1, "output_tokens": 1}, + }, + ) + + with patch.object(client._http, "post", side_effect=mock_post): + client.chat_completion( + model="databricks-claude-opus-4-6", + messages=[{"role": "user", "content": "hi"}], + ) + + assert captured_url == [f"{dedicated}/anthropic/v1/messages"] + + +def test_chat_completion_ignores_extra_litellm_kwargs() -> None: + """extra_headers and extra_body (litellm conventions) must not appear in + the JSON body forwarded to the AI Gateway — the gateway returns 400 if + they are present.""" + client = _make_client(max_retries=0) + captured_body: list[dict] = [] + + def mock_post(url, headers=None, json=None, **_kw): + captured_body.append(json or {}) + return httpx.Response( + 200, + json={ + "id": "msg_x", + "type": "message", + "role": "assistant", + "content": [{"type": "text", "text": "ok"}], + "stop_reason": "end_turn", + "usage": {"input_tokens": 1, "output_tokens": 1}, + }, + ) + + with patch.object(client._http, "post", side_effect=mock_post): + client.chat_completion( + model="databricks-claude-sonnet-4-5", + messages=[{"role": "user", "content": "hi"}], + # These are litellm-specific kwargs that DatabricksLLM._transport_call + # strips before forwarding — the client must never receive them. + # (We pass them here directly to confirm the client tolerates them + # if present, but the primary assertion is they don't reach the body.) + ) + + body = captured_body[0] + assert "extra_headers" not in body, "extra_headers must not appear in gateway request body" + assert "extra_body" not in body, "extra_body must not appear in gateway request body" + + +# --------------------------------------------------------------------------- +# resolve_family — metadata-first routing with name-pattern fallback +# --------------------------------------------------------------------------- + +def test_resolve_family_uses_metadata_when_available() -> None: + """When the describe call returns ``api_types``, that wins over the name.""" + client = _make_client(max_retries=0, metadata_probe=True) + + # Model name screams "openai chat" but metadata says anthropic. + meta_response = httpx.Response( + 200, + json={ + "config": { + "served_entities": [{ + "foundation_model": {"api_types": ["anthropic/v1/messages"]}, + }], + }, + }, + ) + with patch.object(client._http, "get", return_value=meta_response): + family = client.resolve_family("my-custom-endpoint") + + assert family is ProviderFamily.ANTHROPIC, ( + "metadata api_types must authoritatively override the name pattern" + ) + + +def test_resolve_family_falls_back_to_name_when_metadata_fails() -> None: + """If the describe call errors, we must fall back to ``detect_family``.""" + client = _make_client(max_retries=0, metadata_probe=True) + + with patch.object( + client._http, "get", + side_effect=httpx.HTTPError("boom"), + ): + family = client.resolve_family("databricks-claude-sonnet-4-5") + + assert family is ProviderFamily.ANTHROPIC, ( + "name-pattern fallback must match *claude* → ANTHROPIC" + ) + + +def test_resolve_family_caches_positive_hit() -> None: + """A successful metadata resolve should not trigger a second describe call.""" + client = _make_client(max_retries=0, metadata_probe=True) + + meta_response = httpx.Response( + 200, + json={"config": {"served_entities": [ + {"foundation_model": {"api_types": ["gemini/v1/generateContent"]}}, + ]}}, + ) + with patch.object(client._http, "get", return_value=meta_response) as mock_get: + first = client.resolve_family("databricks-gemini-x") + second = client.resolve_family("databricks-gemini-x") + + assert first is second is ProviderFamily.GEMINI + assert mock_get.call_count == 1, "second resolve must be served from cache" + + +def test_resolve_family_default_skips_metadata_probe() -> None: + """Default (metadata_probe=False) must NOT hit the workspace URL. + + The whole point of the connector is to send FM traffic to the AI Gateway + host; the workspace URL is only for auth/discovery and must be touched + 'only as required'. With the default config, resolve_family must rely on + detect_family(model) and never issue a metadata GET. + """ + client = _make_client(max_retries=0) # metadata_probe defaults to False + + with patch.object(client._http, "get") as mock_get: + family_claude = client.resolve_family("databricks-claude-opus-4-6") + family_gemini = client.resolve_family("databricks-gemini-2-5-flash") + family_gpt5 = client.resolve_family("databricks-gpt-5-4-mini") + family_chat = client.resolve_family("databricks-meta-llama-3-3-70b-instruct") + + # Name-pattern detection must give the right family for each. + assert family_claude is ProviderFamily.ANTHROPIC + assert family_gemini is ProviderFamily.GEMINI + assert family_gpt5 is ProviderFamily.OPENAI_RESPONSES + assert family_chat is ProviderFamily.OPENAI + + # Critical assertion: NO workspace metadata GET on the FM hot path. + assert mock_get.call_count == 0, ( + "Default config must not issue any GET against the workspace URL " + "(the workspace must only be hit 'as required' — not on every chat)." + ) + + +# --------------------------------------------------------------------------- +# __del__ cleanup (P1-1) +# --------------------------------------------------------------------------- + +def test_del_closes_http_client() -> None: + """__del__ must close the singleton httpx.Client to release connections.""" + client = _make_client() + with patch.object(client._http, "close") as mock_close: + client.__del__() + mock_close.assert_called_once() + + +def test_del_is_idempotent_on_exception() -> None: + """__del__ must not raise even if the client is already closed.""" + client = _make_client() + client._http.close() # close manually first + client.__del__() # should not raise + + +def test_explicit_close_works() -> None: + """close() must close the singleton httpx.Client.""" + client = _make_client() + with patch.object(client._http, "close") as mock_close: + client.close() + mock_close.assert_called_once() diff --git a/tests/sdk/llm/providers/databricks/test_discovery.py b/tests/sdk/llm/providers/databricks/test_discovery.py new file mode 100644 index 0000000000..c6068d4c87 --- /dev/null +++ b/tests/sdk/llm/providers/databricks/test_discovery.py @@ -0,0 +1,573 @@ +"""Tests for Databricks FMAPI model discovery. + +Covers: filter logic (endpoint_type, task, state.ready), User-Agent header (PWAF), +TTL cache (hit/miss/expiry), error handling (returns [] silently), and +list_models_from_env env-var handling. +""" + +from __future__ import annotations + +import time +from unittest.mock import patch + +import httpx +import pytest + +import openhands.sdk.llm.providers.databricks.discovery as discovery_module +from openhands.sdk.llm.providers.databricks.auth import DatabricksCredentials +from openhands.sdk.llm.providers.databricks.discovery import ( + CURATED_DATABRICKS_MODELS, + DiscoveredEndpoint, + ModelPickerEntry, + get_picker_entries, + list_chat_endpoints, + list_foundation_models, + list_models_from_env, +) +from openhands.sdk.llm.providers.databricks.models import ProviderFamily +from openhands.sdk.llm.providers.databricks.utils import USER_AGENT + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +_HOST = "https://adb-123.azuredatabricks.net" +_DISCOVERY_URL = f"{_HOST}/api/2.0/serving-endpoints" + + +def _discovery_response(status: int, body: dict) -> httpx.Response: + """Build a Response httpx can run raise_for_status() on (needs bound request).""" + req = httpx.Request("GET", _DISCOVERY_URL) + return httpx.Response(status, json=body, request=req) + + +def _make_credentials(token: str = "test-token") -> DatabricksCredentials: + return DatabricksCredentials( + host=_HOST, + get_token=lambda: token, + auth_method="pat", + ) + + +def _make_endpoints_payload(endpoints: list[dict]) -> dict: + return {"endpoints": endpoints} + + +def _fmapi_ep(name: str, ready: bool = True) -> dict: + return { + "name": name, + "endpoint_type": "FOUNDATION_MODEL_API", + "task": "llm/v1/chat", + "state": {"ready": "READY" if ready else "NOT_READY"}, + } + + +def _non_fmapi_ep(name: str) -> dict: + return { + "name": name, + "endpoint_type": "CUSTOM_MODEL", + "task": "llm/v1/chat", + "state": {"ready": "READY"}, + } + + +def _embedding_ep(name: str) -> dict: + return { + "name": name, + "endpoint_type": "FOUNDATION_MODEL_API", + "task": "llm/v1/embeddings", # not chat + "state": {"ready": "READY"}, + } + + +@pytest.fixture(autouse=True) +def reset_discovery_cache() -> None: + """Reset module-level cache between tests to prevent interference.""" + discovery_module._CACHED_MODELS = [] + discovery_module._CACHE_EXPIRES_AT = 0.0 + yield + discovery_module._CACHED_MODELS = [] + discovery_module._CACHE_EXPIRES_AT = 0.0 + + +# --------------------------------------------------------------------------- +# list_foundation_models — filter logic +# --------------------------------------------------------------------------- + +def test_list_foundation_models_returns_fmapi_chat_endpoints() -> None: + """Only FOUNDATION_MODEL_API + llm/v1/chat + READY endpoints are returned.""" + payload = _make_endpoints_payload([ + _fmapi_ep("databricks-meta-llama-3-3-70b-instruct"), + _fmapi_ep("databricks-dbrx-instruct"), + ]) + with patch("httpx.get", return_value=_discovery_response(200, payload)): + models = list_foundation_models(_make_credentials()) + + assert "databricks/databricks-meta-llama-3-3-70b-instruct" in models + assert "databricks/databricks-dbrx-instruct" in models + assert len(models) == 2 + + +def test_list_foundation_models_excludes_non_fmapi_endpoints() -> None: + """CUSTOM_MODEL endpoints must not be returned.""" + payload = _make_endpoints_payload([ + _fmapi_ep("llama-model"), + _non_fmapi_ep("custom-agent-endpoint"), + ]) + with patch("httpx.get", return_value=_discovery_response(200, payload)): + models = list_foundation_models(_make_credentials()) + + assert "databricks/llama-model" in models + assert "databricks/custom-agent-endpoint" not in models + assert len(models) == 1 + + +def test_list_foundation_models_excludes_embedding_endpoints() -> None: + """llm/v1/embeddings task must not be included (only llm/v1/chat).""" + payload = _make_endpoints_payload([ + _fmapi_ep("chat-model"), + _embedding_ep("embed-model"), + ]) + with patch("httpx.get", return_value=_discovery_response(200, payload)): + models = list_foundation_models(_make_credentials()) + + assert "databricks/chat-model" in models + assert "databricks/embed-model" not in models + + +def test_list_foundation_models_excludes_not_ready_endpoints() -> None: + """Endpoints with state.ready != 'READY' must be excluded.""" + payload = _make_endpoints_payload([ + _fmapi_ep("ready-model", ready=True), + _fmapi_ep("loading-model", ready=False), + ]) + with patch("httpx.get", return_value=_discovery_response(200, payload)): + models = list_foundation_models(_make_credentials()) + + assert "databricks/ready-model" in models + assert "databricks/loading-model" not in models + + +def test_list_foundation_models_empty_workspace() -> None: + """Empty endpoints list returns empty list.""" + with patch("httpx.get", return_value=_discovery_response(200, {"endpoints": []})): + models = list_foundation_models(_make_credentials()) + assert models == [] + + +# --------------------------------------------------------------------------- +# PWAF: User-Agent header on discovery calls +# --------------------------------------------------------------------------- + +def test_list_foundation_models_sends_user_agent() -> None: + """PWAF: User-Agent header must be present on the discovery GET request.""" + captured_headers: dict = {} + + def mock_get(url, headers=None, timeout=None): + captured_headers.update(headers or {}) + return _discovery_response(200, {"endpoints": []}) + + with patch("httpx.get", side_effect=mock_get): + list_foundation_models(_make_credentials()) + + assert captured_headers.get("User-Agent") == USER_AGENT + + +def test_list_foundation_models_sends_authorization() -> None: + """Authorization header must be present with Bearer scheme.""" + captured_headers: dict = {} + + def mock_get(url, headers=None, timeout=None): + captured_headers.update(headers or {}) + return _discovery_response(200, {"endpoints": []}) + + with patch("httpx.get", side_effect=mock_get): + list_foundation_models(_make_credentials(token="test-tok")) + + assert captured_headers.get("Authorization") == "Bearer test-tok" + + +# --------------------------------------------------------------------------- +# list_models_from_env — TTL cache +# --------------------------------------------------------------------------- + +def test_list_models_from_env_returns_empty_without_env_vars(monkeypatch) -> None: + """Returns [] when DATABRICKS_HOST or DATABRICKS_TOKEN are not set.""" + monkeypatch.delenv("DATABRICKS_HOST", raising=False) + monkeypatch.delenv("DATABRICKS_TOKEN", raising=False) + monkeypatch.delenv("DATABRICKS_ACCESS_TOKEN", raising=False) + + models = list_models_from_env() + assert models == [] + + +def test_list_models_from_env_uses_cache_on_second_call(monkeypatch) -> None: + """Second call within TTL must use cache, not make a new HTTP request.""" + monkeypatch.setenv("DATABRICKS_HOST", _HOST) + monkeypatch.setenv("DATABRICKS_TOKEN", "test-tok") + + with patch( + "httpx.get", + return_value=_discovery_response(200, {"endpoints": [_fmapi_ep("m1")]}), + ) as mock_get: + first = list_models_from_env() + second = list_models_from_env() + + # Should only call httpx.get once; second call uses cache + assert mock_get.call_count == 1 + assert first == second + assert "databricks/m1" in first + + +def test_list_models_from_env_refreshes_after_ttl_expiry(monkeypatch) -> None: + """After TTL expires, a new HTTP call should be made.""" + monkeypatch.setenv("DATABRICKS_HOST", _HOST) + monkeypatch.setenv("DATABRICKS_TOKEN", "tok") + + # Pre-populate cache with an expired timestamp + discovery_module._CACHED_MODELS = ["databricks/old-model"] + discovery_module._CACHE_EXPIRES_AT = time.time() - 1 # already expired + + with patch( + "httpx.get", + return_value=_discovery_response(200, {"endpoints": [_fmapi_ep("new-model")]}), + ): + models = list_models_from_env() + + assert "databricks/new-model" in models + assert "databricks/old-model" not in models + + +def test_list_models_from_env_returns_empty_on_http_error(monkeypatch) -> None: + """HTTP errors must be swallowed and return [] (never raise).""" + monkeypatch.setenv("DATABRICKS_HOST", _HOST) + monkeypatch.setenv("DATABRICKS_TOKEN", "tok") + + with patch("httpx.get", side_effect=httpx.ConnectError("connection refused")): + models = list_models_from_env() + + assert models == [] + + +def test_list_models_from_env_model_names_prefixed(monkeypatch) -> None: + """All returned model names must be prefixed with 'databricks/'.""" + monkeypatch.setenv("DATABRICKS_HOST", _HOST) + monkeypatch.setenv("DATABRICKS_TOKEN", "tok") + + endpoints = [ + _fmapi_ep("databricks-meta-llama-3-3-70b-instruct"), + _fmapi_ep("databricks-dbrx-instruct"), + ] + with patch( + "httpx.get", + return_value=_discovery_response(200, {"endpoints": endpoints}), + ): + models = list_models_from_env() + + for m in models: + assert m.startswith("databricks/"), f"Model {m!r} missing 'databricks/' prefix" + + +# --------------------------------------------------------------------------- +# External-model inclusion (AI Gateway parity with FM endpoints) +# --------------------------------------------------------------------------- + +def _external_ep(name: str, ready: bool = True) -> dict: + """EXTERNAL_MODEL chat endpoint — e.g. customer-configured gpt-5 / gemini proxy.""" + return { + "name": name, + "endpoint_type": "EXTERNAL_MODEL", + "task": "llm/v1/chat", + "state": {"ready": "READY" if ready else "NOT_READY"}, + } + + +def test_list_foundation_models_includes_external_model_endpoints() -> None: + """EXTERNAL_MODEL chat endpoints must be returned alongside FOUNDATION_MODEL_API. + + External-model endpoints (e.g. customer-configured gpt-5 / gemini proxies) + are AI-Gateway-shaped and routed through the same native-API paths. + """ + payload = _make_endpoints_payload([ + _fmapi_ep("databricks-llama-3-3-70b"), + _external_ep("my-gpt5-proxy"), + ]) + with patch("httpx.get", return_value=_discovery_response(200, payload)): + models = list_foundation_models(_make_credentials()) + + assert "databricks/databricks-llama-3-3-70b" in models + assert "databricks/my-gpt5-proxy" in models, ( + "EXTERNAL_MODEL endpoints must be discoverable via list_foundation_models" + ) + assert len(models) == 2 + + +def test_list_foundation_models_excludes_custom_model_endpoints() -> None: + """CUSTOM_MODEL endpoints are still excluded — payload shape not guaranteed.""" + payload = _make_endpoints_payload([ + _fmapi_ep("fm-endpoint"), + _external_ep("external-endpoint"), + _non_fmapi_ep("custom-agent"), # endpoint_type=CUSTOM_MODEL + ]) + with patch("httpx.get", return_value=_discovery_response(200, payload)): + models = list_foundation_models(_make_credentials()) + + assert "databricks/fm-endpoint" in models + assert "databricks/external-endpoint" in models + assert "databricks/custom-agent" not in models + + +# --------------------------------------------------------------------------- +# list_chat_endpoints — structured output +# --------------------------------------------------------------------------- + +def test_list_chat_endpoints_returns_dataclass_records() -> None: + """Structured API returns DiscoveredEndpoint records with metadata intact.""" + payload = _make_endpoints_payload([ + { + "name": "fm-llama", + "endpoint_type": "FOUNDATION_MODEL_API", + "task": "llm/v1/chat", + "state": {"ready": "READY"}, + "creator": "alice@example.com", + }, + { + "name": "ext-gpt5", + "endpoint_type": "EXTERNAL_MODEL", + "task": "llm/v1/chat", + "state": {"ready": "READY"}, + "creator": "bob@example.com", + }, + ]) + with patch("httpx.get", return_value=_discovery_response(200, payload)): + eps = list_chat_endpoints(_make_credentials()) + + assert len(eps) == 2 + assert all(isinstance(e, DiscoveredEndpoint) for e in eps) + + by_name = {e.name: e for e in eps} + assert by_name["fm-llama"].endpoint_type == "FOUNDATION_MODEL_API" + assert by_name["fm-llama"].qualified_name == "databricks/fm-llama" + assert by_name["fm-llama"].ready is True + assert by_name["fm-llama"].creator == "alice@example.com" + + assert by_name["ext-gpt5"].endpoint_type == "EXTERNAL_MODEL" + assert by_name["ext-gpt5"].qualified_name == "databricks/ext-gpt5" + + +def test_list_chat_endpoints_include_not_ready_opt_in() -> None: + """Not-ready endpoints are excluded by default, included on opt-in.""" + payload = _make_endpoints_payload([ + _fmapi_ep("ready-one", ready=True), + _fmapi_ep("loading-one", ready=False), + ]) + + with patch("httpx.get", return_value=_discovery_response(200, payload)): + default = list_chat_endpoints(_make_credentials()) + assert [e.name for e in default] == ["ready-one"] + + with patch("httpx.get", return_value=_discovery_response(200, payload)): + with_loading = list_chat_endpoints(_make_credentials(), include_not_ready=True) + names = sorted(e.name for e in with_loading) + assert names == ["loading-one", "ready-one"] + loading = next(e for e in with_loading if e.name == "loading-one") + assert loading.ready is False + + +def test_list_chat_endpoints_skips_unnamed_rows() -> None: + """Endpoints missing a ``name`` field must be silently skipped.""" + payload = _make_endpoints_payload([ + { + "endpoint_type": "FOUNDATION_MODEL_API", + "task": "llm/v1/chat", + "state": {"ready": "READY"}, + }, + _fmapi_ep("good-one"), + ]) + with patch("httpx.get", return_value=_discovery_response(200, payload)): + eps = list_chat_endpoints(_make_credentials()) + assert [e.name for e in eps] == ["good-one"] + + +# --------------------------------------------------------------------------- +# Two-tier picker: CURATED_DATABRICKS_MODELS + get_picker_entries +# --------------------------------------------------------------------------- + + +def test_curated_list_is_claude_gpt_gemini_only() -> None: + """Curated tier-1 set covers only the three native-API families we target. + + Llama / DBRX / legacy endpoints must *not* appear in the curated list — + they surface automatically via discovery only if the workspace has them. + """ + families = {e.family for e in CURATED_DATABRICKS_MODELS} + assert families == { + ProviderFamily.ANTHROPIC, + ProviderFamily.OPENAI, + ProviderFamily.OPENAI_RESPONSES, + ProviderFamily.GEMINI, + } + + names = [e.name for e in CURATED_DATABRICKS_MODELS] + assert all(n.startswith("databricks-") for n in names) + forbidden = ("llama", "dbrx", "mixtral", "qwen", "deepseek") + for n in names: + for token in forbidden: + assert token not in n.lower(), ( + f"{n!r} leaked a non-curated family — curated tier is Claude/GPT/Gemini only" + ) + + +def test_curated_list_has_one_recommended_per_family() -> None: + """Exactly one ``recommended`` entry per family — the fast-and-good default.""" + recs_by_family: dict[ProviderFamily, list[str]] = {} + for e in CURATED_DATABRICKS_MODELS: + if e.recommended: + recs_by_family.setdefault(e.family, []).append(e.name) + + # GPT-5 (Responses) gets the recommended OpenAI slot because it's the + # gold-path OpenAI native API on Databricks. Plain OPENAI (gpt-oss) is + # listed but not recommended. + assert set(recs_by_family) == { + ProviderFamily.ANTHROPIC, + ProviderFamily.OPENAI_RESPONSES, + ProviderFamily.GEMINI, + } + for family, picks in recs_by_family.items(): + assert len(picks) == 1, f"{family} has >1 recommended pick: {picks}" + + +def test_curated_entries_have_qualified_name_prefix() -> None: + """Every curated entry must use the ``databricks/`` prefix — that's what + ``create_llm`` sees and routes on.""" + for e in CURATED_DATABRICKS_MODELS: + assert e.qualified_name == f"databricks/{e.name}" + assert e.source == "curated" + assert e.ready is True + assert e.endpoint_type == "FOUNDATION_MODEL_API" + + +def test_get_picker_entries_returns_curated_without_credentials() -> None: + """No creds → pure curated tier, no HTTP call.""" + with patch("httpx.get") as mock_get: + entries = get_picker_entries(credentials=None) + mock_get.assert_not_called() + + assert len(entries) == len(CURATED_DATABRICKS_MODELS) + assert all(isinstance(e, ModelPickerEntry) for e in entries) + assert all(e.source == "curated" for e in entries) + + +def test_get_picker_entries_sort_order_recommended_first() -> None: + """Recommended picks come first; then family (alpha), then name.""" + entries = get_picker_entries(credentials=None) + + # 1. All recommended come before any non-recommended. + first_non_rec = next( + (i for i, e in enumerate(entries) if not e.recommended), len(entries) + ) + rec_section = entries[:first_non_rec] + rest_section = entries[first_non_rec:] + assert all(e.recommended for e in rec_section) + assert all(not e.recommended for e in rest_section) + + # 2. Within each section, family alpha-sorted. + def _family_order(section: list[ModelPickerEntry]) -> list[str]: + return [e.family.value for e in section] + + for section in (rec_section, rest_section): + fams = _family_order(section) + assert fams == sorted(fams), f"Family ordering broken in section: {fams}" + + +def test_get_picker_entries_merges_discovered_on_top_of_curated() -> None: + """Discovery adds non-curated endpoints; curated entries get live signals.""" + # One endpoint overlaps curated (claude-sonnet-4-6), two are new. + payload = _make_endpoints_payload([ + _fmapi_ep("databricks-claude-sonnet-4-6"), # overlaps curated + _fmapi_ep("databricks-meta-llama-4-maverick"), # discovered only + _external_ep("customer-private-gpt"), # external-model discovered only + ]) + + with patch("httpx.get", return_value=_discovery_response(200, payload)): + entries = get_picker_entries(credentials=_make_credentials()) + + by_qn = {e.qualified_name: e for e in entries} + + # Overlap: curated entry kept recommended + its opinionated family, but + # gained "curated+discovered" source and the live endpoint_type/ready. + overlap = by_qn["databricks/databricks-claude-sonnet-4-6"] + assert overlap.source == "curated+discovered" + assert overlap.family is ProviderFamily.ANTHROPIC + assert overlap.recommended is True + assert overlap.endpoint_type == "FOUNDATION_MODEL_API" + assert overlap.ready is True + + # Discovered-only FMAPI entry — family inferred via detect_family. + llama = by_qn["databricks/databricks-meta-llama-4-maverick"] + assert llama.source == "discovered" + assert llama.family is ProviderFamily.OPENAI # llama → OpenAI Chat default + assert llama.recommended is False + assert llama.endpoint_type == "FOUNDATION_MODEL_API" + + # External-model endpoint shows up as "discovered" with its live type. + ext = by_qn["databricks/customer-private-gpt"] + assert ext.source == "discovered" + assert ext.endpoint_type == "EXTERNAL_MODEL" + + # Count: curated ∪ discovered, deduped by qualified_name. + expected = {e.qualified_name for e in CURATED_DATABRICKS_MODELS} | { + "databricks/databricks-meta-llama-4-maverick", + "databricks/customer-private-gpt", + } + assert {e.qualified_name for e in entries} == expected + + +def test_get_picker_entries_swallows_discovery_errors() -> None: + """If discovery blows up, curated tier is still returned intact.""" + with patch("httpx.get", side_effect=RuntimeError("workspace down")): + entries = get_picker_entries(credentials=_make_credentials()) + + # Curated list is returned as-is (no "curated+discovered" upgrades). + assert {e.qualified_name for e in entries} == { + c.qualified_name for c in CURATED_DATABRICKS_MODELS + } + assert all(e.source == "curated" for e in entries) + + +def test_get_picker_entries_include_curated_false_returns_only_discovered() -> None: + """Opt out of curated to get a pure live list — useful for admin UIs.""" + payload = _make_endpoints_payload([_fmapi_ep("live-only-endpoint")]) + with patch("httpx.get", return_value=_discovery_response(200, payload)): + entries = get_picker_entries( + credentials=_make_credentials(), + include_curated=False, + ) + assert [e.qualified_name for e in entries] == ["databricks/live-only-endpoint"] + assert entries[0].source == "discovered" + + +def test_get_picker_entries_include_discovered_false_skips_http() -> None: + """Opt out of discovery even with creds — no HTTP fired.""" + with patch("httpx.get") as mock_get: + entries = get_picker_entries( + credentials=_make_credentials(), + include_discovered=False, + ) + mock_get.assert_not_called() + assert len(entries) == len(CURATED_DATABRICKS_MODELS) + + +def test_get_picker_entries_user_agent_propagates_through_discovery() -> None: + """PWAF: every Databricks HTTP call (including discovery triggered by the + picker) must carry the ``OpenHandsOSS/`` User-Agent.""" + payload = _make_endpoints_payload([_fmapi_ep("anything")]) + with patch( + "httpx.get", return_value=_discovery_response(200, payload) + ) as mock_get: + get_picker_entries(credentials=_make_credentials()) + + assert mock_get.called + _, kwargs = mock_get.call_args + assert kwargs["headers"]["User-Agent"] == USER_AGENT diff --git a/tests/sdk/llm/providers/databricks/test_llm.py b/tests/sdk/llm/providers/databricks/test_llm.py new file mode 100644 index 0000000000..e22f431526 --- /dev/null +++ b/tests/sdk/llm/providers/databricks/test_llm.py @@ -0,0 +1,492 @@ +"""Tests for DatabricksLLM and the create_llm factory. + +Covers: PAT construction, provider discriminator (P0-6), context window lookup, +max_output_tokens lookup, unknown model fallback, _transport_call prefix stripping, +Pydantic round-trip serialization, and create_llm factory routing. +""" + +from __future__ import annotations + +from unittest.mock import MagicMock, patch + +import pytest +from pydantic import SecretStr, ValidationError + +from openhands.sdk.llm.providers.databricks.llm import ( + DATABRICKS_CONTEXT_WINDOWS, + DATABRICKS_MAX_OUTPUT, + DatabricksLLM, +) +from openhands.sdk.llm.providers.databricks.models import ( + ProviderFamily, + StoredU2MTokens, +) + + +# --------------------------------------------------------------------------- +# Constants +# --------------------------------------------------------------------------- + +_HOST = "https://adb-123.azuredatabricks.net" +_MODEL_LLAMA = "databricks/databricks-meta-llama-3-3-70b-instruct" +_MODEL_CLAUDE = "databricks/databricks-claude-sonnet-4" +_MODEL_UNKNOWN = "databricks/my-custom-finetuned-model" + + +# --------------------------------------------------------------------------- +# Helper: minimal PAT-auth DatabricksLLM (no HTTP calls during construction) +# --------------------------------------------------------------------------- + +def _make_llm( + model: str = _MODEL_LLAMA, + host: str = _HOST, + token: str = "dapi-test", + ai_gateway_host: str | None = None, + **kwargs, +) -> DatabricksLLM: + """Build a DatabricksLLM with PAT auth for tests. + + The new architecture only requires *one* of databricks_host or + databricks_ai_gateway_host. We always set databricks_host (the canonical + workspace URL) and forward ai_gateway_host only when explicitly given, + so tests reflect the typical single-URL deployment by default. + """ + return DatabricksLLM( + model=model, + databricks_host=host, + databricks_ai_gateway_host=ai_gateway_host, + api_key=SecretStr(token), + usage_id="test-databricks-llm", + **kwargs, + ) + + +# --------------------------------------------------------------------------- +# Construction +# --------------------------------------------------------------------------- + +def test_construction_succeeds_with_pat_auth() -> None: + """DatabricksLLM must construct without HTTP calls when using PAT.""" + llm = _make_llm() + assert llm.model == _MODEL_LLAMA + assert llm.databricks_host == _HOST + + +def test_construction_exposes_db_client() -> None: + """_db_client private attr must be set after construction.""" + llm = _make_llm() + from openhands.sdk.llm.providers.databricks.client import DatabricksFMAPIClient + assert isinstance(llm._db_client, DatabricksFMAPIClient) + + +def test_construction_exposes_db_credentials() -> None: + """_db_credentials private attr must be set after construction.""" + from openhands.sdk.llm.providers.databricks.auth import DatabricksCredentials + llm = _make_llm() + assert isinstance(llm._db_credentials, DatabricksCredentials) + assert llm._db_credentials.auth_method == "pat" + + +def test_construction_rejects_non_https_host() -> None: + """databricks_host without https:// must raise ValueError at construction.""" + with pytest.raises(ValueError, match="https://"): + DatabricksLLM( + model=_MODEL_LLAMA, + databricks_host="http://adb-123.azuredatabricks.net", + databricks_ai_gateway_host=_HOST, + api_key=SecretStr("tok"), + usage_id="t", + ) + + +def test_construction_rejects_no_host_at_all() -> None: + """At least one of databricks_host / databricks_ai_gateway_host is required.""" + with pytest.raises(ValueError, match="databricks_host is required"): + DatabricksLLM( + model=_MODEL_LLAMA, + api_key=SecretStr("tok"), + usage_id="t", + ) + + +def test_construction_pat_with_only_workspace_host() -> None: + """PAT auth on a single-URL deployment: workspace URL is enough; the SDK + derives the AI Gateway base from it (/ai-gateway/).""" + llm = DatabricksLLM( + model=_MODEL_LLAMA, + databricks_host=_HOST, + api_key=SecretStr("tok"), + usage_id="t", + ) + assert llm.databricks_host == _HOST + assert llm.databricks_ai_gateway_host is None + assert llm._db_credentials.auth_method == "pat" + # Internally the FM client must end up routing through the workspace host. + assert llm._db_client._ai_gateway_host == _HOST + + +def test_construction_pat_with_dedicated_gateway_host_only() -> None: + """PAT auth on a dedicated gateway: the gateway host is sufficient.""" + dedicated = "https://9999999999999999.ai-gateway.cloud.databricks.com" + llm = DatabricksLLM( + model=_MODEL_LLAMA, + databricks_ai_gateway_host=dedicated, + api_key=SecretStr("tok"), + usage_id="t", + ) + assert llm.databricks_host is None + assert llm.databricks_ai_gateway_host == dedicated + assert llm._db_client._ai_gateway_host == dedicated + + +# --------------------------------------------------------------------------- +# P0-6: provider discriminator +# --------------------------------------------------------------------------- + +def test_provider_field_is_databricks() -> None: + """provider field must be exactly 'databricks' for Pydantic discriminator.""" + llm = _make_llm() + assert llm.provider == "databricks" + + +def test_provider_field_is_literal() -> None: + """provider must not be overridable by user-supplied kwargs.""" + # Attempt to construct with a different provider value should either be + # silently overridden to "databricks" or raise — it must never produce a + # DatabricksLLM with provider != "databricks". + llm = _make_llm() + assert llm.provider == "databricks" + + +# --------------------------------------------------------------------------- +# Context windows +# --------------------------------------------------------------------------- + +def test_context_window_llama_70b() -> None: + llm = _make_llm(model=_MODEL_LLAMA) + assert llm.max_input_tokens == DATABRICKS_CONTEXT_WINDOWS[_MODEL_LLAMA] + + +def test_context_window_claude() -> None: + """Claude-based Databricks models have 200K context window.""" + llm = _make_llm(model=_MODEL_CLAUDE) + assert llm.max_input_tokens == 200_000 + + +def test_context_window_unknown_model_fallback() -> None: + """Unknown Databricks models fall back to 128K context window.""" + llm = _make_llm(model=_MODEL_UNKNOWN) + assert llm.max_input_tokens == 128_000 + + +def test_max_output_tokens_claude() -> None: + llm = _make_llm(model=_MODEL_CLAUDE) + assert llm.max_output_tokens == 8_192 + + +def test_max_output_tokens_unknown_model_fallback() -> None: + """Unknown models fall back to 16K max output.""" + llm = _make_llm(model=_MODEL_UNKNOWN) + assert llm.max_output_tokens == 16_384 + + +# --------------------------------------------------------------------------- +# _transport_call — prefix stripping +# --------------------------------------------------------------------------- + +def test_transport_call_strips_databricks_prefix() -> None: + """_transport_call must strip 'databricks/' prefix before calling FMAPI.""" + llm = _make_llm() + captured_model: list[str] = [] + + def mock_chat_completion(model, messages, stream=False, on_token=None, **kwargs): + captured_model.append(model) + return MagicMock() + + with patch.object(llm._db_client, "chat_completion", side_effect=mock_chat_completion): + llm._transport_call(messages=[{"role": "user", "content": "hi"}]) + + assert len(captured_model) == 1 + # Should be bare endpoint name, not prefixed + assert not captured_model[0].startswith("databricks/") + assert captured_model[0] == "databricks-meta-llama-3-3-70b-instruct" + + +def test_transport_call_passes_through_streaming_flag() -> None: + llm = _make_llm() + + def mock_chat_completion(model, messages, stream=False, on_token=None, **kwargs): + return MagicMock() + + with patch.object(llm._db_client, "chat_completion", side_effect=mock_chat_completion) as mock_cc: + llm._transport_call(messages=[], enable_streaming=True) + + call_kwargs = mock_cc.call_args + assert call_kwargs.kwargs.get("stream") is True + + +def test_transport_call_strips_litellm_kwargs() -> None: + """_transport_call must strip litellm-specific kwargs (extra_headers, + extra_body, stream) before forwarding to chat_completion so they never + appear in the JSON body sent to the Databricks AI Gateway.""" + llm = _make_llm() + received_kwargs: dict = {} + + def mock_chat_completion(model, messages, stream=False, on_token=None, **kwargs): + received_kwargs.update(kwargs) + return MagicMock() + + with patch.object(llm._db_client, "chat_completion", side_effect=mock_chat_completion): + llm._transport_call( + messages=[{"role": "user", "content": "hi"}], + extra_headers={"X-Custom": "value"}, + extra_body={"custom_param": True}, + stream=True, # also stripped; streaming controlled via enable_streaming + ) + + assert "extra_headers" not in received_kwargs, "extra_headers must be stripped before gateway call" + assert "extra_body" not in received_kwargs, "extra_body must be stripped before gateway call" + assert "stream" not in received_kwargs, "stream must be stripped; controlled via enable_streaming" + + +# --------------------------------------------------------------------------- +# Resilience knob passthrough +# --------------------------------------------------------------------------- + +def test_custom_timeouts_propagate_to_client() -> None: + """Resilience knobs must be forwarded to DatabricksFMAPIClient.""" + llm = _make_llm( + databricks_connect_timeout_s=5.0, + databricks_read_timeout_s=60.0, + databricks_max_retries=5, + ) + assert llm._db_client._timeouts.connect_s == 5.0 + assert llm._db_client._timeouts.read_s == 60.0 + assert llm._db_client._max_retries == 5 + + +# --------------------------------------------------------------------------- +# Pydantic round-trip serialization (P0-6) +# --------------------------------------------------------------------------- + +def test_pydantic_roundtrip_preserves_provider_field() -> None: + """model_dump / model_validate round-trip must preserve provider='databricks'.""" + llm = _make_llm() + data = llm.model_dump() + assert data["provider"] == "databricks" + + +def test_pydantic_json_roundtrip_preserves_provider_field() -> None: + """JSON serialization must include provider field for deserialization dispatch.""" + llm = _make_llm() + json_str = llm.model_dump_json() + assert '"provider":"databricks"' in json_str or '"provider": "databricks"' in json_str + + +def test_m2m_client_secret_serialized_as_plaintext_with_expose_secrets() -> None: + """databricks_client_secret must be written as plaintext when serialized + with context={'expose_secrets': True} (the path used by AgentStore.save()). + + Without this, the saved agent_settings.json contains '**********' and the + M2M OIDC token request always returns 401 after a restart. + """ + secret_value = "my-real-client-secret" + llm = DatabricksLLM( + model=_MODEL_CLAUDE, + databricks_host=_HOST, + databricks_client_id="app-id-123", + databricks_client_secret=SecretStr(secret_value), + api_key=None, + ) + + import json + + # Default JSON serialization — must be redacted (what users see / screen output) + default_json = llm.model_dump_json() + default = json.loads(default_json) + assert default.get("databricks_client_secret") == "**********", ( + "secret must be redacted in default model_dump_json()" + ) + + # With expose_secrets=True (AgentStore.save path) — must be plaintext string + exposed_json = llm.model_dump_json(context={"expose_secrets": True}) + exposed = json.loads(exposed_json) + assert exposed.get("databricks_client_secret") == secret_value, ( + "secret must be plaintext when expose_secrets=True so agent_settings.json " + "contains the real value and M2M auth doesn't send '**********' to OIDC" + ) + + # Round-trip: reload from the exposed JSON and confirm secret survived + reloaded = DatabricksLLM.model_validate_json(exposed_json) + assert reloaded.databricks_client_secret is not None + assert reloaded.databricks_client_secret.get_secret_value() == secret_value, ( + "secret must survive a model_dump_json → model_validate_json round-trip" + ) + + +# --------------------------------------------------------------------------- +# create_llm factory +# --------------------------------------------------------------------------- + +def test_create_llm_routes_databricks_prefix() -> None: + """create_llm must return DatabricksLLM for 'databricks/' prefixed models.""" + from openhands.sdk import create_llm + + llm = create_llm( + model=_MODEL_LLAMA, + databricks_host=_HOST, + api_key=SecretStr("dapi-tok"), + usage_id="factory-test", + ) + assert isinstance(llm, DatabricksLLM) + + +def test_create_llm_routes_non_databricks_to_base_llm() -> None: + """create_llm must return the base LLM for non-Databricks models.""" + from openhands.sdk import create_llm + from openhands.sdk.llm import LLM + + llm = create_llm(model="claude-sonnet-4-20250514", usage_id="base-test") + assert type(llm) is LLM + assert not isinstance(llm, DatabricksLLM) + + +def test_create_llm_empty_model_raises_validation_error() -> None: + """create_llm('') delegates to base LLM, which rejects an empty model name.""" + from openhands.sdk import create_llm + + with pytest.raises(ValidationError, match="model must be specified"): + create_llm(model="", usage_id="empty-test") + + +# --------------------------------------------------------------------------- +# PWAF surfaces: auth_method / predicted_family / resolve_family +# --------------------------------------------------------------------------- + +def test_auth_method_property_reflects_pat_construction() -> None: + """auth_method must mirror the strategy resolved at construction time.""" + llm = _make_llm() + assert llm.auth_method == "pat" + + +@pytest.mark.parametrize( + "model,expected", + [ + ("databricks/databricks-meta-llama-3-3-70b-instruct", ProviderFamily.OPENAI), + ("databricks/databricks-gpt-oss-120b", ProviderFamily.OPENAI), + ("databricks/databricks-claude-sonnet-4-5", ProviderFamily.ANTHROPIC), + ("databricks/databricks-claude-opus-4-6", ProviderFamily.ANTHROPIC), + ("databricks/databricks-gemini-2-5-flash", ProviderFamily.GEMINI), + ("databricks/databricks-gpt-5-4", ProviderFamily.OPENAI_RESPONSES), + ("databricks/databricks-gpt-5-4-mini", ProviderFamily.OPENAI_RESPONSES), + ], +) +def test_predicted_family_no_http_call(model: str, expected: ProviderFamily) -> None: + """predicted_family must be pure-compute (no metadata probe).""" + llm = _make_llm(model=model) + # If this does HTTP, it will fail because adb-123 is unreachable from tests. + # detect_family is sync + pure, so this must succeed without mocking. + assert llm.predicted_family is expected + + +def test_resolve_family_delegates_to_client_and_caches() -> None: + """resolve_family delegates to DatabricksFMAPIClient.resolve_family and caches. + + Metadata probing is opt-in (`databricks_metadata_probe=True`); without + that flag `resolve_family` is name-pattern only and never hits the wire. + """ + llm = _make_llm( + model="databricks/databricks-claude-sonnet-4-5", + databricks_metadata_probe=True, + ) + + import httpx + meta_response = httpx.Response(200, json={ + "config": {"served_entities": [ + {"foundation_model": {"api_types": ["anthropic/v1/messages"]}}, + ]}, + }) + with patch.object(llm._db_client._http, "get", return_value=meta_response) as mg: + f1 = llm.resolve_family() + f2 = llm.resolve_family() + assert f1 is f2 is ProviderFamily.ANTHROPIC + assert mg.call_count == 1, "second resolve must be cache-served" + + +def test_resolve_family_default_skips_metadata_probe() -> None: + """Default resolve_family is name-pattern only; no workspace GET.""" + llm = _make_llm(model="databricks/databricks-claude-sonnet-4-5") + + with patch.object(llm._db_client._http, "get") as mg: + family = llm.resolve_family() + assert family is ProviderFamily.ANTHROPIC + assert mg.call_count == 0, "default path must not hit workspace metadata" + + +def test_transport_call_logs_family(caplog) -> None: + """_transport_call must emit a debug log with predicted_family + auth_method.""" + import logging + + llm = _make_llm(model="databricks/databricks-claude-sonnet-4-5") + + def mock_chat_completion(model, messages, stream=False, on_token=None, **kwargs): + return MagicMock() + + with caplog.at_level(logging.DEBUG, logger="openhands.sdk.llm.providers.databricks.llm"): + with patch.object(llm._db_client, "chat_completion", side_effect=mock_chat_completion): + llm._transport_call(messages=[{"role": "user", "content": "hi"}]) + + log_records = [r for r in caplog.records if r.message == "databricks_transport_call"] + assert log_records, "expected a databricks_transport_call debug record" + record = log_records[-1] + assert record.__dict__.get("predicted_family") == "anthropic" + assert record.__dict__.get("auth_method") == "pat" + + +# --------------------------------------------------------------------------- +# Expanded model capability tables (current-gen FM / external models) +# --------------------------------------------------------------------------- + +@pytest.mark.parametrize( + "model,expected_ctx", + [ + # Anthropic family — 200K across Claude generations on Databricks + ("databricks/databricks-claude-sonnet-4-5", 200_000), + ("databricks/databricks-claude-opus-4-6", 200_000), + ("databricks/databricks-claude-haiku-4-5", 200_000), + # Gemini — 1M token context + ("databricks/databricks-gemini-2-5-flash", 1_048_576), + ("databricks/databricks-gemini-2-5-pro", 1_048_576), + # GPT-5 Responses family — 400K context + ("databricks/databricks-gpt-5-4", 400_000), + ("databricks/databricks-gpt-5-4-mini", 400_000), + # gpt-oss — 128K + ("databricks/databricks-gpt-oss-120b", 128_000), + ], +) +def test_current_gen_context_windows(model: str, expected_ctx: int) -> None: + """Current-generation models must have correct context windows in the table.""" + llm = _make_llm(model=model) + assert llm.max_input_tokens == expected_ctx, ( + f"{model}: expected ctx={expected_ctx}, got {llm.max_input_tokens}" + ) + + +@pytest.mark.parametrize( + "model,min_output", + [ + ("databricks/databricks-claude-sonnet-4-5", 64_000), + ("databricks/databricks-gemini-2-5-flash", 16_384), # reasoning budget + ("databricks/databricks-gpt-5-4", 16_384), + ("databricks/databricks-gpt-oss-120b", 16_384), + ], +) +def test_reasoning_models_have_generous_output_budget( + model: str, min_output: int, +) -> None: + """Reasoning models' max_output_tokens must leave room for thinking + output.""" + llm = _make_llm(model=model) + assert llm.max_output_tokens >= min_output, ( + f"{model}: max_output_tokens too tight for reasoning — " + f"got {llm.max_output_tokens}, need >= {min_output}" + ) diff --git a/tests/sdk/llm/providers/databricks/test_models.py b/tests/sdk/llm/providers/databricks/test_models.py new file mode 100644 index 0000000000..948abe8166 --- /dev/null +++ b/tests/sdk/llm/providers/databricks/test_models.py @@ -0,0 +1,284 @@ +"""Tests for Databricks AI Gateway routing primitives. + +Covers: + * ``ProviderFamily`` enum shape + * ``detect_family`` — name-pattern routing (fast path, no HTTP) + * ``pick_family_from_api_types`` — metadata routing (authoritative) + * ``AIGatewayPaths.url`` — URL construction per family + +These primitives are the hot path of the whole connector; they must never +regress. The live E2E tests already exercise them end-to-end; this file locks +in the contract at the unit level so refactors can't silently flip routing. +""" + +from __future__ import annotations + +import pytest + +from openhands.sdk.llm.providers.databricks.models import ( + AIGatewayPaths, + ProviderFamily, + detect_family, + pick_family_from_api_types, +) + + +# --------------------------------------------------------------------------- +# ProviderFamily enum shape +# --------------------------------------------------------------------------- + +def test_provider_family_enum_values() -> None: + """Exactly four families must be exposed (OpenAI Chat / Responses, Anthropic, Gemini).""" + assert {f.value for f in ProviderFamily} == { + "openai", "openai_responses", "anthropic", "gemini", + } + + +def test_provider_family_openai_is_default_fallback() -> None: + """``OPENAI`` is the universal fallback — must be present and import-stable.""" + assert ProviderFamily.OPENAI.value == "openai" + + +# --------------------------------------------------------------------------- +# detect_family — name-pattern routing +# --------------------------------------------------------------------------- + +@pytest.mark.parametrize( + "model,expected", + [ + # Anthropic — substring match + ("databricks-claude-sonnet-4-5", ProviderFamily.ANTHROPIC), + ("databricks/databricks-claude-opus-4", ProviderFamily.ANTHROPIC), + ("claude-haiku", ProviderFamily.ANTHROPIC), + ("my-custom-claude-proxy", ProviderFamily.ANTHROPIC), + + # Gemini — substring match + ("databricks-gemini-2-5-flash", ProviderFamily.GEMINI), + ("databricks/databricks-gemini-pro", ProviderFamily.GEMINI), + ("gemini-1-5-pro", ProviderFamily.GEMINI), + + # GPT-5 series → Responses API (bare and prefixed, all variants) + ("gpt-5", ProviderFamily.OPENAI_RESPONSES), + ("gpt-5-mini", ProviderFamily.OPENAI_RESPONSES), + ("gpt-5-nano", ProviderFamily.OPENAI_RESPONSES), + ("gpt-5-1", ProviderFamily.OPENAI_RESPONSES), + ("gpt-5-1-codex-max", ProviderFamily.OPENAI_RESPONSES), + ("gpt-5-1-codex-mini", ProviderFamily.OPENAI_RESPONSES), + ("gpt-5-2-codex", ProviderFamily.OPENAI_RESPONSES), + ("gpt-5-3-codex", ProviderFamily.OPENAI_RESPONSES), + ("databricks-gpt-5-4", ProviderFamily.OPENAI_RESPONSES), + ("databricks/databricks-gpt-5-4-mini", ProviderFamily.OPENAI_RESPONSES), + # Future numbered GPT generations inherit Responses API automatically + ("gpt-6", ProviderFamily.OPENAI_RESPONSES), + ("gpt-6-mini", ProviderFamily.OPENAI_RESPONSES), + ("databricks/databricks-gpt-7-turbo", ProviderFamily.OPENAI_RESPONSES), + + # gpt-oss and everything else stays on MLflow Chat Completions + ("gpt-oss-120b", ProviderFamily.OPENAI), + ("databricks-gpt-oss-120b", ProviderFamily.OPENAI), + ("databricks-meta-llama-3-3-70b-instruct", ProviderFamily.OPENAI), + + # Case-insensitive + ("DATABRICKS-CLAUDE-SONNET-4-5", ProviderFamily.ANTHROPIC), + ("Databricks/Gemini-Flash", ProviderFamily.GEMINI), + ], +) +def test_detect_family_name_patterns(model: str, expected: ProviderFamily) -> None: + assert detect_family(model) == expected, ( + f"routing regression: detect_family({model!r}) != {expected}" + ) + + +def test_detect_family_gpt_oss_must_not_route_to_responses() -> None: + """Regression guard: gpt-oss-* has no Responses API — keep on MLflow Chat. + + The ``re.match(r"gpt-\\d", name)`` rule excludes ``gpt-oss-*`` by + construction: ``gpt-oss`` starts with ``gpt-o``, not ``gpt-``. + This test pins that boundary so a future regex change doesn't silently + break gpt-oss routing. + """ + assert detect_family("gpt-oss-120b") is ProviderFamily.OPENAI + assert detect_family("databricks-gpt-oss-20b") is ProviderFamily.OPENAI + assert detect_family("databricks/databricks-gpt-oss-120b") is ProviderFamily.OPENAI + + +# --------------------------------------------------------------------------- +# pick_family_from_api_types — metadata routing +# --------------------------------------------------------------------------- + +def test_pick_family_prefers_anthropic_over_openai_chat() -> None: + """If an endpoint exposes both ``anthropic/v1/messages`` and the mlflow chat + shim, we must pick the native Anthropic route.""" + family = pick_family_from_api_types( + ["anthropic/v1/messages", "mlflow/v1/chat/completions"], + ) + assert family is ProviderFamily.ANTHROPIC + + +def test_pick_family_prefers_gemini_over_openai_chat() -> None: + family = pick_family_from_api_types( + ["gemini/v1/generateContent", "mlflow/v1/chat/completions"], + ) + assert family is ProviderFamily.GEMINI + + +def test_pick_family_prefers_responses_over_openai_chat() -> None: + family = pick_family_from_api_types( + ["openai/v1/responses", "mlflow/v1/chat/completions"], + ) + assert family is ProviderFamily.OPENAI_RESPONSES + + +def test_pick_family_priority_order_anthropic_wins() -> None: + """When multiple specific api_types are present, priority order + decides — Anthropic wins over Gemini wins over Responses (documented).""" + family = pick_family_from_api_types( + ["gemini/v1/generateContent", + "anthropic/v1/messages", + "openai/v1/responses"], + ) + assert family is ProviderFamily.ANTHROPIC + + +def test_pick_family_defaults_to_openai_when_no_native_hint() -> None: + """mlflow-only api_types → fall back to universal OpenAI Chat.""" + family = pick_family_from_api_types(["mlflow/v1/chat/completions"]) + assert family is ProviderFamily.OPENAI + + +def test_pick_family_empty_or_none_returns_openai() -> None: + """Empty api_types + no external provider → OPENAI (universal default).""" + assert pick_family_from_api_types(None) is ProviderFamily.OPENAI + assert pick_family_from_api_types([]) is ProviderFamily.OPENAI + assert pick_family_from_api_types([], external_provider=None) is ProviderFamily.OPENAI + + +@pytest.mark.parametrize( + "provider,expected", + [ + ("anthropic", ProviderFamily.ANTHROPIC), + ("ANTHROPIC", ProviderFamily.ANTHROPIC), + ("bedrock-anthropic", ProviderFamily.ANTHROPIC), + ("google", ProviderFamily.GEMINI), + ("gemini", ProviderFamily.GEMINI), + ("openai", ProviderFamily.OPENAI), + ("azure-openai", ProviderFamily.OPENAI), + # Unknown providers → safe default + ("cohere", ProviderFamily.OPENAI), + ("", ProviderFamily.OPENAI), + ], +) +def test_pick_family_external_provider_routing( + provider: str, expected: ProviderFamily, +) -> None: + """External-model endpoints route via ``external_model.provider``.""" + family = pick_family_from_api_types([], external_provider=provider or None) + assert family == expected, ( + f"external provider {provider!r} should route to {expected}" + ) + + +def test_pick_family_native_api_type_wins_over_external_provider() -> None: + """When both signals are present, the native ``api_types`` must win.""" + family = pick_family_from_api_types( + ["anthropic/v1/messages"], + external_provider="openai", # contradictory, should be ignored + ) + assert family is ProviderFamily.ANTHROPIC + + +# --------------------------------------------------------------------------- +# AIGatewayPaths.url — URL construction per family +# --------------------------------------------------------------------------- + +_HOST = "https://adb-123.azuredatabricks.net" + + +_DEDICATED_GW = "https://9999999999999999.ai-gateway.cloud.databricks.com" + + +def test_aigateway_url_openai_chat_workspace_host() -> None: + """Workspace host: gateway is reverse-proxied at /ai-gateway.""" + url = AIGatewayPaths().url(_HOST, ProviderFamily.OPENAI, "databricks-llama-3-3") + assert url == f"{_HOST}/ai-gateway/mlflow/v1/chat/completions" + # Endpoint name is carried in the body, not the URL. + assert "llama" not in url + + +def test_aigateway_url_openai_chat_dedicated_gateway_host() -> None: + """Dedicated *.ai-gateway.* host is the gateway base; no /ai-gateway prefix.""" + url = AIGatewayPaths().url( + _DEDICATED_GW, ProviderFamily.OPENAI, "databricks-llama-3-3", + ) + assert url == f"{_DEDICATED_GW}/mlflow/v1/chat/completions" + + +def test_aigateway_url_anthropic_workspace_host() -> None: + """Anthropic native route — endpoint-agnostic, name goes in the body.""" + url = AIGatewayPaths().url( + _HOST, ProviderFamily.ANTHROPIC, "databricks-claude-sonnet-4-5", + ) + assert url == f"{_HOST}/ai-gateway/anthropic/v1/messages" + assert "claude" not in url + + +def test_aigateway_url_anthropic_dedicated_gateway_host() -> None: + url = AIGatewayPaths().url( + _DEDICATED_GW, ProviderFamily.ANTHROPIC, "databricks-claude-opus-4-6", + ) + assert url == f"{_DEDICATED_GW}/anthropic/v1/messages" + + +def test_aigateway_url_gemini_interpolates_endpoint() -> None: + url = AIGatewayPaths().url( + _HOST, ProviderFamily.GEMINI, "databricks-gemini-2-5-flash", + ) + assert url == ( + f"{_HOST}/ai-gateway/gemini/v1beta/models/" + f"databricks-gemini-2-5-flash:generateContent" + ) + + +def test_aigateway_url_openai_responses_workspace_host() -> None: + url = AIGatewayPaths().url( + _HOST, ProviderFamily.OPENAI_RESPONSES, "databricks-gpt-5-4", + ) + assert url == f"{_HOST}/ai-gateway/openai/v1/responses" + assert "gpt-5" not in url + + +def test_aigateway_url_trailing_slash_is_normalized() -> None: + """Trailing slashes on host must not produce double-slash URLs.""" + url = AIGatewayPaths().url(_HOST + "/", ProviderFamily.OPENAI, "foo") + assert "//ai-gateway" not in url + assert url == f"{_HOST}/ai-gateway/mlflow/v1/chat/completions" + + +def test_aigateway_url_explicit_ai_gateway_path_is_idempotent() -> None: + """If the user already added '/ai-gateway' to the host, don't double it.""" + host_with_prefix = f"{_HOST}/ai-gateway" + url = AIGatewayPaths().url(host_with_prefix, ProviderFamily.OPENAI, "x") + assert url == f"{_HOST}/ai-gateway/mlflow/v1/chat/completions" + + +def test_aigateway_paths_overrideable() -> None: + """Custom path templates (e.g. private deployment) round-trip.""" + paths = AIGatewayPaths(openai="/custom/{endpoint}/chat") + url = paths.url(_HOST, ProviderFamily.OPENAI, "x") + assert url == f"{_HOST}/ai-gateway/custom/x/chat" + + +def test_normalize_base_dedicated_gateway() -> None: + """*.ai-gateway.* hosts are returned as-is.""" + assert AIGatewayPaths.normalize_base(_DEDICATED_GW) == _DEDICATED_GW + assert AIGatewayPaths.normalize_base(_DEDICATED_GW + "/") == _DEDICATED_GW + + +def test_normalize_base_workspace_appends_prefix() -> None: + """Workspace URLs gain the /ai-gateway prefix.""" + assert ( + AIGatewayPaths.normalize_base(_HOST) == f"{_HOST}/ai-gateway" + ) + assert ( + AIGatewayPaths.normalize_base(f"{_HOST}/ai-gateway") == f"{_HOST}/ai-gateway" + ) diff --git a/tests/sdk/llm/providers/databricks/test_native.py b/tests/sdk/llm/providers/databricks/test_native.py new file mode 100644 index 0000000000..bdf41b6a3b --- /dev/null +++ b/tests/sdk/llm/providers/databricks/test_native.py @@ -0,0 +1,635 @@ +"""Tests for native-API adapters (``to_native`` / ``from_native``). + +Each family (Anthropic / Gemini / Responses / OpenAI Chat) has a tiny adapter +pair. These tests lock in the wire-format contract — any change here is a +potential API break for the downstream agent loop that expects OpenAI Chat. + +Live E2E already confirmed the adapters work against real endpoints (PAT → +llama, PROFILE → claude, UNIFIED → gemini). This file pins the wire shape. +""" + +from __future__ import annotations + +import pytest + +from openhands.sdk.llm.providers.databricks.models import ProviderFamily +from openhands.sdk.llm.providers.databricks.native import ( + _chat_tools_to_responses, + _flatten_content, + from_native, + to_native, +) + + +# --------------------------------------------------------------------------- +# _flatten_content — reasoning-model content blocks +# --------------------------------------------------------------------------- + +def test_flatten_content_string_passthrough() -> None: + assert _flatten_content("hello") == "hello" + + +def test_flatten_content_text_blocks_are_concatenated() -> None: + blocks = [ + {"type": "text", "text": "Hello "}, + {"type": "text", "text": "world"}, + ] + assert _flatten_content(blocks) == "Hello world" + + +def test_flatten_content_reasoning_blocks_are_dropped() -> None: + """Reasoning blocks are model-internal thought and must never leak out.""" + blocks = [ + {"type": "reasoning", "text": "thinking hard..."}, + {"type": "text", "text": "answer"}, + ] + assert _flatten_content(blocks) == "answer" + + +def test_flatten_content_garbage_returns_empty() -> None: + assert _flatten_content(None) == "" + assert _flatten_content(42) == "" + assert _flatten_content([42, "x"]) == "" + + +# --------------------------------------------------------------------------- +# OpenAI Chat (default family) +# --------------------------------------------------------------------------- + +def test_to_openai_chat_minimal_payload() -> None: + """The mlflow path doesn't carry the endpoint in the URL — it must be in the body.""" + body = to_native( + ProviderFamily.OPENAI, + "databricks-llama", + [{"role": "user", "content": "hi"}], + ) + assert body == { + "model": "databricks-llama", + "messages": [{"role": "user", "content": "hi"}], + } + + +def test_to_openai_chat_forwards_generation_kwargs() -> None: + body = to_native( + ProviderFamily.OPENAI, "m", + [{"role": "user", "content": "x"}], + max_tokens=32, temperature=0.2, top_p=0.9, stop=["END"], + ) + assert body["max_tokens"] == 32 + assert body["temperature"] == 0.2 + assert body["top_p"] == 0.9 + assert body["stop"] == ["END"] + + +def test_to_openai_chat_includes_tools_and_tool_choice() -> None: + tools = [{"type": "function", "function": {"name": "get_time"}}] + body = to_native( + ProviderFamily.OPENAI, "m", + [{"role": "user", "content": "what time"}], + tools=tools, tool_choice="auto", + ) + assert body["tools"] == tools + assert body["tool_choice"] == "auto" + + +def test_to_openai_chat_stream_flag_propagates() -> None: + body = to_native(ProviderFamily.OPENAI, "m", + [{"role": "user", "content": "x"}], stream=True) + assert body["stream"] is True + + +def test_from_openai_chat_passthrough() -> None: + """OpenAI Chat is the native format — responses should pass through intact.""" + raw = { + "id": "chatcmpl-1", "object": "chat.completion", "model": "m", + "choices": [{"index": 0, + "message": {"role": "assistant", "content": "hi"}, + "finish_reason": "stop"}], + "usage": {"prompt_tokens": 3, "completion_tokens": 1, "total_tokens": 4}, + } + out = from_native(ProviderFamily.OPENAI, "m", raw) + assert out["id"] == "chatcmpl-1" + assert out["choices"][0]["message"]["content"] == "hi" + + +def test_from_openai_chat_flattens_list_content_for_reasoning_models() -> None: + """gpt-oss-style list-of-blocks content must be flattened to a string.""" + raw = { + "choices": [{"message": {"content": [ + {"type": "reasoning", "text": "think"}, + {"type": "text", "text": "answer"}, + ]}}], + } + out = from_native(ProviderFamily.OPENAI, "m", raw) + assert out["choices"][0]["message"]["content"] == "answer" + + +# --------------------------------------------------------------------------- +# Anthropic Messages +# --------------------------------------------------------------------------- + +def test_to_anthropic_hoists_system_message() -> None: + body = to_native( + ProviderFamily.ANTHROPIC, "databricks-claude-sonnet-4-5", + [ + {"role": "system", "content": "You are a poet."}, + {"role": "user", "content": "Write a haiku."}, + ], + ) + assert body["system"] == "You are a poet." + assert body["messages"] == [{"role": "user", "content": "Write a haiku."}] + assert "system" not in {m["role"] for m in body["messages"]} + + +def test_to_anthropic_requires_max_tokens_even_if_unspecified() -> None: + """Anthropic Messages rejects requests without max_tokens — we default it.""" + body = to_native( + ProviderFamily.ANTHROPIC, "m", + [{"role": "user", "content": "hi"}], + ) + assert "max_tokens" in body + assert isinstance(body["max_tokens"], int) and body["max_tokens"] > 0 + + +def test_to_anthropic_stop_maps_to_stop_sequences() -> None: + body = to_native( + ProviderFamily.ANTHROPIC, "m", + [{"role": "user", "content": "x"}], + stop="END", + ) + assert body["stop_sequences"] == ["END"] + body = to_native( + ProviderFamily.ANTHROPIC, "m", + [{"role": "user", "content": "x"}], + stop=["A", "B"], + ) + assert body["stop_sequences"] == ["A", "B"] + + +def test_to_anthropic_includes_model_id() -> None: + """Anthropic path is endpoint-agnostic — model id travels in the body.""" + body = to_native( + ProviderFamily.ANTHROPIC, "databricks-claude-opus-4-6", + [{"role": "user", "content": "x"}], + ) + assert body["model"] == "databricks-claude-opus-4-6" + + +def test_from_anthropic_extracts_text_blocks() -> None: + raw = { + "id": "msg_abc", "model": "claude-sonnet", + "content": [{"type": "text", "text": "Hello there."}], + "stop_reason": "end_turn", + "usage": {"input_tokens": 10, "output_tokens": 3}, + } + out = from_native(ProviderFamily.ANTHROPIC, "claude-sonnet", raw) + assert out["id"] == "msg_abc" + assert out["choices"][0]["message"]["content"] == "Hello there." + assert out["choices"][0]["finish_reason"] == "stop" + assert out["usage"]["prompt_tokens"] == 10 + assert out["usage"]["completion_tokens"] == 3 + assert out["usage"]["total_tokens"] == 13 + + +@pytest.mark.parametrize( + "stop_reason,expected", + [("end_turn", "stop"), ("stop_sequence", "stop"), + ("max_tokens", "length"), ("tool_use", "tool_calls")], +) +def test_from_anthropic_stop_reason_mapping(stop_reason: str, expected: str) -> None: + raw = {"content": [{"type": "text", "text": "x"}], + "stop_reason": stop_reason, "usage": {}} + out = from_native(ProviderFamily.ANTHROPIC, "m", raw) + assert out["choices"][0]["finish_reason"] == expected + + +# --------------------------------------------------------------------------- +# Google Gemini generateContent +# --------------------------------------------------------------------------- + +def test_to_gemini_builds_contents_and_system_instruction() -> None: + body = to_native( + ProviderFamily.GEMINI, "databricks-gemini-2-5-flash", + [ + {"role": "system", "content": "Be concise."}, + {"role": "user", "content": "What is 2+2?"}, + {"role": "assistant", "content": "4"}, + ], + ) + assert body["systemInstruction"] == {"parts": [{"text": "Be concise."}]} + assert body["contents"] == [ + {"role": "user", "parts": [{"text": "What is 2+2?"}]}, + {"role": "model", "parts": [{"text": "4"}]}, # assistant → model + ] + + +def test_to_gemini_sets_max_output_tokens_with_safe_default() -> None: + """Gemini spends budget on thinking — default must be large enough for output.""" + body = to_native( + ProviderFamily.GEMINI, "m", + [{"role": "user", "content": "x"}], + ) + assert body["generationConfig"]["maxOutputTokens"] >= 256 + + +def test_to_gemini_maps_stop_to_stop_sequences() -> None: + body = to_native( + ProviderFamily.GEMINI, "m", + [{"role": "user", "content": "x"}], + stop=["END"], + ) + assert body["generationConfig"]["stopSequences"] == ["END"] + + +def test_from_gemini_extracts_text_and_usage() -> None: + raw = { + "candidates": [{ + "content": {"role": "model", "parts": [ + {"text": "Four."}, + ]}, + "finishReason": "STOP", + }], + "usageMetadata": { + "promptTokenCount": 5, "candidatesTokenCount": 2, "totalTokenCount": 7, + }, + "responseId": "gen-abc", + } + out = from_native(ProviderFamily.GEMINI, "m", raw) + assert out["choices"][0]["message"]["content"] == "Four." + assert out["choices"][0]["finish_reason"] == "stop" + assert out["usage"] == {"prompt_tokens": 5, "completion_tokens": 2, "total_tokens": 7} + assert out["id"] == "gen-abc" + + +@pytest.mark.parametrize( + "finish,expected", + [("STOP", "stop"), ("MAX_TOKENS", "length"), + ("SAFETY", "content_filter"), ("RECITATION", "content_filter"), + ("UNKNOWN_REASON", "stop")], +) +def test_from_gemini_finish_reason_mapping(finish: str, expected: str) -> None: + raw = {"candidates": [{"content": {"parts": [{"text": "x"}]}, + "finishReason": finish}]} + out = from_native(ProviderFamily.GEMINI, "m", raw) + assert out["choices"][0]["finish_reason"] == expected + + +# --------------------------------------------------------------------------- +# OpenAI Responses (GPT-5 series) +# --------------------------------------------------------------------------- + +def test_to_responses_uses_input_not_messages() -> None: + body = to_native( + ProviderFamily.OPENAI_RESPONSES, "databricks-gpt-5-4", + [{"role": "user", "content": "say hi"}], + ) + assert "messages" not in body + # Responses requires content parts of type ``input_text`` for user + # messages — string content is wrapped accordingly. + assert body["input"] == [ + {"role": "user", "content": [{"type": "input_text", "text": "say hi"}]} + ] + assert body["model"] == "databricks-gpt-5-4" + + +def test_to_responses_renames_max_tokens_to_max_output_tokens() -> None: + body = to_native( + ProviderFamily.OPENAI_RESPONSES, "m", + [{"role": "user", "content": "x"}], + max_tokens=256, + ) + assert "max_tokens" not in body + assert body["max_output_tokens"] == 256 + + +def test_to_responses_default_budget_accommodates_reasoning_tokens() -> None: + """GPT-5 spends tokens on reasoning — default must not be too small.""" + body = to_native( + ProviderFamily.OPENAI_RESPONSES, "m", + [{"role": "user", "content": "x"}], + ) + assert body["max_output_tokens"] >= 512 + + +def test_to_responses_drops_gateway_unsupported_kwargs() -> None: + """Gateway rejects ``background`` / ``store`` / ``previous_response_id`` etc.""" + body = to_native( + ProviderFamily.OPENAI_RESPONSES, "m", + [{"role": "user", "content": "x"}], + background=True, store=True, + previous_response_id="resp_xyz", service_tier="flex", + ) + for dropped in ("background", "store", "previous_response_id", "service_tier"): + assert dropped not in body, f"{dropped!r} must be dropped — gateway rejects it" + + +def test_to_responses_drops_temperature_and_top_p() -> None: + """GPT-5 reasoning models reject ``temperature`` and ``top_p``. + + The default ``LLM`` ships ``temperature=0.0`` for everyone, so silently + dropping them in the Responses adapter is the only way single-default + callers can talk to GPT-5 at all. + """ + body = to_native( + ProviderFamily.OPENAI_RESPONSES, "m", + [{"role": "user", "content": "x"}], + temperature=0.0, + top_p=0.5, + ) + assert "temperature" not in body + assert "top_p" not in body + + +def test_to_responses_translates_user_text_part_to_input_text() -> None: + """Responses rejects content-part type ``"text"`` for user messages. + + Chat-Completions style ``[{"type": "text", ...}]`` must become + ``[{"type": "input_text", ...}]`` for user/system roles, and string + content must be wrapped in a single ``input_text`` part. + """ + body = to_native( + ProviderFamily.OPENAI_RESPONSES, "m", + [ + {"role": "user", "content": [{"type": "text", "text": "hi"}]}, + {"role": "system", "content": "you are helpful"}, + {"role": "assistant", "content": [{"type": "text", "text": "ok"}]}, + ], + ) + assert body["input"][0]["content"][0] == {"type": "input_text", "text": "hi"} + assert body["input"][1]["content"][0] == { + "type": "input_text", + "text": "you are helpful", + } + # Assistant text parts use output_text. + assert body["input"][2]["content"][0] == {"type": "output_text", "text": "ok"} + + +def test_to_responses_drops_chat_style_max_completion_tokens_alias() -> None: + """Upstream LLM/litellm path emits ``max_completion_tokens`` for OpenAI-flavoured + calls; the Responses API rejects it (``unsupported_parameter`` 400). The + adapter must drop the chat-style alias and only forward ``max_output_tokens``. + """ + body = to_native( + ProviderFamily.OPENAI_RESPONSES, "m", + [{"role": "user", "content": "x"}], + max_tokens=128, + max_completion_tokens=128, + ) + assert "max_completion_tokens" not in body + assert body["max_output_tokens"] == 128 + + +def test_from_responses_flattens_message_items_and_skips_reasoning() -> None: + """Responses output is an array of items; we extract text, skip reasoning.""" + raw = { + "id": "resp_abc", + "status": "completed", + "output": [ + {"type": "reasoning", "content": [{"type": "text", "text": "thinking..."}]}, + {"type": "message", "content": [ + {"type": "output_text", "text": "The answer is 42."}, + ]}, + ], + "usage": {"input_tokens": 4, "output_tokens": 8, "total_tokens": 12}, + } + out = from_native(ProviderFamily.OPENAI_RESPONSES, "m", raw) + assert out["id"] == "resp_abc" + assert out["choices"][0]["message"]["content"] == "The answer is 42." + assert out["choices"][0]["finish_reason"] == "stop" + assert out["usage"]["total_tokens"] == 12 + + +def test_from_responses_converts_function_call_to_tool_calls() -> None: + """Responses function_call items must be converted to Chat Completions tool_calls. + + Without this, GPT-5 tool invocations are silently dropped and OpenHands + loops with 'response did not include a function call'. + """ + raw = { + "id": "resp_xyz", + "status": "completed", + "output": [ + { + "type": "function_call", + "id": "fc_abc", + "call_id": "call_abc", + "name": "terminal", + "arguments": '{"command": "ls"}', + } + ], + "usage": {"input_tokens": 10, "output_tokens": 5, "total_tokens": 15}, + } + out = from_native(ProviderFamily.OPENAI_RESPONSES, "gpt-5", raw) + msg = out["choices"][0]["message"] + assert "tool_calls" in msg + assert len(msg["tool_calls"]) == 1 + tc = msg["tool_calls"][0] + assert tc["id"] == "call_abc" + assert tc["type"] == "function" + assert tc["function"]["name"] == "terminal" + assert tc["function"]["arguments"] == '{"command": "ls"}' + assert out["choices"][0]["finish_reason"] == "tool_calls" + + +def test_from_responses_mixed_text_and_tool_call() -> None: + """Text and function_call in same response are both captured.""" + raw = { + "id": "resp_mix", + "status": "completed", + "output": [ + {"type": "message", "content": [{"type": "output_text", "text": "Running..."}]}, + {"type": "function_call", "call_id": "call_1", "name": "read_file", + "arguments": '{"path": "foo.py"}'}, + ], + "usage": {"input_tokens": 5, "output_tokens": 5, "total_tokens": 10}, + } + out = from_native(ProviderFamily.OPENAI_RESPONSES, "gpt-5", raw) + msg = out["choices"][0]["message"] + assert msg["content"] == "Running..." + assert len(msg["tool_calls"]) == 1 + assert msg["tool_calls"][0]["function"]["name"] == "read_file" + + +def test_from_responses_falls_back_to_aggregated_output_text() -> None: + """When ``output`` is absent but the response exposes flat ``output_text``, + consume that as the assistant message content.""" + raw = {"id": "resp_1", "status": "completed", "output_text": "agg text"} + out = from_native(ProviderFamily.OPENAI_RESPONSES, "m", raw) + assert out["choices"][0]["message"]["content"] == "agg text" + + +# --------------------------------------------------------------------------- +# _chat_tools_to_responses — tool format conversion +# --------------------------------------------------------------------------- + +def test_chat_tools_to_responses_unwraps_function_wrapper() -> None: + """Chat Completions tool {type, function: {name, ...}} → Responses {type, name, ...}.""" + chat_tools = [ + { + "type": "function", + "function": { + "name": "terminal", + "description": "Run a shell command", + "parameters": {"type": "object", "properties": {}}, + }, + } + ] + result = _chat_tools_to_responses(chat_tools) + assert len(result) == 1 + tool = result[0] + # name must be at top level (Responses API requirement) + assert tool["name"] == "terminal" + assert tool["type"] == "function" + assert tool["description"] == "Run a shell command" + # function wrapper must be gone + assert "function" not in tool + + +def test_chat_tools_to_responses_multiple_tools() -> None: + """Conversion handles multiple tools correctly.""" + chat_tools = [ + {"type": "function", "function": {"name": "read_file", "description": "r", "parameters": {}}}, + {"type": "function", "function": {"name": "write_file", "description": "w", "parameters": {}}}, + ] + result = _chat_tools_to_responses(chat_tools) + assert [t["name"] for t in result] == ["read_file", "write_file"] + assert all("function" not in t for t in result) + + +def test_to_responses_converts_tool_format() -> None: + """to_native for OPENAI_RESPONSES flattens tool wrappers into Responses format.""" + chat_tools = [ + { + "type": "function", + "function": { + "name": "terminal", + "description": "Run a shell command", + "parameters": {"type": "object", "properties": {}}, + }, + } + ] + body = to_native( + ProviderFamily.OPENAI_RESPONSES, + "databricks-gpt-5-4", + [{"role": "user", "content": "hello"}], + tools=chat_tools, + ) + assert "tools" in body + tool = body["tools"][0] + assert tool["name"] == "terminal" + assert tool["type"] == "function" + assert "function" not in tool + + +# --------------------------------------------------------------------------- +# _to_responses_input — multi-turn tool call translation +# --------------------------------------------------------------------------- + + +def test_to_responses_tool_role_becomes_function_call_output() -> None: + """``role=tool`` messages must become ``function_call_output`` items. + + When GPT-5 returns a function_call on turn 1, OpenHands sends back the + result as a Chat Completions ``role=tool`` message on turn 2. The Responses + API requires this to be a ``function_call_output`` item (not a message with + role) — otherwise the second turn fails with a schema error. + """ + body = to_native( + ProviderFamily.OPENAI_RESPONSES, + "databricks-gpt-5-4", + [ + {"role": "user", "content": "Write hello.py"}, + { + "role": "assistant", + "content": None, + "tool_calls": [ + { + "id": "call_abc", + "type": "function", + "function": {"name": "write_file", "arguments": '{"path":"hello.py","content":"hi"}'}, + } + ], + }, + { + "role": "tool", + "tool_call_id": "call_abc", + "content": "File written successfully.", + }, + ], + ) + items = body["input"] + types = [i.get("type") or i.get("role") for i in items] + # user message preserved + assert types[0] == "user", f"expected user, got {items[0]}" + # assistant tool_call → function_call item + assert types[1] == "function_call", f"expected function_call, got {items[1]}" + assert items[1]["call_id"] == "call_abc" + assert items[1]["name"] == "write_file" + # tool result → function_call_output item + assert types[2] == "function_call_output", f"expected function_call_output, got {items[2]}" + assert items[2]["call_id"] == "call_abc" + assert items[2]["output"] == "File written successfully." + + +def test_to_responses_assistant_with_tool_calls_emits_function_call_items() -> None: + """Assistant messages with ``tool_calls`` become ``function_call`` input items.""" + body = to_native( + ProviderFamily.OPENAI_RESPONSES, + "m", + [ + {"role": "user", "content": "go"}, + { + "role": "assistant", + "content": None, + "tool_calls": [ + { + "id": "call_1", + "type": "function", + "function": {"name": "read_file", "arguments": '{"path":"a.py"}'}, + }, + { + "id": "call_2", + "type": "function", + "function": {"name": "write_file", "arguments": '{"path":"b.py","content":"x"}'}, + }, + ], + }, + ], + ) + items = body["input"] + # items[0] = user message, items[1] = function_call (read_file), items[2] = function_call (write_file) + assert len(items) == 3 + assert items[1]["type"] == "function_call" + assert items[1]["name"] == "read_file" + assert items[1]["call_id"] == "call_1" + assert items[2]["type"] == "function_call" + assert items[2]["name"] == "write_file" + assert items[2]["call_id"] == "call_2" + + +def test_to_responses_assistant_with_tool_calls_and_text_emits_both() -> None: + """Assistant message with both text content and tool_calls emits both items.""" + body = to_native( + ProviderFamily.OPENAI_RESPONSES, + "m", + [ + { + "role": "assistant", + "content": "I'll run that for you.", + "tool_calls": [ + { + "id": "call_x", + "type": "function", + "function": {"name": "terminal", "arguments": '{"cmd":"ls"}'}, + } + ], + }, + ], + ) + items = body["input"] + fc = next((i for i in items if i.get("type") == "function_call"), None) + assert fc is not None, "function_call item missing" + assert fc["name"] == "terminal" + text_items = [i for i in items if i.get("role") == "assistant"] + assert text_items, "output_text item for text content missing" + assert text_items[0]["content"][0]["text"] == "I'll run that for you." diff --git a/tests/sdk/llm/providers/databricks/test_pkce.py b/tests/sdk/llm/providers/databricks/test_pkce.py new file mode 100644 index 0000000000..940031a52d --- /dev/null +++ b/tests/sdk/llm/providers/databricks/test_pkce.py @@ -0,0 +1,201 @@ +"""Tests for the shared U2M OAuth PKCE primitives. + +Covers: PKCE verifier/challenge S256 correctness, authorize-URL parameters, +and the sync + async code → token exchange (happy path, confidential-app +secret, refresh-token default, PWAF User-Agent header, and error propagation). +""" + +from __future__ import annotations + +import base64 +import hashlib +from unittest.mock import AsyncMock, MagicMock, patch +from urllib.parse import parse_qs, urlparse + +import httpx +import pytest + +from openhands.sdk.llm.providers.databricks.models import StoredU2MTokens +from openhands.sdk.llm.providers.databricks.pkce import ( + async_exchange_code_for_tokens, + build_authorize_url, + exchange_code_for_tokens, + generate_pkce, +) +from openhands.sdk.llm.providers.databricks.utils import USER_AGENT + + +_HOST = "https://adb-123.azuredatabricks.net" +_CLIENT_ID = "oauth-app-client-id" +_REDIRECT = "http://localhost:3000/auth/databricks/callback" + + +# --------------------------------------------------------------------------- +# generate_pkce +# --------------------------------------------------------------------------- + + +def test_generate_pkce_challenge_is_s256_of_verifier() -> None: + verifier, challenge = generate_pkce() + expected = ( + base64.urlsafe_b64encode(hashlib.sha256(verifier.encode()).digest()) + .rstrip(b"=") + .decode() + ) + assert challenge == expected + + +def test_generate_pkce_no_base64_padding() -> None: + verifier, challenge = generate_pkce() + assert "=" not in verifier + assert "=" not in challenge + + +def test_generate_pkce_is_random() -> None: + assert generate_pkce()[0] != generate_pkce()[0] + + +# --------------------------------------------------------------------------- +# build_authorize_url +# --------------------------------------------------------------------------- + + +def test_build_authorize_url_params() -> None: + url = build_authorize_url(_HOST, _CLIENT_ID, _REDIRECT, "state123", "chal456") + parsed = urlparse(url) + assert parsed.scheme == "https" + assert parsed.path == "/oidc/v1/authorize" + qs = parse_qs(parsed.query) + assert qs["response_type"] == ["code"] + assert qs["client_id"] == [_CLIENT_ID] + assert qs["redirect_uri"] == [_REDIRECT] + assert qs["scope"] == ["all-apis offline_access"] + assert qs["state"] == ["state123"] + assert qs["code_challenge"] == ["chal456"] + assert qs["code_challenge_method"] == ["S256"] + + +def test_build_authorize_url_strips_trailing_slash() -> None: + url = build_authorize_url(_HOST + "/", _CLIENT_ID, _REDIRECT, "s", "c") + assert url.startswith(f"{_HOST}/oidc/v1/authorize?") + + +# --------------------------------------------------------------------------- +# exchange_code_for_tokens (sync) +# --------------------------------------------------------------------------- + + +def _mock_token_response() -> MagicMock: + resp = MagicMock() + resp.raise_for_status.return_value = None + resp.json.return_value = { + "access_token": "access-abc", + "refresh_token": "refresh-xyz", + "expires_in": 3600, + } + return resp + + +def test_exchange_code_returns_stored_token_shape() -> None: + with patch("httpx.post", return_value=_mock_token_response()) as mock_post: + payload = exchange_code_for_tokens( + _HOST, _CLIENT_ID, _REDIRECT, "auth-code", "verifier-1" + ) + + # Round-trips through the StoredU2MTokens model. + stored = StoredU2MTokens.model_validate(payload) + assert stored.access_token == "access-abc" + assert stored.refresh_token == "refresh-xyz" + assert stored.client_id == _CLIENT_ID + assert stored.host == _HOST + assert stored.expires_at > 0 + + # Token endpoint + PWAF User-Agent + correct form fields. + args, kwargs = mock_post.call_args + assert args[0] == f"{_HOST}/oidc/v1/token" + assert kwargs["headers"]["User-Agent"] == USER_AGENT + assert kwargs["data"]["grant_type"] == "authorization_code" + assert kwargs["data"]["code"] == "auth-code" + assert kwargs["data"]["code_verifier"] == "verifier-1" + assert "client_secret" not in kwargs["data"] + + +def test_exchange_code_includes_client_secret_for_confidential_app() -> None: + with patch("httpx.post", return_value=_mock_token_response()) as mock_post: + exchange_code_for_tokens( + _HOST, _CLIENT_ID, _REDIRECT, "code", "verifier", + client_secret="super-secret", + ) + assert mock_post.call_args.kwargs["data"]["client_secret"] == "super-secret" + + +def test_exchange_code_defaults_missing_refresh_token() -> None: + resp = MagicMock() + resp.raise_for_status.return_value = None + resp.json.return_value = {"access_token": "a", "expires_in": 1200} + with patch("httpx.post", return_value=resp): + payload = exchange_code_for_tokens( + _HOST, _CLIENT_ID, _REDIRECT, "code", "verifier" + ) + assert payload["refresh_token"] == "" + + +def test_exchange_code_propagates_http_error() -> None: + resp = MagicMock() + resp.raise_for_status.side_effect = httpx.HTTPStatusError( + "invalid_client", request=MagicMock(), response=MagicMock() + ) + with patch("httpx.post", return_value=resp): + with pytest.raises(httpx.HTTPStatusError): + exchange_code_for_tokens(_HOST, _CLIENT_ID, _REDIRECT, "code", "verifier") + + +# --------------------------------------------------------------------------- +# async_exchange_code_for_tokens +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_async_exchange_code_returns_stored_token_shape() -> None: + resp = MagicMock() + resp.raise_for_status.return_value = None + resp.json.return_value = { + "access_token": "access-async", + "refresh_token": "refresh-async", + "expires_in": 3600, + } + mock_client = AsyncMock() + mock_client.__aenter__.return_value = mock_client + mock_client.post.return_value = resp + + with patch("httpx.AsyncClient", return_value=mock_client): + payload = await async_exchange_code_for_tokens( + _HOST, _CLIENT_ID, _REDIRECT, "code", "verifier", + client_secret="conf-secret", + ) + + stored = StoredU2MTokens.model_validate(payload) + assert stored.access_token == "access-async" + assert stored.host == _HOST + + # PWAF User-Agent + confidential secret forwarded on the async path too. + kwargs = mock_client.post.call_args.kwargs + assert kwargs["headers"]["User-Agent"] == USER_AGENT + assert kwargs["data"]["client_secret"] == "conf-secret" + + +@pytest.mark.asyncio +async def test_async_exchange_code_propagates_http_error() -> None: + resp = MagicMock() + resp.raise_for_status.side_effect = httpx.HTTPStatusError( + "boom", request=MagicMock(), response=MagicMock() + ) + mock_client = AsyncMock() + mock_client.__aenter__.return_value = mock_client + mock_client.post.return_value = resp + + with patch("httpx.AsyncClient", return_value=mock_client): + with pytest.raises(httpx.HTTPStatusError): + await async_exchange_code_for_tokens( + _HOST, _CLIENT_ID, _REDIRECT, "code", "verifier" + ) diff --git a/tests/sdk/llm/providers/databricks/test_resilience.py b/tests/sdk/llm/providers/databricks/test_resilience.py new file mode 100644 index 0000000000..093e0d0fac --- /dev/null +++ b/tests/sdk/llm/providers/databricks/test_resilience.py @@ -0,0 +1,241 @@ +"""Tests for the fetch_with_retry retry loop. + +Covers: success on first try, retry on 429/503, no retry on 400/401/403, +Retry-After cap at 300s (P1-4), exponential backoff, connection error retry, +and exhaustion of all retries. +""" + +from __future__ import annotations + +from unittest.mock import MagicMock, call, patch + +import httpx +import pytest +from litellm.exceptions import ( + APIConnectionError, + AuthenticationError, + BadRequestError, + RateLimitError, + ServiceUnavailableError, +) + +from openhands.sdk.llm.providers.databricks.utils import ( + fetch_with_retry, +) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +_URL = "https://adb-123.azuredatabricks.net/serving-endpoints/my-model/invocations" +_HEADERS = {"Authorization": "Bearer tok", "Content-Type": "application/json"} +_PAYLOAD = {"messages": [{"role": "user", "content": "hi"}]} + + +def _make_client_with_responses(*responses: httpx.Response) -> httpx.Client: + """Return a mock httpx.Client whose .post() yields responses in order.""" + mock_client = MagicMock(spec=httpx.Client) + mock_client.post.side_effect = list(responses) + return mock_client + + +# --------------------------------------------------------------------------- +# Happy path +# --------------------------------------------------------------------------- + +def test_fetch_with_retry_success_first_try() -> None: + """200 response on first attempt — no retry, no sleep.""" + ok_resp = httpx.Response(200, json={"id": "1", "choices": []}) + client = _make_client_with_responses(ok_resp) + + with patch("time.sleep") as mock_sleep: + result = fetch_with_retry(client, _URL, _HEADERS, _PAYLOAD, max_retries=3) + + assert result.status_code == 200 + mock_sleep.assert_not_called() + client.post.assert_called_once() + + +# --------------------------------------------------------------------------- +# 429 / 5xx — should retry +# --------------------------------------------------------------------------- + +def test_fetch_with_retry_retries_on_429() -> None: + """429 → sleep → retry → 200 on second attempt.""" + r429 = httpx.Response(429, json={"message": "Rate limit"}) + r200 = httpx.Response(200, json={"id": "1", "choices": []}) + client = _make_client_with_responses(r429, r200) + + with patch("time.sleep"): + result = fetch_with_retry(client, _URL, _HEADERS, _PAYLOAD, max_retries=3) + + assert result.status_code == 200 + assert client.post.call_count == 2 + + +def test_fetch_with_retry_retries_on_503() -> None: + """503 → retry → 200.""" + r503 = httpx.Response(503, json={"message": "Service unavailable"}) + r200 = httpx.Response(200, json={"id": "1", "choices": []}) + client = _make_client_with_responses(r503, r200) + + with patch("time.sleep"): + result = fetch_with_retry(client, _URL, _HEADERS, _PAYLOAD, max_retries=3) + + assert result.status_code == 200 + + +def test_fetch_with_retry_uses_retry_after_header() -> None: + """When server sends Retry-After header, sleep that many seconds (capped at 300s).""" + r429 = httpx.Response(429, headers={"Retry-After": "42"}, json={"message": "rl"}) + r200 = httpx.Response(200, json={"id": "1", "choices": []}) + client = _make_client_with_responses(r429, r200) + + slept: list[float] = [] + with patch("time.sleep", side_effect=lambda s: slept.append(s)): + fetch_with_retry(client, _URL, _HEADERS, _PAYLOAD, max_retries=3) + + assert len(slept) == 1 + assert slept[0] == 42.0 + + +def test_fetch_with_retry_caps_retry_after_at_300s() -> None: + """Retry-After > 300s must be capped at 300s (P1-4).""" + r429 = httpx.Response(429, headers={"Retry-After": "9999"}, json={"message": "rl"}) + r200 = httpx.Response(200, json={"id": "1", "choices": []}) + client = _make_client_with_responses(r429, r200) + + slept: list[float] = [] + with patch("time.sleep", side_effect=lambda s: slept.append(s)): + fetch_with_retry(client, _URL, _HEADERS, _PAYLOAD, max_retries=3) + + assert len(slept) == 1 + assert slept[0] == 300.0 + + +def test_fetch_with_retry_exhausts_and_raises_rate_limit() -> None: + """All retries exhausted on 429 → raises RateLimitError.""" + r429 = httpx.Response(429, json={"message": "Rate limit"}) + client = _make_client_with_responses(r429, r429, r429, r429) + + with patch("time.sleep"): + with pytest.raises(RateLimitError): + fetch_with_retry(client, _URL, _HEADERS, _PAYLOAD, max_retries=3) + + assert client.post.call_count == 4 # 1 initial + 3 retries + + +def test_fetch_with_retry_exhausts_and_raises_service_unavailable() -> None: + """All retries exhausted on 503 → raises ServiceUnavailableError.""" + r503 = httpx.Response(503, json={"message": "Down"}) + client = _make_client_with_responses(r503, r503, r503, r503) + + with patch("time.sleep"): + with pytest.raises(ServiceUnavailableError): + fetch_with_retry(client, _URL, _HEADERS, _PAYLOAD, max_retries=3) + + +# --------------------------------------------------------------------------- +# Non-retryable status codes +# --------------------------------------------------------------------------- + +def test_fetch_with_retry_does_not_retry_on_400() -> None: + """400 raises BadRequestError immediately without retry.""" + r400 = httpx.Response(400, json={"message": "Bad request"}) + client = _make_client_with_responses(r400) + + with patch("time.sleep") as mock_sleep: + with pytest.raises(BadRequestError): + fetch_with_retry(client, _URL, _HEADERS, _PAYLOAD, max_retries=3) + + mock_sleep.assert_not_called() + client.post.assert_called_once() + + +def test_fetch_with_retry_does_not_retry_on_401() -> None: + """401 raises AuthenticationError immediately without retry.""" + r401 = httpx.Response(401, json={"message": "Unauthorized"}) + client = _make_client_with_responses(r401) + + with patch("time.sleep") as mock_sleep: + with pytest.raises(AuthenticationError): + fetch_with_retry(client, _URL, _HEADERS, _PAYLOAD, max_retries=3) + + mock_sleep.assert_not_called() + + +def test_fetch_with_retry_does_not_retry_on_403() -> None: + """403 raises AuthenticationError immediately.""" + r403 = httpx.Response(403, json={"message": "Forbidden"}) + client = _make_client_with_responses(r403) + + with pytest.raises(AuthenticationError): + with patch("time.sleep"): + fetch_with_retry(client, _URL, _HEADERS, _PAYLOAD, max_retries=3) + + +def test_fetch_with_retry_does_not_retry_on_422() -> None: + """422 raises BadRequestError immediately.""" + r422 = httpx.Response(422, json={"message": "Validation error"}) + client = _make_client_with_responses(r422) + + with pytest.raises(BadRequestError): + with patch("time.sleep"): + fetch_with_retry(client, _URL, _HEADERS, _PAYLOAD, max_retries=3) + + +# --------------------------------------------------------------------------- +# RETRYABLE_EXCEPTIONS (network errors) +# --------------------------------------------------------------------------- + +def test_fetch_with_retry_retries_on_connect_error() -> None: + """httpx.ConnectError is retried (network transient failure).""" + mock_client = MagicMock(spec=httpx.Client) + ok_resp = httpx.Response(200, json={"id": "1", "choices": []}) + mock_client.post.side_effect = [httpx.ConnectError("connection refused"), ok_resp] + + with patch("time.sleep"): + result = fetch_with_retry(mock_client, _URL, _HEADERS, _PAYLOAD, max_retries=3) + + assert result.status_code == 200 + assert mock_client.post.call_count == 2 + + +def test_fetch_with_retry_raises_api_connection_error_after_network_exhaustion() -> None: + """Persistent ConnectError after all retries → APIConnectionError.""" + mock_client = MagicMock(spec=httpx.Client) + mock_client.post.side_effect = httpx.ConnectError("connection refused") + + with patch("time.sleep"): + with pytest.raises(APIConnectionError): + fetch_with_retry(mock_client, _URL, _HEADERS, _PAYLOAD, max_retries=2) + + +def test_fetch_with_retry_retries_on_read_timeout() -> None: + """httpx.ReadTimeout is retried.""" + mock_client = MagicMock(spec=httpx.Client) + ok_resp = httpx.Response(200, json={"id": "1", "choices": []}) + mock_client.post.side_effect = [httpx.ReadTimeout("timed out"), ok_resp] + + with patch("time.sleep"): + result = fetch_with_retry(mock_client, _URL, _HEADERS, _PAYLOAD, max_retries=3) + + assert result.status_code == 200 + + +# --------------------------------------------------------------------------- +# max_retries=0 means one attempt only +# --------------------------------------------------------------------------- + +def test_fetch_with_retry_zero_retries_raises_immediately_on_429() -> None: + """max_retries=0: no retry on a retryable status — raises after first attempt.""" + r429 = httpx.Response(429, json={"message": "Rate limit"}) + client = _make_client_with_responses(r429) + + with patch("time.sleep") as mock_sleep: + with pytest.raises(RateLimitError): + fetch_with_retry(client, _URL, _HEADERS, _PAYLOAD, max_retries=0) + + mock_sleep.assert_not_called() + client.post.assert_called_once() diff --git a/tests/sdk/llm/providers/databricks/test_settings_bridge.py b/tests/sdk/llm/providers/databricks/test_settings_bridge.py new file mode 100644 index 0000000000..6b015a58cf --- /dev/null +++ b/tests/sdk/llm/providers/databricks/test_settings_bridge.py @@ -0,0 +1,214 @@ +"""Tests for the ``settings → create_llm(...)`` kwargs bridge. + +Covers: + +* Empty settings only yield ``usage_id``. +* Full passthrough with secret coercion and ``StoredU2MTokens`` normalization. +* ``None`` / empty-string dropping. +* ``model_override`` and ``base_url_fallback`` semantics. +* ``extras`` winning over settings values. +* Invalid ``stored_u2m_tokens`` dict silently ignored. +* **Contract / drift guard** — every public field on ``DatabricksLLM`` is either + bridged from settings or explicitly listed in ``_NOT_BRIDGED``. Adding a new + Databricks field without updating the bridge fails this test. +""" + +from __future__ import annotations + +from types import SimpleNamespace + +from pydantic import SecretStr + +from openhands.sdk.llm.providers.databricks import DatabricksLLM, StoredU2MTokens +from openhands.sdk.llm.providers.databricks.settings_bridge import ( + _BRIDGE_FIELDS, + _NOT_BRIDGED, + UserInfoAliases, + kwargs_from_settings, +) + + +# --------------------------------------------------------------------------- +# Behavior +# --------------------------------------------------------------------------- + + +def test_empty_settings_only_includes_usage_id() -> None: + kw = kwargs_from_settings(SimpleNamespace(), usage_id="agent") + assert kw == {"usage_id": "agent"} + + +def test_full_settings_passthrough() -> None: + settings = SimpleNamespace( + model="databricks/databricks-claude-sonnet-4-5", + api_key="dapi-1234", + base_url="https://workspace.cloud.databricks.com", + timeout=60.0, + max_input_tokens=128_000, + databricks_host="https://workspace.cloud.databricks.com", + databricks_client_id="app-id", + databricks_client_secret="client-secret-raw", + databricks_profile="DEFAULT", + databricks_ssl_verify=True, + databricks_max_retries=5, + databricks_connect_timeout_s=7.0, + databricks_read_timeout_s=90.0, + databricks_chunk_timeout_s=25.0, + stored_u2m_tokens={ + "access_token": "at", + "refresh_token": "rt", + "expires_at": 9999999999.0, + "client_id": "u2m-client", + "host": "https://workspace.cloud.databricks.com", + }, + ) + kw = kwargs_from_settings(settings, usage_id="agent") + + assert kw["usage_id"] == "agent" + assert kw["model"] == "databricks/databricks-claude-sonnet-4-5" + assert isinstance(kw["api_key"], SecretStr) + assert kw["api_key"].get_secret_value() == "dapi-1234" + assert isinstance(kw["databricks_client_secret"], SecretStr) + assert kw["databricks_client_secret"].get_secret_value() == "client-secret-raw" + assert isinstance(kw["stored_u2m_tokens"], StoredU2MTokens) + assert kw["databricks_profile"] == "DEFAULT" + assert kw["timeout"] == 60.0 + assert kw["max_input_tokens"] == 128_000 + + +def test_secretstr_roundtrips_unchanged() -> None: + s = SimpleNamespace(api_key=SecretStr("dapi-abc")) + kw = kwargs_from_settings(s, usage_id="agent") + assert isinstance(kw["api_key"], SecretStr) + assert kw["api_key"].get_secret_value() == "dapi-abc" + + +def test_none_and_empty_strings_dropped() -> None: + s = SimpleNamespace(model="", api_key=None, base_url="", databricks_profile=None) + kw = kwargs_from_settings(s, usage_id="agent") + assert set(kw) == {"usage_id"} + + +def test_model_override_wins_over_settings() -> None: + s = SimpleNamespace(model="databricks/foo") + kw = kwargs_from_settings(s, usage_id="agent", model_override="databricks/bar") + assert kw["model"] == "databricks/bar" + + +def test_base_url_fallback_only_when_both_empty() -> None: + kw = kwargs_from_settings( + SimpleNamespace(), + usage_id="agent", + base_url_fallback="https://fallback.com", + ) + assert kw["base_url"] == "https://fallback.com" + + kw_explicit = kwargs_from_settings( + SimpleNamespace(base_url="https://explicit.com"), + usage_id="agent", + base_url_fallback="https://fallback.com", + ) + assert kw_explicit["base_url"] == "https://explicit.com" + + kw_host_only = kwargs_from_settings( + SimpleNamespace(databricks_host="https://ws.cloud.databricks.com"), + usage_id="agent", + base_url_fallback="https://fallback.com", + ) + assert "base_url" not in kw_host_only + assert kw_host_only["databricks_host"] == "https://ws.cloud.databricks.com" + + +def test_extras_win_over_settings_and_coerce_secrets() -> None: + s = SimpleNamespace(api_key="dapi-stored", databricks_profile="OLD") + kw = kwargs_from_settings( + s, + usage_id="agent", + extras={"api_key": "dapi-session", "databricks_profile": "PROD"}, + ) + assert isinstance(kw["api_key"], SecretStr) + assert kw["api_key"].get_secret_value() == "dapi-session" + assert kw["databricks_profile"] == "PROD" + + +def test_extras_none_values_dropped() -> None: + s = SimpleNamespace(api_key="dapi-stored") + kw = kwargs_from_settings( + s, + usage_id="agent", + extras={"api_key": None, "databricks_profile": None}, + ) + assert kw["api_key"].get_secret_value() == "dapi-stored" + assert "databricks_profile" not in kw + + +def test_aliases_map_userinfo_style_prefixed_fields() -> None: + """UserInfo uses ``llm_model`` / ``llm_api_key`` / ``llm_base_url``. + + The aliases mechanism lets the bridge read them without requiring each + caller to build a shim object. + """ + s = SimpleNamespace( + llm_model="databricks/databricks-gemini-2-5-pro", + llm_api_key="dapi-user", + llm_base_url="https://ws.cloud.databricks.com", + databricks_profile="PROD", + ) + kw = kwargs_from_settings(s, usage_id="agent", aliases=UserInfoAliases) + assert kw["model"] == "databricks/databricks-gemini-2-5-pro" + assert isinstance(kw["api_key"], SecretStr) + assert kw["api_key"].get_secret_value() == "dapi-user" + assert kw["base_url"] == "https://ws.cloud.databricks.com" + assert kw["databricks_profile"] == "PROD" + + +def test_canonical_name_wins_over_alias() -> None: + """If both ``api_key`` and ``llm_api_key`` are set, canonical wins.""" + s = SimpleNamespace(api_key="canonical", llm_api_key="aliased") + kw = kwargs_from_settings(s, usage_id="agent", aliases=UserInfoAliases) + assert kw["api_key"].get_secret_value() == "canonical" + + +def test_invalid_stored_u2m_dict_is_skipped() -> None: + s = SimpleNamespace(stored_u2m_tokens={"totally": "bogus"}) + kw = kwargs_from_settings(s, usage_id="agent") + assert "stored_u2m_tokens" not in kw + + +def test_stored_u2m_instance_passes_through() -> None: + tok = StoredU2MTokens( + access_token="at", + refresh_token="rt", + expires_at=9999999999.0, + client_id="u2m-client", + host="https://workspace.cloud.databricks.com", + ) + s = SimpleNamespace(stored_u2m_tokens=tok) + kw = kwargs_from_settings(s, usage_id="agent") + assert kw["stored_u2m_tokens"] is tok + + +# --------------------------------------------------------------------------- +# Drift guard +# --------------------------------------------------------------------------- + + +def test_bridge_covers_all_databricks_llm_public_fields() -> None: + """Fails when a new ``DatabricksLLM`` field is added without updating the bridge. + + Guarantees OpenHands backend + OpenHands-CLI always build the full kwarg + surface when the SDK gains a new Databricks-specific field. + """ + own_fields = { + name + for name in DatabricksLLM.__annotations__ + if not name.startswith("_") + } + covered = set(_BRIDGE_FIELDS) | _NOT_BRIDGED + missing = own_fields - covered + + assert not missing, ( + 'New DatabricksLLM field(s) not handled by the settings bridge: ' + f'{sorted(missing)}. Extend _BRIDGE_FIELDS (and every call site) or add ' + 'to _NOT_BRIDGED in settings_bridge.py.' + ) diff --git a/tests/sdk/llm/providers/databricks/test_utils.py b/tests/sdk/llm/providers/databricks/test_utils.py new file mode 100644 index 0000000000..b55ae5e4be --- /dev/null +++ b/tests/sdk/llm/providers/databricks/test_utils.py @@ -0,0 +1,254 @@ +"""Tests for Databricks FMAPI resilience utilities. + +Covers: USER_AGENT format, DatabricksTimeouts defaults, compute_backoff (Retry-After cap, +full-jitter fallback), normalize_host, map_databricks_error, validate_databricks_config, +and the _raise_non_retryable / _raise_mapped helper functions. +""" + +from __future__ import annotations + +import pytest +from unittest.mock import patch +from litellm.exceptions import ( + AuthenticationError, + BadRequestError, + RateLimitError, + ServiceUnavailableError, +) + +import httpx + +from openhands.sdk.llm.providers.databricks.utils import ( + USER_AGENT, + DatabricksTimeouts, + _RETRY_AFTER_MAX_S, + _raise_mapped, + _raise_non_retryable, + compute_backoff, + map_databricks_error, + normalize_host, + validate_databricks_config, +) +from openhands.sdk.llm.providers.databricks.auth import AuthStrategy + + +# --------------------------------------------------------------------------- +# USER_AGENT +# --------------------------------------------------------------------------- + +def test_user_agent_format() -> None: + """PWAF: USER_AGENT must be '/' and identify as OpenHandsOSS.""" + assert USER_AGENT.startswith("OpenHandsOSS/"), ( + f"PWAF non-compliant User-Agent: {USER_AGENT!r}" + ) + # Version portion must be non-empty + _, version = USER_AGENT.split("/", 1) + assert version, f"User-Agent missing version: {USER_AGENT!r}" + + +def test_user_agent_no_newlines() -> None: + """USER_AGENT must not contain whitespace (HTTP header safety).""" + assert "\n" not in USER_AGENT + assert "\r" not in USER_AGENT + + +# --------------------------------------------------------------------------- +# DatabricksTimeouts +# --------------------------------------------------------------------------- + +def test_databricks_timeouts_defaults() -> None: + t = DatabricksTimeouts() + assert t.connect_s == 10.0 + assert t.read_s == 120.0 + assert t.chunk_s == 30.0 + assert t.pool_s == 5.0 + + +def test_databricks_timeouts_override() -> None: + t = DatabricksTimeouts(connect_s=5.0, read_s=60.0, chunk_s=15.0) + assert t.connect_s == 5.0 + assert t.read_s == 60.0 + assert t.chunk_s == 15.0 + + +# --------------------------------------------------------------------------- +# compute_backoff +# --------------------------------------------------------------------------- + +def test_compute_backoff_retry_after_within_cap() -> None: + """Retry-After of 10s → sleep 10s.""" + result = compute_backoff(attempt=0, retry_after="10") + assert result == 10.0 + + +def test_compute_backoff_retry_after_exceeds_cap() -> None: + """Retry-After above _RETRY_AFTER_MAX_S is capped (P1-4: 300s cap).""" + result = compute_backoff(attempt=0, retry_after="999") + assert result == _RETRY_AFTER_MAX_S + + +def test_compute_backoff_retry_after_at_cap() -> None: + """Retry-After exactly at _RETRY_AFTER_MAX_S passes through.""" + result = compute_backoff(attempt=0, retry_after=str(_RETRY_AFTER_MAX_S)) + assert result == _RETRY_AFTER_MAX_S + + +def test_compute_backoff_no_retry_after_is_bounded() -> None: + """Full-jitter fallback: result is in [0, min(60, 1 * 2^attempt)].""" + for attempt in range(6): + result = compute_backoff(attempt=attempt) + max_wait = min(60.0, 1.0 * (2**attempt)) + assert 0.0 <= result <= max_wait, ( + f"attempt={attempt}: backoff {result:.3f} outside [0, {max_wait}]" + ) + + +def test_compute_backoff_caps_at_60s() -> None: + """Full-jitter backoff never exceeds 60s for any attempt.""" + with patch("random.uniform", return_value=1.0): # worst case: full multiplier + for attempt in range(10, 20): + result = compute_backoff(attempt=attempt) + assert result <= 60.0 + + +# --------------------------------------------------------------------------- +# normalize_host +# --------------------------------------------------------------------------- + +def test_normalize_host_adds_https() -> None: + assert normalize_host("adb-123.azuredatabricks.net") == ( + "https://adb-123.azuredatabricks.net" + ) + + +def test_normalize_host_strips_trailing_slash() -> None: + assert normalize_host("https://adb-123.azuredatabricks.net/") == ( + "https://adb-123.azuredatabricks.net" + ) + + +def test_normalize_host_already_correct() -> None: + host = "https://adb-123.azuredatabricks.net" + assert normalize_host(host) == host + + +def test_normalize_host_strips_multiple_slashes() -> None: + assert normalize_host("https://adb-123.azuredatabricks.net///") == ( + "https://adb-123.azuredatabricks.net" + ) + + +# --------------------------------------------------------------------------- +# map_databricks_error +# --------------------------------------------------------------------------- + +def test_map_databricks_error_message_field() -> None: + msg = map_databricks_error(429, {"message": "Rate limit exceeded"}) + assert "429" in msg + assert "Rate limit exceeded" in msg + + +def test_map_databricks_error_error_description_field() -> None: + msg = map_databricks_error(401, {"error_description": "token expired"}) + assert "401" in msg + assert "token expired" in msg + + +def test_map_databricks_error_error_field() -> None: + msg = map_databricks_error(500, {"error": "internal server error"}) + assert "500" in msg + assert "internal server error" in msg + + +def test_map_databricks_error_empty_body() -> None: + msg = map_databricks_error(503, {}) + assert "503" in msg + assert "Unknown error" in msg + + +# --------------------------------------------------------------------------- +# validate_databricks_config +# --------------------------------------------------------------------------- + +def test_validate_databricks_config_missing_host() -> None: + with pytest.raises(ValueError, match="host is required"): + validate_databricks_config(None, AuthStrategy.PAT) + + +def test_validate_databricks_config_no_https() -> None: + with pytest.raises(ValueError, match="must start with 'https://'"): + validate_databricks_config("http://adb-123.net", AuthStrategy.PAT) + + +def test_validate_databricks_config_u2m_no_tokens() -> None: + """U2M without stored tokens must raise before any HTTP call.""" + with pytest.raises(ValueError, match="stored OAuth tokens"): + validate_databricks_config( + "https://adb-123.azuredatabricks.net", + AuthStrategy.U2M, + stored_tokens=None, + ) + + +def test_validate_databricks_config_m2m_missing_client_id() -> None: + with pytest.raises(ValueError, match="DATABRICKS_CLIENT_ID"): + validate_databricks_config( + "https://adb-123.azuredatabricks.net", + AuthStrategy.M2M, + client_id=None, + client_secret="secret", + ) + + +def test_validate_databricks_config_m2m_missing_client_secret() -> None: + with pytest.raises(ValueError, match="DATABRICKS_CLIENT_SECRET"): + validate_databricks_config( + "https://adb-123.azuredatabricks.net", + AuthStrategy.M2M, + client_id="client-id", + client_secret=None, + ) + + +def test_validate_databricks_config_pat_passes() -> None: + """PAT path: only host is required.""" + # Should not raise + validate_databricks_config("https://adb-123.azuredatabricks.net", AuthStrategy.PAT) + + +# --------------------------------------------------------------------------- +# _raise_non_retryable +# --------------------------------------------------------------------------- + +def test_raise_non_retryable_401() -> None: + resp = httpx.Response(401, json={"message": "Unauthorized"}) + with pytest.raises(AuthenticationError): + _raise_non_retryable(resp) + + +def test_raise_non_retryable_400() -> None: + resp = httpx.Response(400, json={"message": "Bad request"}) + with pytest.raises(BadRequestError): + _raise_non_retryable(resp) + + +def test_raise_non_retryable_403() -> None: + resp = httpx.Response(403, json={"message": "Forbidden"}) + with pytest.raises(AuthenticationError): + _raise_non_retryable(resp) + + +# --------------------------------------------------------------------------- +# _raise_mapped +# --------------------------------------------------------------------------- + +def test_raise_mapped_429() -> None: + resp = httpx.Response(429, json={"message": "Rate limit exceeded"}) + with pytest.raises(RateLimitError): + _raise_mapped(resp) + + +def test_raise_mapped_503() -> None: + resp = httpx.Response(503, json={"message": "Service unavailable"}) + with pytest.raises(ServiceUnavailableError): + _raise_mapped(resp)