From ad37f4872edbcfbcc6a38b07c8561ab6b65e581b Mon Sep 17 00:00:00 2001 From: Prasad Kona Date: Tue, 21 Apr 2026 14:41:49 -0700 Subject: [PATCH 01/21] feat(llm): add native Databricks Foundation Model API provider MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds DatabricksLLM — a native provider for the Databricks AI Gateway that bypasses LiteLLM and routes directly to the correct per-family endpoint: - Anthropic Claude → /anthropic/v1/messages - Google Gemini → /gemini/v1/generateContent - OpenAI GPT-5+ → /openai/v1/responses (gpt-\d routing rule) - All others → /mlflow/v1/chat/completions Auth: PAT, M2M (service-principal), CLI profile, and U2M (browser SSO via databricks-sdk). All auth strategies resolve credentials lazily so saving settings succeeds before the optional databricks-sdk package is installed. Base class changes are minimal and PR-friendly: - `LLM`: slim 15-line dispatch validator (generic subclass discovery, no hardcoded names); no new fields on the base class - `AgentBase.llm` + `LLMSummarizingCondenser.llm`: `SerializeAsAny` annotation so DatabricksLLM fields survive agent save/load round-trips - `model_features.py`: early-return guard for `databricks/` prefix - `__init__.py`: additive `create_llm` factory Includes 275 unit tests covering auth, client, discovery, routing, native API translation (multi-turn tool calls, Responses API format), resilience, and settings bridge. --- openhands-sdk/openhands/sdk/__init__.py | 22 + openhands-sdk/openhands/sdk/agent/base.py | 5 +- .../condenser/llm_summarizing_condenser.py | 5 +- openhands-sdk/openhands/sdk/llm/llm.py | 48 ++ .../openhands/sdk/llm/providers/__init__.py | 3 + .../sdk/llm/providers/databricks/README.md | 167 +++++ .../sdk/llm/providers/databricks/__init__.py | 93 +++ .../sdk/llm/providers/databricks/auth.py | 344 ++++++++++ .../sdk/llm/providers/databricks/client.py | 404 +++++++++++ .../sdk/llm/providers/databricks/discovery.py | 329 +++++++++ .../sdk/llm/providers/databricks/llm.py | 352 ++++++++++ .../sdk/llm/providers/databricks/models.py | 287 ++++++++ .../sdk/llm/providers/databricks/native.py | 570 ++++++++++++++++ .../providers/databricks/settings_bridge.py | 181 +++++ .../sdk/llm/providers/databricks/utils.py | 269 ++++++++ .../openhands/sdk/llm/utils/model_features.py | 19 + openhands-sdk/pyproject.toml | 3 + tests/sdk/llm/providers/__init__.py | 0 .../sdk/llm/providers/databricks/__init__.py | 0 .../sdk/llm/providers/databricks/conftest.py | 82 +++ .../sdk/llm/providers/databricks/test_auth.py | 339 ++++++++++ .../llm/providers/databricks/test_client.py | 382 +++++++++++ .../providers/databricks/test_discovery.py | 573 ++++++++++++++++ .../sdk/llm/providers/databricks/test_llm.py | 441 ++++++++++++ .../llm/providers/databricks/test_models.py | 286 ++++++++ .../llm/providers/databricks/test_native.py | 635 ++++++++++++++++++ .../providers/databricks/test_resilience.py | 241 +++++++ .../databricks/test_settings_bridge.py | 214 ++++++ .../llm/providers/databricks/test_utils.py | 255 +++++++ 29 files changed, 6545 insertions(+), 4 deletions(-) create mode 100644 openhands-sdk/openhands/sdk/llm/providers/__init__.py create mode 100644 openhands-sdk/openhands/sdk/llm/providers/databricks/README.md create mode 100644 openhands-sdk/openhands/sdk/llm/providers/databricks/__init__.py create mode 100644 openhands-sdk/openhands/sdk/llm/providers/databricks/auth.py create mode 100644 openhands-sdk/openhands/sdk/llm/providers/databricks/client.py create mode 100644 openhands-sdk/openhands/sdk/llm/providers/databricks/discovery.py create mode 100644 openhands-sdk/openhands/sdk/llm/providers/databricks/llm.py create mode 100644 openhands-sdk/openhands/sdk/llm/providers/databricks/models.py create mode 100644 openhands-sdk/openhands/sdk/llm/providers/databricks/native.py create mode 100644 openhands-sdk/openhands/sdk/llm/providers/databricks/settings_bridge.py create mode 100644 openhands-sdk/openhands/sdk/llm/providers/databricks/utils.py create mode 100644 tests/sdk/llm/providers/__init__.py create mode 100644 tests/sdk/llm/providers/databricks/__init__.py create mode 100644 tests/sdk/llm/providers/databricks/conftest.py create mode 100644 tests/sdk/llm/providers/databricks/test_auth.py create mode 100644 tests/sdk/llm/providers/databricks/test_client.py create mode 100644 tests/sdk/llm/providers/databricks/test_discovery.py create mode 100644 tests/sdk/llm/providers/databricks/test_llm.py create mode 100644 tests/sdk/llm/providers/databricks/test_models.py create mode 100644 tests/sdk/llm/providers/databricks/test_native.py create mode 100644 tests/sdk/llm/providers/databricks/test_resilience.py create mode 100644 tests/sdk/llm/providers/databricks/test_settings_bridge.py create mode 100644 tests/sdk/llm/providers/databricks/test_utils.py diff --git a/openhands-sdk/openhands/sdk/__init__.py b/openhands-sdk/openhands/sdk/__init__.py index e2fb6f0bb6..16e9398bd3 100644 --- a/openhands-sdk/openhands/sdk/__init__.py +++ b/openhands-sdk/openhands/sdk/__init__.py @@ -148,6 +148,27 @@ def __getattr__(name: str) -> Any: raise AttributeError(f"module {__name__!r} has no attribute {name!r}") +def create_llm(model: str = "", **kwargs) -> LLM: + """Factory function that routes to the correct LLM subclass. + + Routes models with the "databricks/" prefix to DatabricksLLM (native FMAPI + provider, bypassing LiteLLM). All other models use the base LLM class. + + Uses lazy import for DatabricksLLM to avoid circular import and to keep + Databricks dependencies optional (install via: pip install openhands-sdk[databricks]). + + Example: + llm = create_llm("databricks/databricks-meta-llama-3-3-70b-instruct", + databricks_host="https://adb-xxx.azuredatabricks.net", + api_key=SecretStr("dapi...")) + llm = create_llm("claude-sonnet-4-20250514", api_key=SecretStr("sk-ant-...")) + """ + if model.startswith("databricks/"): + from openhands.sdk.llm.providers.databricks.llm import DatabricksLLM + + return DatabricksLLM(model=model, **kwargs) + return LLM(model=model, **kwargs) + __all__ = [ "LLM", "LLM_PROFILE_SCHEMA_VERSION", @@ -232,4 +253,5 @@ def __getattr__(name: str) -> Any: "load_user_skills", "page_iterator", "__version__", + "create_llm", ] diff --git a/openhands-sdk/openhands/sdk/agent/base.py b/openhands-sdk/openhands/sdk/agent/base.py index dbd580e2d1..109b9ffbc7 100644 --- a/openhands-sdk/openhands/sdk/agent/base.py +++ b/openhands-sdk/openhands/sdk/agent/base.py @@ -8,7 +8,7 @@ from abc import ABC, abstractmethod from collections.abc import Generator, Iterable, Sequence from concurrent.futures import ThreadPoolExecutor -from typing import TYPE_CHECKING, Any, Literal +from typing import TYPE_CHECKING, Annotated, Any, Literal from pydantic import ( BaseModel, @@ -17,6 +17,7 @@ PrivateAttr, SecretStr, SerializationInfo, + SerializeAsAny, ValidationInfo, model_serializer, model_validator, @@ -110,7 +111,7 @@ class AgentBase(DiscriminatedUnionMixin, ABC): arbitrary_types_allowed=True, ) - llm: LLM = Field( + llm: Annotated[LLM, SerializeAsAny()] = Field( ..., description="LLM configuration for the agent.", examples=[ diff --git a/openhands-sdk/openhands/sdk/context/condenser/llm_summarizing_condenser.py b/openhands-sdk/openhands/sdk/context/condenser/llm_summarizing_condenser.py index aae387be9b..1544843054 100644 --- a/openhands-sdk/openhands/sdk/context/condenser/llm_summarizing_condenser.py +++ b/openhands-sdk/openhands/sdk/context/condenser/llm_summarizing_condenser.py @@ -1,8 +1,9 @@ import os from collections.abc import Sequence from enum import Enum +from typing import Annotated -from pydantic import Field, model_validator +from pydantic import Field, SerializeAsAny, model_validator from openhands.sdk.context.condenser.base import ( CondensationRequirement, @@ -43,7 +44,7 @@ class LLMSummarizingCondenser(RollingCondenser): it is the same as the one defined in this condenser. """ - llm: LLM + llm: Annotated[LLM, SerializeAsAny()] max_size: int = Field(default=240, gt=0) max_tokens: int | None = None diff --git a/openhands-sdk/openhands/sdk/llm/llm.py b/openhands-sdk/openhands/sdk/llm/llm.py index 3bee201e0b..2f4a83a9ec 100644 --- a/openhands-sdk/openhands/sdk/llm/llm.py +++ b/openhands-sdk/openhands/sdk/llm/llm.py @@ -14,8 +14,10 @@ BaseModel, ConfigDict, Field, + ModelWrapValidatorHandler, PrivateAttr, SecretStr, + ValidationInfo, field_serializer, field_validator, model_validator, @@ -477,6 +479,52 @@ class LLM(BaseModel, RetryMixin, NonNativeToolCallingMixin): # ========================================================================= # Validators # ========================================================================= + @model_validator(mode="wrap") + @classmethod + def _dispatch_to_provider_subclass( + cls, + data: Any, + handler: ModelWrapValidatorHandler["LLM"], + info: ValidationInfo, + ) -> "LLM": + """Route persisted agent JSON to the matching LLM subclass. + + When a saved agent is rehydrated the ``llm`` dict may carry a + ``provider`` key written by a subclass (e.g. ``"provider": + "databricks"``). Without this, ``LLM.model_validate(data)`` would + build a base ``LLM`` and silently drop the subclass-only fields. + + Only fires when called directly on the base ``LLM`` class with a + plain dict — subclass validators and non-dict inputs pass through + unchanged. Subclasses are discovered generically via + ``__subclasses__()`` keyed off their ``provider`` Literal annotation, + so no provider names are hardcoded here. + """ + if cls is not LLM or not isinstance(data, dict): + return handler(data) + provider = data.get("provider") + if not provider: + return handler(data) + for sub in LLM.__subclasses__(): + f = sub.model_fields.get("provider") + if f and provider in getattr(f.annotation, "__args__", ()): + return sub.model_validate(data, context=info.context) + return handler(data) + + @field_validator("safety_settings", mode="before") + @classmethod + def _warn_safety_settings_deprecated( + cls, v: list[dict[str, str]] | None + ) -> list[dict[str, str]] | None: + if v is not None: + warn_deprecated( + "LLM.safety_settings", + deprecated_in="1.15.0", + removed_in="1.20.0", + details="Safety settings are no longer applied.", + ) + return v + @field_validator( "api_key", "aws_access_key_id", "aws_secret_access_key", "aws_session_token" ) diff --git a/openhands-sdk/openhands/sdk/llm/providers/__init__.py b/openhands-sdk/openhands/sdk/llm/providers/__init__.py new file mode 100644 index 0000000000..9432c80ee1 --- /dev/null +++ b/openhands-sdk/openhands/sdk/llm/providers/__init__.py @@ -0,0 +1,3 @@ +# Provider implementations for the OpenHands V1 SDK LLM layer. +# Each provider subpackage contains a DatabricksLLM (or equivalent) subclass +# that bypasses LiteLLM for direct, PWAF-compliant API communication. diff --git a/openhands-sdk/openhands/sdk/llm/providers/databricks/README.md b/openhands-sdk/openhands/sdk/llm/providers/databricks/README.md new file mode 100644 index 0000000000..8cb6b1788d --- /dev/null +++ b/openhands-sdk/openhands/sdk/llm/providers/databricks/README.md @@ -0,0 +1,167 @@ +# Databricks AI Gateway provider + +Native provider for Databricks Foundation Model APIs, routed through the +**Databricks AI Gateway**. Uses a direct `httpx` transport against the +gateway instead of routing HTTP through `litellm.completion`. + +Implements the Partner Well-Architected Framework (PWAF) contract for +Databricks OSS connectors: isolated auth strategies, `_/` +User-Agent on every request, typed errors, retry/backoff, and metadata-first +routing with a safe name-pattern fallback. + +## TL;DR — construct an LLM and use it + +```python +from pydantic import SecretStr +from openhands.sdk import create_llm +from openhands.sdk.llm.message import Message, TextContent + +llm = create_llm( + model="databricks/databricks-claude-sonnet-4-5", + databricks_host="https://adb-xxx.cloud.databricks.com", + api_key=SecretStr("dapi..."), + usage_id="my-agent", +) + +resp = llm.completion(messages=[ + Message(role="user", content=[TextContent(text="Hello!")]), +]) +print(resp.message.content[0].text) +``` + +No `databricks/` prefix + no explicit provider keyword means you get the +base `LLM` (default LiteLLM transport). Adding the `databricks/` prefix is +the only signal `create_llm` needs to route to this provider. + +## Supported native APIs + +| Family | AI Gateway path | When selected | +|---|---|---| +| `ProviderFamily.OPENAI` | `POST /serving-endpoints/{endpoint}/invocations` | Default — every `llm/v1/chat` endpoint | +| `ProviderFamily.OPENAI_RESPONSES` | `POST /serving-endpoints/v1/responses` | GPT-5 series (`databricks-gpt-5*`) | +| `ProviderFamily.ANTHROPIC` | `POST /serving-endpoints/anthropic/v1/messages` | Claude models (`*claude*`) | +| `ProviderFamily.GEMINI` | `POST /serving-endpoints/gemini/v1beta/models/{endpoint}:generateContent` | Gemini models (`*gemini*`) | + +Routing is metadata-first +(`GET /api/2.0/serving-endpoints/{name}` → `foundation_model.api_types` / +`external_model.provider`) with a name-pattern fallback (see `models.py`). +Results are cached in-process with a 5-minute TTL. + +See `_local/skills/databricks-ai-gateway-fm-apis/SKILL.md` for the +authoritative routing table, example payloads per family, and a runnable +`probe.py` that self-verifies every native path. + +## Authentication + +Five PWAF-compliant strategies, resolved in this priority order by +`resolve_credentials()`: + +1. **U2M** — OAuth browser PKCE, tokens passed in via `stored_u2m_tokens`. +2. **M2M** — `databricks_client_id` + `databricks_client_secret` + (service principal / client_credentials grant). +3. **PAT** — `api_key=SecretStr("dapi...")`. +4. **PROFILE** — `databricks_profile="DEFAULT"` (reads `~/.databrickscfg`; + requires the `databricks` extra: `pip install openhands-sdk[databricks]`). +5. **UNIFIED** — fallback to the `databricks-sdk` unified auth chain + (env vars, Azure MSI/Entra ID, etc.). + +Defer to the **Databricks Partner PWAF skills** *(URL TBD)* for end-to-end +auth details (token caching, refresh policies, CLI profile selection). + +Use `llm.auth_method` to see which strategy resolved. + +## Discovery (picker UIs) + +Listing AI-Gateway-shaped chat endpoints: + +```python +from openhands.sdk.llm.providers.databricks import ( + DatabricksCredentials, list_chat_endpoints, +) + +creds = DatabricksCredentials(host="...", get_token=lambda: "dapi...", auth_method="pat") +for ep in list_chat_endpoints(creds): + print(ep.qualified_name, ep.endpoint_type) +``` + +`list_chat_endpoints` includes both `FOUNDATION_MODEL_API` and +`EXTERNAL_MODEL` endpoints (customer-configured gpt-5 / gemini / +claude proxies). The lighter `list_foundation_models` returns flat +`databricks/` strings and is what `list_models_from_env` uses +with a 5-minute TTL cache. + +## PWAF surfaces on `DatabricksLLM` + +| Attribute | What it tells you | +|---|---| +| `llm.auth_method` | Resolved strategy: `pat` / `m2m` / `u2m` / `profile` / `unified` / `env` | +| `llm.predicted_family` | Family by name pattern only (pure compute, no HTTP) | +| `llm.resolve_family()` | Authoritative family (metadata probe, cached; falls back to predicted) | +| `llm.max_input_tokens` | Context window from the model-capability table | +| `llm.max_output_tokens` | Output budget (generous on reasoning models — gpt-5, gemini 2.5, gpt-oss) | + +## Module layout + +| Module | Role | +|---|---| +| `__init__.py` | Public API (`DatabricksLLM`, `ProviderFamily`, discovery, auth types) | +| `llm.py` | `DatabricksLLM` — Pydantic subclass of `LLM`; transport override | +| `client.py` | `DatabricksFMAPIClient` — `httpx` transport, family dispatch, retry | +| `native.py` | Per-family `to_native` / `from_native` adapters (request/response shaping) | +| `models.py` | `ProviderFamily`, `AIGatewayPaths`, routing functions, token containers | +| `auth.py` | Five credential strategies, `resolve_credentials()`, token providers | +| `discovery.py` | `list_chat_endpoints` / `list_foundation_models` + TTL cache | +| `utils.py` | `USER_AGENT`, `DatabricksTimeouts`, retry/backoff, error mapping | + +## Relationship to LiteLLM + +The connector owns its own HTTP path — all wire traffic to the Databricks AI +Gateway goes through `client.py` (`httpx`), not through `litellm.completion`. +It does, however, interoperate with LiteLLM at the type boundary to stay +compatible with the OpenHands base `LLM` class: + +- `client.py` returns a `litellm.types.utils.ModelResponse` so callers of the + base `LLM.completion` API get the expected return shape. +- `utils.py` and `auth.py` raise `litellm.exceptions.*` for HTTP error mapping + so retry/backoff behaves consistently with other providers. + +Removing this type-level coupling would require decoupling the OpenHands base +`LLM` class itself from LiteLLM and is tracked as a separate investigation — +it is deliberately out of scope for this provider. + +### Parked: true LiteLLM decoupling (Scope B) + +The work needed to drop the residual LiteLLM type dependency is intentionally +parked and recorded here so it isn't forgotten: + +- **Goal.** Remove both `litellm.types.utils.ModelResponse` and + `litellm.exceptions` from the Databricks provider's public surface so a + future OpenHands deployment could run this connector without LiteLLM in the + dependency graph at all. +- **Blocker.** The OpenHands base `LLM` class (in the shared SDK) still + accepts and returns `ModelResponse`, and its retry/backoff wiring catches + `litellm.exceptions.*`. Changing only the Databricks connector would break + that contract and ripple into every other provider plus the callers. +- **Shape of the fix (for the future investigation).** + 1. Introduce a small SDK-level `LLMResponse` dataclass that `LLM.completion` + returns regardless of provider, with a thin adapter that today wraps + `ModelResponse` for LiteLLM-backed providers. + 2. Replace `litellm.exceptions.*` with a set of SDK-owned error classes + (`LLMAuthError`, `LLMRateLimitError`, …) and translate at the LiteLLM + boundary only. + 3. Switch this provider to the new types, delete the two LiteLLM imports + from `client.py` / `utils.py` / `auth.py`, and remove the dependency + marker from the Databricks provider's optional extras. +- **Test plan.** Re-run the full `tests/sdk/llm/providers/databricks/` suite + plus the shared `LLM` contract tests without `litellm` installed in a + separate venv to prove the connector is truly standalone. +- **Why not now.** Would touch every other provider and the base `LLM` + class; out of scope for the native-Databricks provider milestone. + +## Testing + +Every routing / adapter / auth path has unit coverage under +`tests/sdk/llm/providers/databricks/`. End-to-end calls against +`e2e-demo-field-eng` have been run live across all three native families +(Llama → OpenAI Chat, Claude → Anthropic, Gemini → `generateContent`) and +all five auth strategies. diff --git a/openhands-sdk/openhands/sdk/llm/providers/databricks/__init__.py b/openhands-sdk/openhands/sdk/llm/providers/databricks/__init__.py new file mode 100644 index 0000000000..d07f770d94 --- /dev/null +++ b/openhands-sdk/openhands/sdk/llm/providers/databricks/__init__.py @@ -0,0 +1,93 @@ +"""Databricks AI Gateway native provider for the OpenHands V1 SDK. + +PWAF-compliant. Uses a direct synchronous ``httpx`` transport against the +Databricks AI Gateway rather than routing HTTP through ``litellm.completion``, +and dispatches by provider family to the correct native surface: + +* **OpenAI Chat** → ``/serving-endpoints/{endpoint}/invocations`` (universal) +* **OpenAI Responses** → ``/serving-endpoints/v1/responses`` (GPT-5 series) +* **Anthropic Messages** → ``/serving-endpoints/anthropic/v1/messages`` (Claude) +* **Gemini generateContent** → ``/serving-endpoints/gemini/v1beta/models/{endpoint}:generateContent`` + +Routing is metadata-first (``GET /api/2.0/serving-endpoints/{name}`` → +``foundation_model.api_types`` / ``external_model.provider``) with a +name-pattern fallback. + +Auth: PAT, OAuth M2M (client_credentials), OAuth U2M (browser PKCE), +``~/.databrickscfg`` profile, and unified Databricks SDK chain — see +:mod:`.auth` for details. The PWAF Partner AI Dev Kit skills are the +authoritative reference for credential handling. + +See the companion skill ``databricks-ai-gateway-fm-apis`` (in ``_local/skills``) +for the routing table, worked examples, and a runnable ``probe.py`` that +self-verifies every native path. + +Typical usage (via :func:`openhands.sdk.create_llm` factory): + +.. code-block:: python + + from openhands.sdk import create_llm + from openhands.sdk.llm.message import Message, TextContent + from pydantic import SecretStr + + llm = create_llm( + model="databricks/databricks-claude-sonnet-4-5", + databricks_host="https://adb-xxx.cloud.databricks.com", + api_key=SecretStr("dapi..."), # or pass databricks_profile="DEFAULT" + usage_id="my-agent", + ) + print(llm.predicted_family) # ProviderFamily.ANTHROPIC (no HTTP) + print(llm.resolve_family()) # ANTHROPIC (metadata-confirmed) + + resp = llm.completion(messages=[ + Message(role="user", content=[TextContent(text="Hello!")]), + ]) +""" + +from openhands.sdk.llm.providers.databricks.auth import ( + AuthStrategy, + DatabricksCredentials, +) +from openhands.sdk.llm.providers.databricks.discovery import ( + CURATED_DATABRICKS_MODELS, + DiscoveredEndpoint, + ModelPickerEntry, + get_picker_entries, + list_chat_endpoints, + list_foundation_models, + list_models_from_env, +) +from openhands.sdk.llm.providers.databricks.llm import DatabricksLLM +from openhands.sdk.llm.providers.databricks.models import ( + AIGatewayPaths, + ProviderFamily, + StoredU2MTokens, + detect_family, + pick_family_from_api_types, +) +from openhands.sdk.llm.providers.databricks.settings_bridge import kwargs_from_settings + +__all__ = [ + # LLM + "DatabricksLLM", + # Routing primitives + "ProviderFamily", + "AIGatewayPaths", + "detect_family", + "pick_family_from_api_types", + # Auth + "AuthStrategy", + "DatabricksCredentials", + "StoredU2MTokens", + # Discovery + "DiscoveredEndpoint", + "list_chat_endpoints", + "list_foundation_models", + "list_models_from_env", + # Two-tier model picker + "ModelPickerEntry", + "CURATED_DATABRICKS_MODELS", + "get_picker_entries", + # Settings → create_llm bridge (drift-guarded) + "kwargs_from_settings", +] diff --git a/openhands-sdk/openhands/sdk/llm/providers/databricks/auth.py b/openhands-sdk/openhands/sdk/llm/providers/databricks/auth.py new file mode 100644 index 0000000000..0ff4f82bc9 --- /dev/null +++ b/openhands-sdk/openhands/sdk/llm/providers/databricks/auth.py @@ -0,0 +1,344 @@ +"""Databricks FMAPI authentication strategies. + +Supports all 5 PWAF-required auth paths: + U2M — OAuth browser PKCE flow (PWAF primary interactive auth). + Provider receives StoredU2MTokens from app layer; manages refresh only. + M2M — OAuth client credentials (PWAF primary service auth). + M2MTokenProvider fetches/refreshes tokens with threading.Lock. + PAT — Personal Access Token (additional option only per PWAF). + PROFILE — Databricks CLI profile (~/.databrickscfg). Requires databricks-sdk. + UNIFIED — databricks-sdk unified auth chain (workload identity, Azure AD, etc.). + Requires databricks-sdk. + +Auth priority (PWAF compliant): U2M > M2M > PAT > PROFILE > UNIFIED. + +Client ID distinction (critical): + DATABRICKS_CLIENT_ID — M2M service principal (grant_type=client_credentials) + DATABRICKS_U2M_CLIENT_ID — Custom OAuth app for browser login (PKCE flow) + Using M2M client_id for U2M → "OAuth application not available" from Databricks. +""" + +from __future__ import annotations + +import logging +import threading +import time +from dataclasses import dataclass +from enum import Enum +from typing import TYPE_CHECKING, Callable + +import httpx +from litellm.exceptions import AuthenticationError + +from openhands.sdk.llm.providers.databricks.models import StoredU2MTokens +from openhands.sdk.llm.providers.databricks.utils import ( + USER_AGENT, + normalize_host, + validate_databricks_config, +) + +if TYPE_CHECKING: + from openhands.sdk.llm.providers.databricks.llm import DatabricksLLM + +logger = logging.getLogger(__name__) + + +class AuthStrategy(str, Enum): + """Auth strategy discriminator. Used in resolve_credentials() priority chain.""" + + U2M = "u2m" # OAuth browser PKCE — PWAF primary interactive auth + M2M = "m2m" # OAuth client credentials — PWAF primary service auth + PAT = "pat" # Personal Access Token — additional option only per PWAF + PROFILE = "profile" # Databricks CLI profile (~/.databrickscfg) + UNIFIED = "unified" # databricks-sdk unified auth chain (fallback) + + +@dataclass +class DatabricksCredentials: + """Resolved Databricks credentials ready for use in API calls. + + get_token is always synchronous — threading.Lock used internally for M2M/U2M. + auth_method is a plain string logged for observability (never the token value). + """ + + host: str + get_token: Callable[[], str] + auth_method: str = "unknown" # "u2m" / "m2m" / "pat" / "profile" / "unified" + + +# --------------------------------------------------------------------------- +# M2M token provider — thread-safe via threading.Lock (not asyncio) +# --------------------------------------------------------------------------- + +class M2MTokenProvider: + """Thread-safe OAuth client credentials token provider. + + P0-3: constructor accepts host, client_id, client_secret (was missing in P3 plan). + Uses double-checked locking to ensure exactly one _fetch_new_token() call under + concurrent pressure. + """ + + def __init__(self, host: str, client_id: str, client_secret: str) -> None: + self._host = host + self._client_id = client_id + self._client_secret = client_secret + self._token: str | None = None + self._expires_at: float = 0.0 + self._lock = threading.Lock() # threading, NOT asyncio + + def get_token(self) -> str: + """Return a valid access token, refreshing proactively if nearing expiry.""" + # Fast path: check without acquiring lock + if self._token and time.time() < self._expires_at - 300: + return self._token + with self._lock: + # Double-check inside lock to prevent thundering herd + if self._token and time.time() < self._expires_at - 300: + return self._token + self._token, self._expires_at = self._fetch_new_token() + return self._token + + def _fetch_new_token(self) -> tuple[str, float]: + """Fetch a new M2M access token via client credentials grant.""" + resp = httpx.post( + f"{self._host}/oidc/v1/token", + data={ + "grant_type": "client_credentials", + "client_id": self._client_id, + "client_secret": self._client_secret, + "scope": "all-apis", # required by Databricks OIDC + }, + headers={"User-Agent": USER_AGENT}, # PWAF: UA on ALL Databricks HTTP + timeout=15.0, + ) + resp.raise_for_status() + data = resp.json() + expires_at = time.time() + data.get("expires_in", 3600) + return data["access_token"], expires_at + + +# --------------------------------------------------------------------------- +# Main credential resolver +# --------------------------------------------------------------------------- + +def resolve_credentials(llm: "DatabricksLLM") -> DatabricksCredentials: + """Resolve auth strategy and return DatabricksCredentials. + + Priority (PWAF: OAuth primary, PAT additional only): + 1. U2M stored_u2m_tokens is set (user completed browser login) + 2. M2M databricks_client_id + databricks_client_secret both set + 3. PAT api_key is set + 4. PROFILE databricks_profile named + 5. UNIFIED databricks-sdk auth chain (fallback) + + Host resolution: ``databricks_host`` (workspace) is required for U2M / + M2M / PROFILE / UNIFIED (workspace OAuth) and when + ``databricks_metadata_probe=True``. PAT only needs *some* host for the + FM client to route to — either ``databricks_host`` (canonical) or + ``databricks_ai_gateway_host`` (override). The validator in + ``DatabricksLLM`` enforces that at least one is set. + """ + stored = llm.stored_u2m_tokens + host = ( + llm.databricks_host + or llm.base_url + or (stored.host if stored else None) + ) + if host: + host = normalize_host(host) + + # Path 1: U2M — highest priority (user already browser-logged in) + if stored: + if not host: + raise ValueError("databricks_host is required for U2M auth.") + validate_databricks_config(host, AuthStrategy.U2M, stored_tokens=stored) + return _resolve_u2m(host, stored) + + # Path 2: M2M — service principal client credentials + if llm.databricks_client_id and llm.databricks_client_secret: + if not host: + raise ValueError("databricks_host is required for M2M auth.") + validate_databricks_config( + host, + AuthStrategy.M2M, + client_id=llm.databricks_client_id, + client_secret=llm.databricks_client_secret.get_secret_value(), + ) + return _resolve_m2m(host, llm) + + # Path 3: PAT — token is sent directly to the AI Gateway, no workspace + # host needed. credentials.host is left empty so any accidental use + # surfaces clearly. + if llm.api_key: + token = ( + llm.api_key.get_secret_value() + if hasattr(llm.api_key, "get_secret_value") + else str(llm.api_key) + ) + logger.info("databricks_auth_resolved", extra={"method": "pat"}) + return DatabricksCredentials( + host=host or "", get_token=lambda: token, auth_method="pat" + ) + + # Path 4: Named CLI profile + if llm.databricks_profile: + if not host: + raise ValueError( + "databricks_host is required when using databricks_profile." + ) + return _resolve_profile(host, llm.databricks_profile) + + # Path 5: SDK unified auth chain (workload identity, Azure AD, ~/.databrickscfg) + if not host: + raise ValueError( + "databricks_host is required for unified-SDK auth." + ) + return _resolve_sdk_auth(host) + + +# --------------------------------------------------------------------------- +# Strategy implementations +# --------------------------------------------------------------------------- + +def _resolve_u2m(host: str, stored: StoredU2MTokens) -> DatabricksCredentials: + """U2M: return current access token, refreshing silently via refresh_token. + + Proactive refresh: 5 minutes before expiry (300s buffer). + Uses threading.Lock for thread safety in the synchronous call path. + """ + lock = threading.Lock() + state: dict[str, object] = { + "token": stored.access_token, + "expires_at": stored.expires_at, + } + + def get_token() -> str: + # Fast path — no lock needed + if time.time() < float(state["expires_at"]) - 300: + return str(state["token"]) + with lock: + if time.time() < float(state["expires_at"]) - 300: + return str(state["token"]) + remaining = float(state["expires_at"]) - time.time() + logger.info( + "databricks_u2m_token_refresh", + extra={"remaining_s": round(remaining, 1)}, + ) + resp = httpx.post( + f"{host}/oidc/v1/token", + data={ + "grant_type": "refresh_token", + "refresh_token": stored.refresh_token, + "client_id": stored.client_id, + # No client_secret for PKCE (public client) + }, + headers={"User-Agent": USER_AGENT}, # PWAF: UA on token endpoint + timeout=15.0, + ) + if not resp.is_success: + raise AuthenticationError( + f"U2M token refresh failed [{resp.status_code}]. " + "Re-authenticate at /auth/databricks/initiate.", + model="", + llm_provider="databricks", + ) + data = resp.json() + state["token"] = data["access_token"] + state["expires_at"] = time.time() + data.get("expires_in", 3600) + logger.info("databricks_u2m_token_refreshed", extra={"method": "u2m"}) + return str(state["token"]) + + logger.info("databricks_auth_resolved", extra={"method": "u2m"}) + return DatabricksCredentials(host=host, get_token=get_token, auth_method="u2m") + + +def _resolve_m2m(host: str, llm: "DatabricksLLM") -> DatabricksCredentials: + """M2M: client credentials grant via M2MTokenProvider.""" + assert llm.databricks_client_secret is not None # validated in resolve_credentials + provider = M2MTokenProvider( + host=host, + client_id=llm.databricks_client_id, # type: ignore[arg-type] + client_secret=llm.databricks_client_secret.get_secret_value(), + ) + logger.info("databricks_auth_resolved", extra={"method": "m2m"}) + return DatabricksCredentials( + host=host, get_token=provider.get_token, auth_method="m2m" + ) + + +def _resolve_profile(host: str, profile: str) -> DatabricksCredentials: + """PROFILE: Databricks CLI profile via databricks-sdk. Requires optional dep. + + The import check is deferred into ``get_token()`` so that saving settings + succeeds even if ``databricks-sdk`` is not yet installed; the clear error + with install instructions surfaces only when the agent first makes an API + call. + """ + client_holder: dict[str, object] = {} + lock = threading.Lock() + + def get_token() -> str: + try: + from databricks.sdk import WorkspaceClient as _WC + except ImportError: + raise ImportError( + "PROFILE auth requires the 'databricks-sdk' package.\n" + "Install it: pip install databricks-sdk\n" + f"Then verify: databricks auth profiles " + f"(profile '{profile}' must appear)" + ) from None + + client = client_holder.get("client") + if client is None: + with lock: + client = client_holder.get("client") + if client is None: + client = _WC(host=host, profile=profile) + client_holder["client"] = client + return ( + client.config.authenticate()["Authorization"].replace("Bearer ", "") # type: ignore[attr-defined] + ) + + logger.info("databricks_auth_resolved", extra={"method": "profile", "profile": profile}) + return DatabricksCredentials(host=host, get_token=get_token, auth_method="profile") + + +def _resolve_sdk_auth(host: str) -> DatabricksCredentials: + """UNIFIED: databricks-sdk auth chain (workload identity, Azure AD, ~/.databrickscfg). + + The import check is deferred into ``get_token()`` so that saving settings + succeeds even if ``databricks-sdk`` is not yet installed; the clear error + with install + pre-login instructions surfaces only when the agent first + makes an API call. + + Pre-requisites (outside the agent): + 1. pip install databricks-sdk + 2. databricks auth login --host + """ + client_holder: dict[str, object] = {} + lock = threading.Lock() + + def get_token() -> str: + try: + from databricks.sdk import WorkspaceClient as _WC + except ImportError: + raise ImportError( + "Browser-SSO / unified auth requires the 'databricks-sdk' package.\n" + " Step 1 — install: pip install databricks-sdk\n" + f" Step 2 — login: databricks auth login --host {host}\n" + "Re-run the agent after completing both steps." + ) from None + + client = client_holder.get("client") + if client is None: + with lock: + client = client_holder.get("client") + if client is None: + client = _WC(host=host) + client_holder["client"] = client + return ( + client.config.authenticate()["Authorization"].replace("Bearer ", "") # type: ignore[attr-defined] + ) + + logger.info("databricks_auth_resolved", extra={"method": "unified"}) + return DatabricksCredentials(host=host, get_token=get_token, auth_method="unified") diff --git a/openhands-sdk/openhands/sdk/llm/providers/databricks/client.py b/openhands-sdk/openhands/sdk/llm/providers/databricks/client.py new file mode 100644 index 0000000000..56e0deeab9 --- /dev/null +++ b/openhands-sdk/openhands/sdk/llm/providers/databricks/client.py @@ -0,0 +1,404 @@ +"""Databricks AI Gateway synchronous HTTP client. + +Two distinct Databricks surfaces, kept strictly separate: + +* **AI Gateway** (FM invocations) → ``ai_gateway_host`` (optional override). + Defaults to ``credentials.host``. All ``POST /ai-gateway/...`` traffic + goes here, exclusively. +* **Workspace** (auth, discovery, metadata probes) → ``credentials.host``. + Used by auth/token resolution and the opt-in metadata probe; never used + for FM invocations. + +Path templates by provider family (relative to the AI Gateway base URL, +which :meth:`AIGatewayPaths.normalize_base` produces from the configured +host — see ``models.py``): + +* :attr:`ProviderFamily.OPENAI` → ``POST {base}/mlflow/v1/chat/completions`` +* :attr:`ProviderFamily.OPENAI_RESPONSES` → ``POST {base}/openai/v1/responses`` +* :attr:`ProviderFamily.ANTHROPIC` → ``POST {base}/anthropic/v1/messages`` +* :attr:`ProviderFamily.GEMINI` → ``POST {base}/gemini/v1beta/models/{endpoint}:generateContent`` + +Family routing: + +* Default — ``detect_family(model)`` resolves the family from the model + name; no network call. +* ``metadata_probe=True`` — issues + ``GET /api/2.0/serving-endpoints/{name}`` against the workspace before + each cache-miss invocation; results cached for 5 minutes per process. + +Streaming stays on the universal OpenAI Chat SSE path +(``{base}/mlflow/v1/chat/completions`` with ``stream=true`` in the body). + +Non-streaming: singleton ``httpx.Client`` for connection pooling, thread-safe. +Streaming: context-managed client per request — always closed, no leak. +""" + +from __future__ import annotations + +import json +import logging +import threading +import time +from typing import Any, Callable + +import httpx +from litellm.types.utils import ModelResponse + +from openhands.sdk.llm.providers.databricks.auth import DatabricksCredentials +from openhands.sdk.llm.providers.databricks.models import ( + AIGatewayPaths, + ProviderFamily, + detect_family, + pick_family_from_api_types, +) +from openhands.sdk.llm.providers.databricks.native import from_native, to_native +from openhands.sdk.llm.providers.databricks.utils import ( + USER_AGENT, + DatabricksTimeouts, + _raise_non_retryable, + fetch_with_retry, +) + +logger = logging.getLogger(__name__) + +TokenCallbackType = Callable[[str], None] + +# Metadata cache: endpoint name -> (family, expires_at_epoch_seconds). +_METADATA_TTL_S: float = 300.0 +_METADATA_NEGATIVE_TTL_S: float = 60.0 + + +def _bare_endpoint(model: str) -> str: + """Strip ``databricks/`` / ``databricks-`` prefix to get the endpoint name.""" + name = model.strip() + for prefix in ("databricks/",): + if name.startswith(prefix): + name = name[len(prefix):] + return name + + +class DatabricksFMAPIClient: + """Synchronous Foundation Model client for Databricks AI Gateway. + + Public API is unchanged from the previous single-route client — + ``chat_completion(model=..., messages=..., stream=..., tools=..., **kwargs)`` — + but internally the request is routed to the correct native surface by + provider family. + + Thread-safety: the singleton ``self._http`` is used for all non-streaming + calls (including metadata lookups). Streaming opens a fresh + context-managed client per request. + """ + + def __init__( + self, + credentials: DatabricksCredentials, + timeouts: DatabricksTimeouts, + ai_gateway_host: str | None = None, + max_retries: int = 3, + ssl_verify: bool = True, + paths: AIGatewayPaths | None = None, + metadata_probe: bool = False, + ) -> None: + # AI Gateway host is an optional override; for the common + # single-URL Databricks deployment the workspace host doubles as + # the gateway base (``/ai-gateway/``). + gateway_host = (ai_gateway_host or credentials.host or "").rstrip("/") + if not gateway_host: + raise ValueError( + "Either ai_gateway_host or credentials.host must be provided; " + "the FM client needs at least one URL to route invocations to." + ) + self._credentials = credentials + self._timeouts = timeouts + self._max_retries = max_retries + self._ssl_verify = ssl_verify + self._paths = paths or AIGatewayPaths() + self._ai_gateway_host = gateway_host + self._metadata_probe = metadata_probe + self._http = httpx.Client( + verify=ssl_verify, + timeout=httpx.Timeout( + connect=timeouts.connect_s, + read=timeouts.read_s, + write=10.0, + pool=timeouts.pool_s, + ), + ) + self._metadata_cache: dict[str, tuple[ProviderFamily, float]] = {} + self._metadata_lock = threading.Lock() + + # ------------------------------------------------------------------ + # Lifecycle + # ------------------------------------------------------------------ + + def __del__(self) -> None: # pragma: no cover — best-effort cleanup + try: + self._http.close() + except Exception: + pass + + def close(self) -> None: + """Explicitly close the singleton HTTP client.""" + self._http.close() + + # ------------------------------------------------------------------ + # Headers / auth + # ------------------------------------------------------------------ + + def _make_headers(self, family: ProviderFamily) -> dict[str, str]: + """Build request headers. Never re-import USER_AGENT per request.""" + h = { + "Authorization": f"Bearer {self._credentials.get_token()}", + "Content-Type": "application/json", + "User-Agent": USER_AGENT, + } + if family is ProviderFamily.ANTHROPIC: + h["anthropic-version"] = "2023-06-01" + return h + + # ------------------------------------------------------------------ + # Routing + # ------------------------------------------------------------------ + + def resolve_family(self, model: str) -> ProviderFamily: + """Resolve the AI Gateway path family for ``model``. + + Default: name-pattern only via ``detect_family(model)``. + + Opt-in (``metadata_probe=True``): metadata-first with name-pattern + fallback; issues ``GET /api/2.0/serving-endpoints/{name}`` against + the workspace, cached for 5 minutes per endpoint. + """ + if not self._metadata_probe: + return detect_family(model) + + endpoint = _bare_endpoint(model) + now = time.time() + with self._metadata_lock: + hit = self._metadata_cache.get(endpoint) + if hit and hit[1] > now: + return hit[0] + family = self._probe_metadata(endpoint) or detect_family(model) + ttl = _METADATA_TTL_S if self._probe_succeeded else _METADATA_NEGATIVE_TTL_S + with self._metadata_lock: + self._metadata_cache[endpoint] = (family, now + ttl) + return family + + _probe_succeeded: bool = False # set by _probe_metadata; read by resolve_family + + def _probe_metadata(self, endpoint: str) -> ProviderFamily | None: + """GET /api/2.0/serving-endpoints/{endpoint} → family (or None).""" + if not self._credentials.host: + raise ValueError( + "databricks_host is required when " + "databricks_metadata_probe=True; the metadata probe targets " + "the workspace control plane." + ) + url = f"{self._credentials.host}/api/2.0/serving-endpoints/{endpoint}" + try: + resp = self._http.get( + url, + headers={ + "Authorization": f"Bearer {self._credentials.get_token()}", + "User-Agent": USER_AGENT, + }, + timeout=10.0, + ) + except httpx.HTTPError as exc: + logger.debug("databricks_metadata_probe_failed", extra={ + "endpoint": endpoint, "error": str(exc) + }) + self._probe_succeeded = False + return None + if resp.status_code != 200: + logger.debug("databricks_metadata_probe_nonok", extra={ + "endpoint": endpoint, "status": resp.status_code, + }) + self._probe_succeeded = False + return None + try: + meta = resp.json() + except ValueError: + self._probe_succeeded = False + return None + entities = ((meta.get("config") or {}).get("served_entities")) or [] + fm = (entities[0] if entities else {}).get("foundation_model") or {} + em = (entities[0] if entities else {}).get("external_model") or {} + family = pick_family_from_api_types( + api_types=fm.get("api_types"), + external_provider=em.get("provider"), + ) + self._probe_succeeded = True + return family + + def invalidate_metadata(self, model: str) -> None: + """Drop the cached family for an endpoint (e.g. on 404/410).""" + endpoint = _bare_endpoint(model) + with self._metadata_lock: + self._metadata_cache.pop(endpoint, None) + + # ------------------------------------------------------------------ + # Public API + # ------------------------------------------------------------------ + + def chat_completion( + self, + model: str, + messages: list[dict], + stream: bool = False, + tools: list[dict] | None = None, + on_token: TokenCallbackType | None = None, + **kwargs: Any, + ) -> ModelResponse: + """Dispatch a chat-completion call across AI Gateway surfaces.""" + endpoint = _bare_endpoint(model) + + if stream: + family = ProviderFamily.OPENAI + url = self._paths.url(self._ai_gateway_host, family, endpoint) + payload = to_native( + family, endpoint, messages, + tools=tools, stream=True, **kwargs, + ) + return self._handle_stream(url, self._make_headers(family), payload, + endpoint, on_token) + + family = self.resolve_family(model) + url = self._paths.url(self._ai_gateway_host, family, endpoint) + payload = to_native(family, endpoint, messages, tools=tools, **kwargs) + headers = self._make_headers(family) + + response = fetch_with_retry( + client=self._http, + url=url, + headers=headers, + json=payload, + max_retries=self._max_retries, + ) + logger.debug( + "databricks_ai_gateway_response", + extra={ + "status": response.status_code, + "request_id": response.headers.get("x-request-id"), + "endpoint": endpoint, + "family": family.value, + "auth_method": self._credentials.auth_method, + # Intentionally NOT logging: Authorization header, token values, + # request/response bodies (may contain user prompts). + }, + ) + return self._parse_response(response, endpoint, family) + + # ------------------------------------------------------------------ + # Response parsing + # ------------------------------------------------------------------ + + def _parse_response( + self, + response: httpx.Response, + model: str, + family: ProviderFamily, + ) -> ModelResponse: + """Convert a native response to ``litellm.ModelResponse``. + + Flow: native JSON → ``from_native`` → OpenAI ChatCompletion dict → + ``ModelResponse``. Fallback builds a minimal response if the native + shape is unexpected. + """ + try: + data = response.json() + except ValueError: + return ModelResponse(id="databricks-response", choices=[], model=model) + + try: + chat = from_native(family, model, data) + return ModelResponse(**chat) + except Exception: + logger.warning( + "databricks_parse_fallback", + extra={"family": family.value, "endpoint": model}, + ) + return ModelResponse( + id=data.get("id", "databricks-response"), + choices=data.get("choices", []), + model=data.get("model", model), + usage=data.get("usage"), + ) + + # ------------------------------------------------------------------ + # Streaming (OpenAI Chat only for V1) + # ------------------------------------------------------------------ + + def _handle_stream( + self, + url: str, + headers: dict[str, str], + payload: dict, + model: str, + on_token: TokenCallbackType | None, + ) -> ModelResponse: + """Stream ``/invocations`` with a fresh context-managed client.""" + chunk_count = 0 + accumulated_content = "" + last_chunk_id = "" + + with httpx.Client( + verify=self._ssl_verify, + timeout=httpx.Timeout( + connect=self._timeouts.connect_s, + read=self._timeouts.chunk_s, + write=10.0, + pool=self._timeouts.pool_s, + ), + ) as stream_client: + with stream_client.stream( + "POST", url, headers=headers, json=payload + ) as resp: + if resp.status_code >= 400: + resp.read() + _raise_non_retryable(resp) + for line in resp.iter_lines(): + if not line.startswith("data: ") or line == "data: [DONE]": + continue + chunk_count += 1 + try: + chunk = json.loads(line[6:]) + except json.JSONDecodeError: + continue + last_chunk_id = chunk.get("id", last_chunk_id) + choices = chunk.get("choices", []) + if choices: + delta = choices[0].get("delta", {}) + token = delta.get("content") + if token: + accumulated_content += token + if on_token is not None: + on_token(token) + + logger.debug( + "databricks_stream_complete", + extra={ + "chunks": chunk_count, + "endpoint": model, + "auth_method": self._credentials.auth_method, + }, + ) + return self._build_stream_response(accumulated_content, last_chunk_id, model) + + def _build_stream_response( + self, content: str, response_id: str, model: str, + ) -> ModelResponse: + """Build a ``ModelResponse`` from accumulated stream content.""" + return ModelResponse( + id=response_id or "databricks-stream", + choices=[ + { + "index": 0, + "message": {"role": "assistant", "content": content}, + "finish_reason": "stop", + } + ], + model=model, + object="chat.completion", + ) diff --git a/openhands-sdk/openhands/sdk/llm/providers/databricks/discovery.py b/openhands-sdk/openhands/sdk/llm/providers/databricks/discovery.py new file mode 100644 index 0000000000..ae65b86964 --- /dev/null +++ b/openhands-sdk/openhands/sdk/llm/providers/databricks/discovery.py @@ -0,0 +1,329 @@ +"""Databricks AI Gateway model discovery. + +Queries GET /api/2.0/serving-endpoints and returns the chat-capable endpoints +exposed through the AI Gateway. Three endpoint classes are surfaced: + +* ``FOUNDATION_MODEL_API`` — workspace-hosted DBRX / Llama / Claude / Gemini / + GPT-5 pay-per-token models (native AI Gateway). +* ``EXTERNAL_MODEL`` — customer-configured external model endpoints + proxied through the gateway (still routed to provider-native APIs). +* ``CUSTOM_MODEL`` and ``endpoint_type=None`` endpoints are intentionally + excluded — those are agent / custom-deployment endpoints whose payload + shape is not guaranteed to be OpenAI-Chat-compatible. + +Results are TTL-cached for 5 minutes in ``list_models_from_env``. The structured +``list_chat_endpoints`` call always hits the network because metadata is cheap +and callers may want fresh data. + +PWAF: User-Agent header is included on every discovery call (required on ALL +Databricks HTTP). + +Cache race condition: on cache miss multiple threads may call +``list_foundation_models()`` concurrently (thundering herd). Last writer wins — +no data corruption, just redundant API calls. We prefer this over holding the +lock during the HTTP call, which would serialize all model-picker refreshes. +""" + +from __future__ import annotations + +import logging +import os +import threading +import time +from dataclasses import dataclass + +import httpx + +from openhands.sdk.llm.providers.databricks.auth import DatabricksCredentials +from openhands.sdk.llm.providers.databricks.models import ( + ProviderFamily, + detect_family, +) +from openhands.sdk.llm.providers.databricks.utils import USER_AGENT, normalize_host + +logger = logging.getLogger(__name__) + +_CACHE_LOCK = threading.Lock() +_CACHED_MODELS: list[str] = [] +_CACHE_EXPIRES_AT: float = 0.0 +_CACHE_TTL_S: int = 300 # 5 minutes + +# Endpoint types that expose AI-Gateway-shaped chat payloads. Everything else +# (CUSTOM_MODEL, None, agent/* tasks) is excluded by default — callers that +# really want them can pass a custom filter to list_chat_endpoints. +_GATEWAY_ENDPOINT_TYPES = frozenset({"FOUNDATION_MODEL_API", "EXTERNAL_MODEL"}) + + +@dataclass(frozen=True) +class DiscoveredEndpoint: + """Structured view of a serving endpoint from the list call. + + Only fields the list response reliably returns are captured here. + Authoritative routing metadata (``foundation_model.api_types``, + ``external_model.provider``) only comes from the per-endpoint describe + call and is resolved lazily by ``DatabricksFMAPIClient._probe_metadata``. + """ + + name: str # e.g. "databricks-claude-sonnet-4-5" + qualified_name: str # e.g. "databricks/databricks-claude-sonnet-4-5" + endpoint_type: str | None # FOUNDATION_MODEL_API | EXTERNAL_MODEL | None + task: str # "llm/v1/chat" + ready: bool + creator: str | None = None + + +def list_chat_endpoints( + credentials: DatabricksCredentials, + *, + include_not_ready: bool = False, + allowed_endpoint_types: frozenset[str] = _GATEWAY_ENDPOINT_TYPES, +) -> list[DiscoveredEndpoint]: + """Return all AI-Gateway chat endpoints visible in this workspace. + + Filters: + * ``task == "llm/v1/chat"`` + * ``endpoint_type`` in ``allowed_endpoint_types`` (default: + FOUNDATION_MODEL_API + EXTERNAL_MODEL) + * ``state.ready == "READY"`` unless ``include_not_ready`` is True + + PWAF: ``User-Agent`` header required on ALL Databricks HTTP calls. + """ + resp = httpx.get( + f"{credentials.host}/api/2.0/serving-endpoints", + headers={ + "Authorization": f"Bearer {credentials.get_token()}", + "User-Agent": USER_AGENT, + }, + timeout=30.0, + ) + resp.raise_for_status() + endpoints = resp.json().get("endpoints", []) + + out: list[DiscoveredEndpoint] = [] + for ep in endpoints: + task = ep.get("task") + if task != "llm/v1/chat": + continue + et = ep.get("endpoint_type") + if et not in allowed_endpoint_types: + continue + ready = ep.get("state", {}).get("ready") == "READY" + if not ready and not include_not_ready: + continue + name = ep.get("name") + if not name: + continue + out.append( + DiscoveredEndpoint( + name=name, + qualified_name=f"databricks/{name}", + endpoint_type=et, + task=task, + ready=ready, + creator=ep.get("creator"), + ) + ) + return out + + +def list_foundation_models(credentials: DatabricksCredentials) -> list[str]: + """Return qualified names of gateway chat endpoints (back-compat API). + + Historical name — kept for back-compat. Delegates to ``list_chat_endpoints`` + and returns only the ``databricks/`` strings. Includes both + ``FOUNDATION_MODEL_API`` and ``EXTERNAL_MODEL`` endpoints, READY only. + """ + return [e.qualified_name for e in list_chat_endpoints(credentials)] + + +# --------------------------------------------------------------------------- +# Two-tier model picker: curated (static) + discovered (dynamic) +# --------------------------------------------------------------------------- +# +# Static/global lists go stale fast and paper over per-workspace availability. +# Dynamic discovery solves that but isn't available until the user has entered +# host + credentials. We surface both: +# +# - tier 1: CURATED_DATABRICKS_MODELS — a small, hand-picked, family-balanced +# set (Claude, GPT, Gemini) that known-good works against FMAPI on any +# standard workspace. Used as the picker default before auth. +# - tier 2: list_chat_endpoints(creds) — the actual endpoints this workspace +# exposes (FOUNDATION_MODEL_API + EXTERNAL_MODEL), fetched live. Merged on +# top of the curated set so customer-configured gpt-5 / gemini / claude +# endpoints also show up. +# +# ``get_picker_entries`` is the single call UIs should use. It dedups by +# qualified name, preserves the curated "recommended" flag when the same +# endpoint is also discovered, and sorts recommended-first then by family/name. + + +@dataclass(frozen=True) +class ModelPickerEntry: + """UI-facing model picker row — merged view across curated + discovered. + + Fields: + qualified_name — "databricks/" (use as the model id for create_llm) + name — bare endpoint name + family — predicted provider family (OPENAI / ANTHROPIC / ...) + source — "curated" | "discovered" | "curated+discovered" + endpoint_type — FOUNDATION_MODEL_API | EXTERNAL_MODEL | None (curated-only) + ready — True if discovered and READY; True for curated (optimistic) + recommended — True for the curated "one per family" default picks + """ + + qualified_name: str + name: str + family: ProviderFamily + source: str + endpoint_type: str | None = None + ready: bool = True + recommended: bool = False + + +def _curated_entry( + name: str, family: ProviderFamily, *, recommended: bool = False +) -> ModelPickerEntry: + return ModelPickerEntry( + qualified_name=f"databricks/{name}", + name=name, + family=family, + source="curated", + endpoint_type="FOUNDATION_MODEL_API", + ready=True, + recommended=recommended, + ) + + +# Curated tier-1 set — Claude / GPT / Gemini only. One "recommended" per family +# (fast + capable), plus a couple of siblings. Intentionally excludes Llama, +# DBRX, and legacy endpoints — those surface automatically via discovery if the +# workspace has them enabled. +CURATED_DATABRICKS_MODELS: tuple[ModelPickerEntry, ...] = ( + # Anthropic — Claude (native Anthropic Messages API) + _curated_entry( + "databricks-claude-sonnet-4-5", ProviderFamily.ANTHROPIC, recommended=True + ), + _curated_entry("databricks-claude-opus-4-1", ProviderFamily.ANTHROPIC), + _curated_entry("databricks-claude-haiku-4-5", ProviderFamily.ANTHROPIC), + # OpenAI — GPT-5 (Responses API) and gpt-oss (OpenAI Chat) + _curated_entry( + "databricks-gpt-5-mini", ProviderFamily.OPENAI_RESPONSES, recommended=True + ), + _curated_entry("databricks-gpt-5", ProviderFamily.OPENAI_RESPONSES), + _curated_entry("databricks-gpt-oss-120b", ProviderFamily.OPENAI), + # Google — Gemini (native generateContent) + _curated_entry( + "databricks-gemini-2-5-flash", ProviderFamily.GEMINI, recommended=True + ), + _curated_entry("databricks-gemini-2-5-pro", ProviderFamily.GEMINI), +) + + +def get_picker_entries( + credentials: DatabricksCredentials | None = None, + *, + include_curated: bool = True, + include_discovered: bool = True, + include_not_ready: bool = False, +) -> list[ModelPickerEntry]: + """Merged view of curated + discovered Databricks models for picker UIs. + + Dedup rule: if a qualified name is in both tiers, the curated entry wins + on ``recommended`` / ``family`` (our opinion) but picks up the live + ``endpoint_type`` and ``ready`` fields from discovery, and its ``source`` + becomes ``"curated+discovered"`` so UIs can show a "verified + available" + badge. Order: recommended-first, then by family, then by name. + + Network: calls ``list_chat_endpoints`` **only** if ``credentials`` is + provided and ``include_discovered`` is True. Without creds this is a pure + compute over the static curated set and safe to call from sync UI code. + + Errors during discovery are logged and swallowed — the curated tier is + always returned even if the workspace is unreachable. That keeps the + picker usable offline / during outages. + """ + merged: dict[str, ModelPickerEntry] = {} + + if include_curated: + for e in CURATED_DATABRICKS_MODELS: + merged[e.qualified_name] = e + + if include_discovered and credentials is not None: + try: + discovered = list_chat_endpoints( + credentials, include_not_ready=include_not_ready + ) + except Exception as exc: + logger.warning( + "databricks_discovery_failed_in_picker", extra={"error": str(exc)} + ) + discovered = [] + + for d in discovered: + existing = merged.get(d.qualified_name) + if existing is not None: + # Curated entry already present — upgrade with live signals, + # keep our opinion on family + recommended. + merged[d.qualified_name] = ModelPickerEntry( + qualified_name=existing.qualified_name, + name=existing.name, + family=existing.family, + source="curated+discovered", + endpoint_type=d.endpoint_type, + ready=d.ready, + recommended=existing.recommended, + ) + else: + merged[d.qualified_name] = ModelPickerEntry( + qualified_name=d.qualified_name, + name=d.name, + family=detect_family(d.name), + source="discovered", + endpoint_type=d.endpoint_type, + ready=d.ready, + recommended=False, + ) + + return sorted( + merged.values(), + key=lambda e: (not e.recommended, e.family.value, e.name), + ) + + +def list_models_from_env() -> list[str]: + """Convenience wrapper that reads env vars and returns TTL-cached model list. + + Reads: + DATABRICKS_HOST — required + DATABRICKS_TOKEN or DATABRICKS_ACCESS_TOKEN — required + + Returns [] silently if env vars are not set or on any API error. + Results are cached for ``_CACHE_TTL_S`` seconds. Cache writes are protected + by ``_CACHE_LOCK``; reads use a lock-free fast path (last-writer-wins on miss). + """ + global _CACHED_MODELS, _CACHE_EXPIRES_AT + + if time.time() < _CACHE_EXPIRES_AT: + return _CACHED_MODELS + + host = os.environ.get("DATABRICKS_HOST", "").rstrip("/") + token = os.environ.get("DATABRICKS_TOKEN") or os.environ.get( + "DATABRICKS_ACCESS_TOKEN" + ) + if not host or not token: + return [] + + credentials = DatabricksCredentials( + host=normalize_host(host), + get_token=lambda: token, # type: ignore[return-value] + auth_method="env", + ) + try: + models = list_foundation_models(credentials) + with _CACHE_LOCK: + _CACHED_MODELS = models + _CACHE_EXPIRES_AT = time.time() + _CACHE_TTL_S + return models + except Exception as exc: + logger.warning("databricks_discovery_failed", extra={"error": str(exc)}) + return [] diff --git a/openhands-sdk/openhands/sdk/llm/providers/databricks/llm.py b/openhands-sdk/openhands/sdk/llm/providers/databricks/llm.py new file mode 100644 index 0000000000..0b27265d71 --- /dev/null +++ b/openhands-sdk/openhands/sdk/llm/providers/databricks/llm.py @@ -0,0 +1,352 @@ +"""DatabricksLLM — native Databricks AI Gateway provider for the OpenHands V1 SDK. + +Subclasses LLM and overrides the transport layer to talk to the AI Gateway +``/ai-gateway/`` directly over httpx. + +Usage (via factory — preferred): + from openhands.sdk import create_llm + llm = create_llm( + "databricks/databricks-claude-opus-4-6", + databricks_host="https://adb-1234.cloud.databricks.com", + api_key=SecretStr("dapi..."), + ) + +The workspace URL (``databricks_host``) is the canonical configured host. +The SDK derives the AI Gateway base from it +(``/ai-gateway/``) for every FM invocation. + +``databricks_ai_gateway_host`` is an optional override for deployments +with a dedicated gateway hostname (``.ai-gateway.cloud.databricks.com``); +when set, FM invocations route through it directly. Discovery, auth, and +metadata probes always go to ``databricks_host`` regardless. +""" + +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING, Any, Literal + +from pydantic import PrivateAttr, SecretStr, field_validator, model_validator + +from openhands.sdk.llm.llm import LLM +from openhands.sdk.llm.providers.databricks.auth import ( + DatabricksCredentials, + resolve_credentials, +) +from openhands.sdk.llm.providers.databricks.client import DatabricksFMAPIClient +from openhands.sdk.llm.providers.databricks.models import ( + ProviderFamily, + StoredU2MTokens, + detect_family, +) +from openhands.sdk.llm.providers.databricks.utils import DatabricksTimeouts + +if TYPE_CHECKING: + from litellm.types.utils import ModelResponse + +logger = logging.getLogger(__name__) + +# --------------------------------------------------------------------------- +# Model capability tables +# --------------------------------------------------------------------------- + +# Context windows (input tokens) for known Databricks FMAPI models. +# Unknown models fall back to 128K. +# +# Values reflect what the AI Gateway endpoint accepts, not the raw upstream model +# limits — e.g. Anthropic Claude endpoints are gateway-capped at 200K even if +# the upstream contract supports more. Keep this table conservative. +DATABRICKS_CONTEXT_WINDOWS: dict[str, int] = { + # --- Databricks-hosted FM (OpenAI Chat family) --- + "databricks/databricks-dbrx-instruct": 32_768, + "databricks/databricks-meta-llama-3-1-70b-instruct": 128_000, + "databricks/databricks-meta-llama-3-1-405b-instruct": 128_000, + "databricks/databricks-meta-llama-3-3-70b-instruct": 128_000, + "databricks/databricks-meta-llama-4-maverick": 128_000, + "databricks/databricks-mixtral-8x7b-instruct": 32_768, + "databricks/databricks-gpt-oss-20b": 128_000, + "databricks/databricks-gpt-oss-120b": 128_000, + # --- Anthropic native (Claude series on gateway) --- + "databricks/databricks-claude-3-5-sonnet-2": 200_000, + "databricks/databricks-claude-3-7-sonnet": 200_000, + "databricks/databricks-claude-sonnet-4": 200_000, + "databricks/databricks-claude-sonnet-4-5": 200_000, + "databricks/databricks-claude-opus-4-6": 200_000, + "databricks/databricks-claude-haiku-4-5": 200_000, + # --- Google Gemini native --- + "databricks/databricks-gemini-2-5-flash": 1_048_576, + "databricks/databricks-gemini-2-5-pro": 1_048_576, + # --- OpenAI Responses (GPT-5 series) --- + "databricks/databricks-gpt-5": 400_000, + "databricks/databricks-gpt-5-2": 400_000, + "databricks/databricks-gpt-5-4": 400_000, + "databricks/databricks-gpt-5-4-mini": 400_000, + "databricks/databricks-gpt-5-4-nano": 400_000, +} + +# Maximum output tokens for known Databricks FMAPI models. +# Unknown models fall back to 16K. +# +# For reasoning-capable endpoints (gpt-5 series, gemini 2.5, gpt-oss), output +# tokens include internal thinking tokens — the budget must be generous enough +# that visible text actually fits. See ``databricks-ai-gateway-fm-apis`` skill. +DATABRICKS_MAX_OUTPUT: dict[str, int] = { + # --- OpenAI Chat family --- + "databricks/databricks-dbrx-instruct": 4_096, + "databricks/databricks-meta-llama-3-1-70b-instruct": 4_096, + "databricks/databricks-meta-llama-3-1-405b-instruct": 4_096, + "databricks/databricks-meta-llama-3-3-70b-instruct": 4_096, + "databricks/databricks-meta-llama-4-maverick": 8_192, + "databricks/databricks-mixtral-8x7b-instruct": 4_096, + "databricks/databricks-gpt-oss-20b": 16_384, # reasoning capacity + "databricks/databricks-gpt-oss-120b": 16_384, + # --- Anthropic --- + "databricks/databricks-claude-3-5-sonnet-2": 8_192, + "databricks/databricks-claude-3-7-sonnet": 8_192, + "databricks/databricks-claude-sonnet-4": 8_192, + "databricks/databricks-claude-sonnet-4-5": 64_000, + "databricks/databricks-claude-opus-4-6": 32_000, + "databricks/databricks-claude-haiku-4-5": 8_192, + # --- Gemini (budget includes thinking) --- + "databricks/databricks-gemini-2-5-flash": 65_536, + "databricks/databricks-gemini-2-5-pro": 65_536, + # --- OpenAI Responses (GPT-5) — generous default so reasoning fits --- + "databricks/databricks-gpt-5": 16_384, + "databricks/databricks-gpt-5-2": 16_384, + "databricks/databricks-gpt-5-4": 16_384, + "databricks/databricks-gpt-5-4-mini": 16_384, + "databricks/databricks-gpt-5-4-nano": 16_384, +} + + +# --------------------------------------------------------------------------- +# DatabricksLLM +# --------------------------------------------------------------------------- + +class DatabricksLLM(LLM): + """Native Databricks Foundation Model API provider. PWAF-compliant. + + Uses a direct httpx transport to the Databricks AI Gateway instead of + routing HTTP through litellm.completion. Supports OAuth U2M (browser + PKCE), OAuth M2M (client credentials), PAT, CLI profile, and the + databricks-sdk unified auth chain. + """ + + # Pydantic provider discriminator. Serialized by ``SerializeAsAny`` on + # ``AgentBase.llm`` / ``LLMSummarizingCondenser.llm``; read back by the + # ``_dispatch_to_provider_subclass`` wrap-validator on the base ``LLM`` + # class to route the payload to this subclass on load. + provider: Literal["databricks"] = "databricks" + + # --- Databricks-specific fields --- + + databricks_ai_gateway_host: str | None = None + """Optional AI Gateway override — host only, scheme + hostname[:port], no path. + + When set, all FM invocations route through this host instead of the + workspace URL. Use this for deployments with a dedicated gateway, e.g. + ``https://.ai-gateway.cloud.databricks.com``. + + Leave unset for the common single-URL deployment — the SDK then routes + invocations through ``/ai-gateway/``. + + Must start with ``https://`` and contain no path.""" + + databricks_host: str | None = None + """Workspace URL — the canonical Databricks endpoint. + + Used for: + + * FM invocations (default base, becomes ``/ai-gateway/``) + unless ``databricks_ai_gateway_host`` is set. + * Auth / token resolution (OAuth flows mint tokens here). + * Discovery and the opt-in metadata probe + (``GET /api/2.0/serving-endpoints/...``). + + Required for OAuth-based auth (``profile`` / ``m2m`` / ``u2m`` / + unified). For PAT auth it's optional only when + ``databricks_ai_gateway_host`` is also set (the gateway then has its + own URL and the workspace isn't needed).""" + + databricks_metadata_probe: bool = False + """When True, ``resolve_family`` issues + ``GET /api/2.0/serving-endpoints/{name}`` against ``databricks_host`` + to authoritatively determine the AI Gateway path family from the + server-side ``api_types``. Results cached in-process for 5 minutes + per endpoint. Default False (name-pattern resolution only).""" + + databricks_client_id: str | None = None + """M2M service principal application ID (OAuth client_credentials grant). + NOT the same as DATABRICKS_U2M_CLIENT_ID (browser OAuth app). + Set DATABRICKS_CLIENT_SECRET alongside this.""" + + databricks_client_secret: SecretStr | None = None + """M2M service principal OAuth secret. Paired with databricks_client_id.""" + + databricks_profile: str | None = None + """Databricks CLI profile name from ~/.databrickscfg. Requires databricks-sdk.""" + + databricks_ssl_verify: bool = True + """SSL/TLS verification. Set to path string for custom CA bundle.""" + + stored_u2m_tokens: StoredU2MTokens | None = None + """U2M OAuth tokens from browser login flow. Passed from app layer. + Highest-priority auth path (PWAF: OAuth primary).""" + + # --- Resilience knobs --- + + databricks_max_retries: int = 3 + databricks_connect_timeout_s: float = 10.0 + databricks_read_timeout_s: float = 120.0 + databricks_chunk_timeout_s: float = 30.0 + + # --- Private state (not serialized) --- + + _db_credentials: DatabricksCredentials = PrivateAttr() + _db_client: DatabricksFMAPIClient = PrivateAttr() + + # --------------------------------------------------------------------------- + # Validators + # --------------------------------------------------------------------------- + + @field_validator("databricks_host", mode="before") + @classmethod + def _validate_host(cls, v: str | None) -> str | None: + if v is not None and not v.startswith("https://"): + raise ValueError( + f"databricks_host must start with 'https://'. Got: {v!r}" + ) + return v + + @field_validator("databricks_ai_gateway_host", mode="before") + @classmethod + def _validate_ai_gateway_host(cls, v: str | None) -> str | None: + if v is None or v == "": + return None + if not v.startswith("https://"): + raise ValueError( + "databricks_ai_gateway_host must start with 'https://'. " + f"Got: {v!r}" + ) + from urllib.parse import urlsplit + parts = urlsplit(v) + # Allow trailing ``/ai-gateway`` so users who copy-paste the full + # gateway base URL aren't penalised; everything else must be host-only. + path = (parts.path or "").rstrip("/") + if path and path != "/ai-gateway": + raise ValueError( + "databricks_ai_gateway_host must be host-only (optionally " + "ending in '/ai-gateway'); the SDK appends the per-family " + f"route itself. Got path={parts.path!r} in {v!r}" + ) + if parts.query or parts.fragment: + raise ValueError( + "databricks_ai_gateway_host must not include a query string " + f"or fragment. Got: {v!r}" + ) + return v.rstrip("/") + + @model_validator(mode="after") + def _init_databricks(self) -> "DatabricksLLM": + if not (self.databricks_ai_gateway_host or self.databricks_host): + raise ValueError( + "databricks_host is required (or databricks_ai_gateway_host " + "as an override). FM invocations route through " + "/ai-gateway/ by default." + ) + self._db_credentials = resolve_credentials(self) + self._db_client = DatabricksFMAPIClient( + credentials=self._db_credentials, + ai_gateway_host=self.databricks_ai_gateway_host, + timeouts=DatabricksTimeouts( + connect_s=self.databricks_connect_timeout_s, + read_s=self.databricks_read_timeout_s, + chunk_s=self.databricks_chunk_timeout_s, + ), + max_retries=self.databricks_max_retries, + ssl_verify=self.databricks_ssl_verify, + metadata_probe=self.databricks_metadata_probe, + ) + return self + + # --------------------------------------------------------------------------- + # PWAF surfaces (observability, diagnostics, pickers) + # --------------------------------------------------------------------------- + + @property + def auth_method(self) -> str: + """Resolved auth strategy: ``pat`` | ``m2m`` | ``u2m`` | ``profile`` | ``unified`` | ``env``. + + Read-only — set by ``resolve_credentials()`` during construction. + Handy for log correlation and operator dashboards. + """ + return self._db_credentials.auth_method + + @property + def predicted_family(self) -> ProviderFamily: + """Provider family predicted by **name pattern only** (no HTTP call). + + Useful for picker UIs and validation — gives an immediate answer without + hitting the ``/api/2.0/serving-endpoints/{name}`` describe endpoint. For + the authoritative family used at request time, call + :meth:`resolve_family` (it performs a metadata probe with in-process + caching and falls back to this same prediction on error). + """ + return detect_family(self.model) + + def resolve_family(self) -> ProviderFamily: + """Provider family used at request time. + + Default (``databricks_metadata_probe=False``): pure name-pattern + resolution — same as :attr:`predicted_family`, no network call. + + Opt-in (``databricks_metadata_probe=True``): metadata-first with + name-pattern fallback. Triggers at most one + ``GET /api/2.0/serving-endpoints/{name}`` per endpoint per 5-minute + TTL window against the workspace URL. + """ + endpoint = self.model.removeprefix("databricks/") + return self._db_client.resolve_family(endpoint) + + # --------------------------------------------------------------------------- + # LLM overrides + # --------------------------------------------------------------------------- + + def _init_model_info_and_caps(self) -> None: + """Override: set context windows from Databricks capability tables.""" + self.max_input_tokens = DATABRICKS_CONTEXT_WINDOWS.get(self.model, 128_000) + self.max_output_tokens = DATABRICKS_MAX_OUTPUT.get(self.model, 16_384) + self._validate_context_window_size() + + def _get_litellm_api_key_value(self) -> str | None: + """Override: return a fresh Databricks token rather than the static api_key.""" + return self._db_credentials.get_token() + + def _transport_call( + self, + *, + messages: list[dict[str, Any]], + enable_streaming: bool = False, + on_token=None, + **kwargs, + ) -> "ModelResponse": + """Override: call the Databricks FMAPI directly via httpx.""" + model_name = self.model.removeprefix("databricks/") + logger.debug( + "databricks_transport_call", + extra={ + "endpoint": model_name, + "auth_method": self._db_credentials.auth_method, + "predicted_family": detect_family(self.model).value, + # Authoritative family is resolved inside the client (with cache); + # we don't re-probe here to avoid a second call path. + "streaming": enable_streaming, + }, + ) + return self._db_client.chat_completion( + model=model_name, + messages=messages, + stream=enable_streaming, + on_token=on_token, + **kwargs, + ) diff --git a/openhands-sdk/openhands/sdk/llm/providers/databricks/models.py b/openhands-sdk/openhands/sdk/llm/providers/databricks/models.py new file mode 100644 index 0000000000..539fca3b75 --- /dev/null +++ b/openhands-sdk/openhands/sdk/llm/providers/databricks/models.py @@ -0,0 +1,287 @@ +"""Databricks AI Gateway routing config + shared Pydantic models. + +All FM traffic is routed through **Databricks AI Gateway** under the +``/ai-gateway/`` URL prefix. The gateway is reachable two ways: + +* Reverse-proxied under the workspace host + (``https:///ai-gateway/...``) — this is what + :func:`AIGatewayPaths.normalize_base` produces when the user only + configures the workspace URL. +* Dedicated AI-Gateway hostname + (``https://.ai-gateway.cloud.databricks.com/...``) — used + when the customer has a separate gateway endpoint (PrivateLink, Front + Door). In that case the path templates are appended directly without + the ``/ai-gateway`` prefix. + +This module carries: + +1. :class:`StoredU2MTokens` — OAuth token container shared with the app layer. +2. :class:`ProviderFamily` — dispatch key for which provider-native contract + the target endpoint speaks (OpenAI Chat, Anthropic Messages, Google Gemini + ``generateContent``, OpenAI Responses). +3. :class:`AIGatewayPaths` — path templates for each family's AI Gateway route. +4. :func:`detect_family` — name-pattern router (fast path, no HTTP call). +5. :func:`pick_family_from_api_types` — metadata router (authoritative; uses the + ``foundation_model.api_types`` / ``external_model.provider`` signals returned + by ``GET /api/2.0/serving-endpoints/{name}``). + +The two routers mirror the `databricks-ai-gateway-fm-apis` skill exactly — +keep them in sync when the skill's routing table changes. + +``StoredU2MTokens`` is defined **exactly once** here; `auth.py` and the +OpenHands app layer import from this module (no duplicate definitions). +""" + +from __future__ import annotations + +import re +from enum import Enum +from typing import Iterable + +from pydantic import BaseModel, Field + + +# --------------------------------------------------------------------------- +# OAuth tokens (shared container) +# --------------------------------------------------------------------------- + + +class StoredU2MTokens(BaseModel): + """OAuth tokens stored in the OpenHands user session after browser login. + + Passed from the app layer (after ``/auth/databricks/callback``) to + ``resolve_credentials()``. The provider module never initiates the browser + flow — it only manages token refresh. + """ + + access_token: str + refresh_token: str + expires_at: float # Unix epoch seconds + client_id: str # DATABRICKS_U2M_CLIENT_ID — required for token refresh + host: str # Workspace host — fallback when databricks_host is unset + + +# --------------------------------------------------------------------------- +# Provider family — drives AI Gateway payload format + URL path +# --------------------------------------------------------------------------- + + +class ProviderFamily(str, Enum): + """Which provider-native API format the AI Gateway endpoint speaks. + + Routing (in priority order — first match wins): + + ===================== ========================== ============================== + Family Name pattern Metadata ``api_types`` entry + ===================== ========================== ============================== + :attr:`ANTHROPIC` ``*claude*`` ``anthropic/v1/messages`` + :attr:`GEMINI` ``*gemini*`` ``gemini/v1/generateContent`` + :attr:`OPENAI_RESPONSES` ``databricks-gpt-5*`` ``openai/v1/responses`` + :attr:`OPENAI` *(default)* ``mlflow/v1/chat/completions`` + ===================== ========================== ============================== + + :attr:`OPENAI` is the **always-safe default** — every ``task=llm/v1/chat`` + endpoint accepts OpenAI-chat payloads at ``/{endpoint}/invocations`` and + returns OpenAI ``ChatCompletion`` responses. The other families are opt-in + and only used when there's positive evidence (metadata or name match) that + the endpoint speaks the native contract. + """ + + OPENAI = "openai" # OpenAI Chat — universal default + OPENAI_RESPONSES = "openai_responses" # OpenAI Responses — GPT-5 series only + ANTHROPIC = "anthropic" # Anthropic Messages — Claude models + GEMINI = "gemini" # Google Gemini generateContent + + +# --------------------------------------------------------------------------- +# AI Gateway path templates +# --------------------------------------------------------------------------- + + +class AIGatewayPaths(BaseModel): + """Path templates appended to the AI Gateway base for each native API. + + All four templates have been verified against a live Databricks workspace: + + * :attr:`openai` — OpenAI Chat Completions, mlflow flavor, universal + default. Endpoint name is carried in the body. + * :attr:`openai_responses` — OpenAI Responses API for the GPT-5 series. + Endpoint name is in the body. + * :attr:`anthropic` — Anthropic Messages API, native flavor for Claude + models. Endpoint name is in the body. + * :attr:`gemini` — Google Gemini ``generateContent`` native path. The + endpoint name is part of the URL. + + Templates intentionally start at the AI Gateway base (without the + ``/ai-gateway`` prefix). :meth:`url` calls :meth:`normalize_base` first + to produce the right base URL given the configured host: + + * Workspace URL (``adb-*.cloud.databricks.com``) → + ``/ai-gateway`` (the gateway is reverse-proxied under the + workspace control plane). + * Dedicated gateway URL (``*.ai-gateway.*``) → host as-is, the gateway + hostname is itself the base. + + Each template can be overridden for deployments with non-standard path + layouts. ``{endpoint}`` is substituted with the bare endpoint name + (after stripping the ``databricks/`` prefix). + """ + + openai: str = Field( + default="/mlflow/v1/chat/completions", + description="OpenAI Chat Completions (mlflow flavor; universal default).", + ) + openai_responses: str = Field( + default="/openai/v1/responses", + description="OpenAI Responses API. Endpoint name is in the body.", + ) + anthropic: str = Field( + default="/anthropic/v1/messages", + description="Anthropic Messages API. Endpoint name is in the body.", + ) + gemini: str = Field( + default="/gemini/v1beta/models/{endpoint}:generateContent", + description="Google Gemini generateContent native path.", + ) + + @staticmethod + def normalize_base(host: str) -> str: + """Return the AI Gateway base URL for a configured host. + + - Hosts whose netloc matches ``*.ai-gateway.*`` are dedicated AI + Gateway endpoints; the host itself is the base — return as-is + (after stripping any trailing slash). + - Hosts that already end with ``/ai-gateway`` are returned as-is. + - Anything else is treated as a workspace URL with the gateway + reverse-proxied; ``/ai-gateway`` is appended. + """ + h = host.rstrip("/") + # Crude netloc check; ``http(s):///...`` and bare ``netloc`` + # both work here without pulling in urllib for one substring match. + scheme_split = h.split("://", 1) + netloc = scheme_split[1].split("/", 1)[0] if len(scheme_split) == 2 else h + if ".ai-gateway." in netloc: + return h + if h.endswith("/ai-gateway"): + return h + return h + "/ai-gateway" + + def url(self, host: str, family: ProviderFamily, endpoint: str) -> str: + """Build the fully-qualified URL for a ``(family, endpoint)`` pair. + + ``host`` may be a workspace URL or a dedicated AI Gateway hostname; + :meth:`normalize_base` figures out which prefix to apply. + """ + tmpl = getattr(self, family.value) + return self.normalize_base(host) + tmpl.format(endpoint=endpoint) + + +# --------------------------------------------------------------------------- +# Routers +# --------------------------------------------------------------------------- + +# Strip these from the model id before matching. Keeps support for +# "databricks/databricks-claude-sonnet-4-5", "databricks-claude-...", etc. +_MODEL_PREFIXES = ("databricks/", "databricks-") + + +def _bare_name(model: str) -> str: + name = model.lower().strip() + for prefix in _MODEL_PREFIXES: + if name.startswith(prefix): + name = name[len(prefix):] + return name + + +def detect_family(model: str) -> ProviderFamily: + """Name-pattern router (fast path, no extra API call). + + Mirrors the ``databricks-ai-gateway-fm-apis`` skill's ``route_by_name``. + + Priority (first match wins): + + 1. ``*claude*`` → :attr:`ProviderFamily.ANTHROPIC` + 2. ``*gemini*`` → :attr:`ProviderFamily.GEMINI` + 3. ``gpt-*`` → :attr:`ProviderFamily.OPENAI_RESPONSES` + The leading digit requirement naturally excludes ``gpt-oss-*`` + (starts with ``gpt-o``) without needing an explicit blocklist. + Confirmed live against the full GPT-5 product line (April 2026): + ``gpt-5``, ``gpt-5-1``, ``gpt-5-1-codex-{max,mini}``, + ``gpt-5-2``, ``gpt-5-2-codex``, ``gpt-5-3-codex``, ``gpt-5-4``, + ``gpt-5-4-{mini,nano}``, ``gpt-5-mini``, ``gpt-5-nano``. + Any future ``gpt-6``, ``gpt-7``, … variants automatically inherit + this rule — Databricks routes all numbered GPT generations through + the OpenAI Responses API (``/openai/v1/responses``). + 4. Everything else → :attr:`ProviderFamily.OPENAI` (universal + MLflow Chat Completions — safe default for ``gpt-oss``, Llama, + Mistral, DBRX, Qwen, …) + """ + name = _bare_name(model) + if "claude" in name: + return ProviderFamily.ANTHROPIC + if "gemini" in name: + return ProviderFamily.GEMINI + # ``_bare_name`` strips both ``databricks/`` and ``databricks-`` prefixes. + # ``re.match`` anchors at the start only — ``gpt-\d`` requires a digit + # immediately after the dash, so ``gpt-oss-*`` falls through cleanly. + if re.match(r"gpt-\d", name): + return ProviderFamily.OPENAI_RESPONSES + return ProviderFamily.OPENAI + + +# ``api_types`` strings exposed by ``GET /api/2.0/serving-endpoints/{name}`` +# (``config.served_entities[0].foundation_model.api_types``). +_API_TYPE_TO_FAMILY: dict[str, ProviderFamily] = { + "anthropic/v1/messages": ProviderFamily.ANTHROPIC, + "gemini/v1/generateContent": ProviderFamily.GEMINI, + "openai/v1/responses": ProviderFamily.OPENAI_RESPONSES, + # mlflow/v1/chat/completions is the universal fallback → OPENAI, + # handled by priority ordering below. +} + +# ``external_model.provider`` values (external endpoints don't have +# ``foundation_model``; they expose a single upstream provider instead). +_EXTERNAL_PROVIDER_TO_FAMILY: dict[str, ProviderFamily] = { + "anthropic": ProviderFamily.ANTHROPIC, + "bedrock-anthropic": ProviderFamily.ANTHROPIC, + "google": ProviderFamily.GEMINI, + "gemini": ProviderFamily.GEMINI, + # OpenAI & azure-openai endpoints speak Chat Completions on /invocations; + # map to OPENAI so we take the default path. + "openai": ProviderFamily.OPENAI, + "azure-openai": ProviderFamily.OPENAI, +} + +# Priority order when an endpoint exposes multiple ``api_types`` — prefer the +# most specific native API, fall back to OpenAI Chat. Reverse this list if you +# want "stay on OpenAI Chat unless explicitly overridden" behaviour. +_API_TYPE_PRIORITY: tuple[str, ...] = ( + "anthropic/v1/messages", + "gemini/v1/generateContent", + "openai/v1/responses", +) + + +def pick_family_from_api_types( + api_types: Iterable[str] | None, + external_provider: str | None = None, +) -> ProviderFamily: + """Metadata-first router (authoritative — no name-based guessing). + + ``api_types`` comes from ``foundation_model.api_types`` on a foundation-model + endpoint. ``external_provider`` comes from ``external_model.provider`` on an + external-model endpoint. Exactly one of them is populated for a given + endpoint; passing both is fine (foundation signals win). + + Returns :attr:`ProviderFamily.OPENAI` when no native signal is present — + this is the always-safe default for any ``task=llm/v1/chat`` endpoint. + """ + present = set(api_types or ()) + for key in _API_TYPE_PRIORITY: + if key in present: + return _API_TYPE_TO_FAMILY[key] + if external_provider: + return _EXTERNAL_PROVIDER_TO_FAMILY.get( + external_provider.lower(), ProviderFamily.OPENAI + ) + return ProviderFamily.OPENAI diff --git a/openhands-sdk/openhands/sdk/llm/providers/databricks/native.py b/openhands-sdk/openhands/sdk/llm/providers/databricks/native.py new file mode 100644 index 0000000000..7091c690f9 --- /dev/null +++ b/openhands-sdk/openhands/sdk/llm/providers/databricks/native.py @@ -0,0 +1,570 @@ +"""Minimal native-API adapters for Databricks AI Gateway. + +Everything in the OpenHands agent loop is OpenAI Chat Completions. This module +has one small adapter per non-default provider family — each ~30 LOC — that: + +1. Converts an OpenAI-chat ``messages`` list + generation kwargs to the + native request body for that family (``to_native``). +2. Converts the native response back to a minimal OpenAI ``ChatCompletion`` + dict that ``client._parse_response`` / ``litellm.ModelResponse`` accept + unchanged (``from_native``). + +Streaming is **not** adapted here — streaming stays on the universal OpenAI +Chat SSE path (``/invocations``) in ``client.py``. Native-API streaming +(Anthropic / Gemini / Responses) is a follow-up, documented in the skill. + +Out-of-scope by design (documented, not silently stripped): + +* Anthropic prompt caching, tool-use block preservation beyond plain text. +* Gemini multi-modal ``parts`` (inlineData, fileData). +* Responses ``custom`` / ``apply_patch`` / ``mcp`` tool types. + +For these, use the native provider SDK directly against the Databricks +gateway ``base_url`` — see the companion skill's "Using provider SDKs +directly" section. +""" + +from __future__ import annotations + +import time +import uuid +from typing import Any + +from openhands.sdk.llm.providers.databricks.models import ProviderFamily + + +# --------------------------------------------------------------------------- +# Public dispatch +# --------------------------------------------------------------------------- + + +def to_native( + family: ProviderFamily, + model: str, + messages: list[dict], + **kwargs: Any, +) -> dict: + """Build the native request body for the given family. + + ``model`` is the bare endpoint name (no ``databricks/`` prefix). + Unknown kwargs are passed through where the native API supports them + and dropped where it doesn't. + """ + if family is ProviderFamily.ANTHROPIC: + return _to_anthropic(model, messages, **kwargs) + if family is ProviderFamily.GEMINI: + return _to_gemini(model, messages, **kwargs) + if family is ProviderFamily.OPENAI_RESPONSES: + return _to_responses(model, messages, **kwargs) + return _to_openai_chat(model, messages, **kwargs) + + +def from_native( + family: ProviderFamily, + model: str, + data: dict, +) -> dict: + """Normalize a native response ``data`` to an OpenAI ChatCompletion dict.""" + if family is ProviderFamily.ANTHROPIC: + return _from_anthropic(model, data) + if family is ProviderFamily.GEMINI: + return _from_gemini(model, data) + if family is ProviderFamily.OPENAI_RESPONSES: + return _from_responses(model, data) + return _from_openai_chat(model, data) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _chat_completion( + model: str, + content: str, + *, + finish_reason: str = "stop", + usage: dict | None = None, + tool_calls: list[dict] | None = None, + response_id: str | None = None, +) -> dict: + """Build a minimal OpenAI-chat ``ChatCompletion`` dict.""" + msg: dict[str, Any] = {"role": "assistant", "content": content} + if tool_calls: + msg["tool_calls"] = tool_calls + return { + "id": response_id or f"chatcmpl-{uuid.uuid4().hex[:24]}", + "object": "chat.completion", + "created": int(time.time()), + "model": model, + "choices": [{"index": 0, "message": msg, "finish_reason": finish_reason}], + "usage": usage or {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}, + } + + +def _flatten_content(content: Any) -> str: + """Flatten OpenAI-style content (str | list of {type,text}) to plain str. + + Handles the reasoning-model shape where ``choices[0].message.content`` is + a list of blocks (``{"type":"reasoning",...}``, ``{"type":"text","text":...}``). + Skips ``reasoning`` blocks; concatenates ``text`` blocks. + """ + if isinstance(content, str): + return content + if isinstance(content, list): + parts: list[str] = [] + for blk in content: + if not isinstance(blk, dict): + continue + t = blk.get("type") + if t == "text" and isinstance(blk.get("text"), str): + parts.append(blk["text"]) + # "reasoning" blocks are intentionally skipped. + return "".join(parts) + return "" + + +_GENERIC_KWARGS = {"temperature", "top_p", "stop"} + + +# --------------------------------------------------------------------------- +# OpenAI Chat (default) +# --------------------------------------------------------------------------- + + +def _to_openai_chat(model: str, messages: list[dict], **kwargs: Any) -> dict: + # The AI Gateway ``/mlflow/v1/chat/completions`` endpoint reads the + # target endpoint name from the body, not the URL — every other family + # already does this, so the OpenAI Chat path now matches. + body: dict[str, Any] = {"model": model, "messages": messages} + if "max_tokens" in kwargs and kwargs["max_tokens"] is not None: + body["max_tokens"] = kwargs["max_tokens"] + if kwargs.get("tools"): + body["tools"] = kwargs["tools"] + if kwargs.get("tool_choice") is not None: + body["tool_choice"] = kwargs["tool_choice"] + for k in _GENERIC_KWARGS: + if kwargs.get(k) is not None: + body[k] = kwargs[k] + if kwargs.get("stream"): + body["stream"] = True + return body + + +def _from_openai_chat(model: str, data: dict) -> dict: + # Gateway already returns ChatCompletion shape; only normalize reasoning + # models' list-of-blocks content back to a string so downstream consumers + # don't need to special-case it. + choices = data.get("choices") or [] + if choices: + msg = choices[0].get("message") or {} + if isinstance(msg.get("content"), list): + msg["content"] = _flatten_content(msg["content"]) + return data + + +# --------------------------------------------------------------------------- +# Anthropic Messages +# --------------------------------------------------------------------------- + + +def _to_anthropic(model: str, messages: list[dict], **kwargs: Any) -> dict: + """OpenAI messages → Anthropic Messages body. + + System messages become the top-level ``system`` string (Anthropic doesn't + support ``role=system`` inside ``messages``). Tool calls beyond plain text + are not converted here — use the Anthropic SDK against the gateway + ``base_url`` if you need Anthropic-native tool_use blocks. + """ + system_parts: list[str] = [] + conv: list[dict] = [] + for m in messages: + role = m.get("role") + content = _flatten_content(m.get("content")) + if role == "system": + if content: + system_parts.append(content) + continue + if role in ("user", "assistant") and content: + conv.append({"role": role, "content": content}) + body: dict[str, Any] = { + "model": model, + "messages": conv, + # max_tokens is REQUIRED by Anthropic Messages; provide a safe default. + "max_tokens": int(kwargs.get("max_tokens") or 1024), + } + if system_parts: + body["system"] = "\n\n".join(system_parts) + if kwargs.get("temperature") is not None: + body["temperature"] = kwargs["temperature"] + if kwargs.get("top_p") is not None: + body["top_p"] = kwargs["top_p"] + if kwargs.get("stop"): + stop = kwargs["stop"] + body["stop_sequences"] = [stop] if isinstance(stop, str) else list(stop) + return body + + +_ANTHROPIC_STOP_MAP = { + "end_turn": "stop", + "stop_sequence": "stop", + "max_tokens": "length", + "tool_use": "tool_calls", +} + + +def _from_anthropic(model: str, data: dict) -> dict: + text_parts: list[str] = [] + for blk in data.get("content") or []: + if isinstance(blk, dict) and blk.get("type") == "text": + text_parts.append(blk.get("text", "")) + u = data.get("usage") or {} + usage = { + "prompt_tokens": u.get("input_tokens", 0), + "completion_tokens": u.get("output_tokens", 0), + "total_tokens": (u.get("input_tokens", 0) + u.get("output_tokens", 0)), + } + return _chat_completion( + model, + "".join(text_parts), + finish_reason=_ANTHROPIC_STOP_MAP.get(data.get("stop_reason", ""), "stop"), + usage=usage, + response_id=data.get("id"), + ) + + +# --------------------------------------------------------------------------- +# Google Gemini generateContent +# --------------------------------------------------------------------------- + + +def _to_gemini(model: str, messages: list[dict], **kwargs: Any) -> dict: + """OpenAI messages → Gemini ``generateContent`` body. + + ``role=system`` maps to ``systemInstruction``. OpenAI ``assistant`` becomes + Gemini ``model``. ``maxOutputTokens`` defaults to 1024 — Gemini budgets + *thinking + output* against this, so anything below ~256 can return empty. + """ + system_parts: list[str] = [] + contents: list[dict] = [] + for m in messages: + role = m.get("role") + text = _flatten_content(m.get("content")) + if role == "system": + if text: + system_parts.append(text) + continue + g_role = "model" if role == "assistant" else "user" + if text: + contents.append({"role": g_role, "parts": [{"text": text}]}) + gen_config: dict[str, Any] = { + "maxOutputTokens": int(kwargs.get("max_tokens") or 1024), + } + if kwargs.get("temperature") is not None: + gen_config["temperature"] = kwargs["temperature"] + if kwargs.get("top_p") is not None: + gen_config["topP"] = kwargs["top_p"] + if kwargs.get("stop"): + stop = kwargs["stop"] + gen_config["stopSequences"] = [stop] if isinstance(stop, str) else list(stop) + body: dict[str, Any] = {"contents": contents, "generationConfig": gen_config} + if system_parts: + body["systemInstruction"] = {"parts": [{"text": "\n\n".join(system_parts)}]} + return body + + +_GEMINI_FINISH_MAP = { + "STOP": "stop", + "MAX_TOKENS": "length", + "SAFETY": "content_filter", + "RECITATION": "content_filter", + "OTHER": "stop", +} + + +def _from_gemini(model: str, data: dict) -> dict: + cands = data.get("candidates") or [] + text = "" + finish = "stop" + if cands: + cand = cands[0] + parts = ((cand.get("content") or {}).get("parts")) or [] + text = "".join(p.get("text", "") for p in parts if isinstance(p, dict)) + finish = _GEMINI_FINISH_MAP.get(cand.get("finishReason", ""), "stop") + um = data.get("usageMetadata") or {} + usage = { + "prompt_tokens": um.get("promptTokenCount", 0), + "completion_tokens": um.get("candidatesTokenCount", 0), + "total_tokens": um.get("totalTokenCount", 0), + } + return _chat_completion( + model, + text, + finish_reason=finish, + usage=usage, + response_id=data.get("responseId"), + ) + + +# --------------------------------------------------------------------------- +# OpenAI Responses (GPT-5 series) +# --------------------------------------------------------------------------- + +# Pay-per-token FM endpoints reject these — they're documented for the +# hosted OpenAI Responses API but not supported via the gateway. +_RESPONSES_DROP = { + "background", + "store", + "previous_response_id", + "service_tier", + # ``max_completion_tokens`` is the Chat-Completions name; the upstream + # LLM/litellm path emits it for OpenAI-flavoured calls. Responses uses + # ``max_output_tokens`` (set explicitly above), so drop the chat-style + # alias to avoid the API's ``unsupported_parameter`` 400. + "max_completion_tokens", + # GPT-5 reasoning models routed through ``/openai/v1/responses`` reject + # both ``temperature`` and ``top_p`` ("Unsupported parameter…"). The + # SDK still respects user intent by carrying the values into kwargs, + # but the adapter drops them before they hit the wire so a single + # ``temperature=0.0`` default doesn't break every GPT-5 call. + "temperature", + "top_p", +} + + +def _to_responses_input(messages: list[dict]) -> list[dict]: + """Translate Chat-Completions messages into Responses API ``input`` items. + + The Responses API uses a flat item list instead of a ``messages`` array. + This function handles all message types that appear in a multi-turn + OpenHands conversation: + + Chat Completions role → Responses input item(s) + ───────────────────────────────────────────────────────────────────────── + user / system → {"role":"user/system","content":[{"type":"input_text","text":...}]} + assistant (text only) → {"role":"assistant","content":[{"type":"output_text","text":...}]} + assistant (tool_calls) → one {"type":"function_call","call_id":...} item per tool call, + followed by an output_text item if there is also text content + tool (result) → {"type":"function_call_output","call_id":...,"output":...} + ───────────────────────────────────────────────────────────────────────── + + Without the tool_calls and tool-result translations, multi-turn + conversations where GPT-5 calls a tool fail on the second turn because + the Responses API has no record of the function_call in the history. + """ + role_to_part_type = { + "user": "input_text", + "system": "input_text", + "developer": "input_text", + } + out: list[dict] = [] + for msg in messages: + if not isinstance(msg, dict): + out.append(msg) + continue + role = msg.get("role") + content = msg.get("content") + + # ── tool result ───────────────────────────────────────────────────── + if role == "tool": + out.append({ + "type": "function_call_output", + "call_id": msg.get("tool_call_id", ""), + "output": str(_flatten_content(content)) if content else "", + }) + continue + + # ── assistant with tool_calls (possibly also with text) ───────────── + if role == "assistant" and msg.get("tool_calls"): + for tc in msg.get("tool_calls") or []: + if not isinstance(tc, dict): + continue + fn = tc.get("function") or {} + out.append({ + "type": "function_call", + "call_id": tc.get("id", ""), + "name": fn.get("name", ""), + "arguments": fn.get("arguments", "{}"), + }) + # If the assistant message also contained text, emit it too. + text = _flatten_content(content) if content else "" + if text: + out.append({ + "role": "assistant", + "content": [{"type": "output_text", "text": text}], + }) + continue + + # ── text-only assistant messages ───────────────────────────────────── + if role == "assistant": + text = _flatten_content(content) if content else "" + if text: + out.append({ + "role": "assistant", + "content": [{"type": "output_text", "text": text}], + }) + continue + + # ── user / system / developer ──────────────────────────────────────── + part_type = role_to_part_type.get(role) + if part_type is None or content is None: + # Unknown role — pass through unchanged so the gateway surfaces + # a clear error rather than silently mangling the message. + out.append(msg) + continue + + if isinstance(content, str): + translated_content: list[dict] = [ + {"type": part_type, "text": content} + ] + elif isinstance(content, list): + translated_content = [] + for c in content: + if isinstance(c, dict) and c.get("type") == "text": + translated_content.append( + {"type": part_type, "text": c.get("text", "")} + ) + else: + translated_content.append(c) + else: + out.append(msg) + continue + + new_msg = dict(msg) + new_msg["content"] = translated_content + out.append(new_msg) + return out + + +def _chat_tools_to_responses(tools: list[dict]) -> list[dict]: + """Convert Chat Completions tool format to Responses API format. + + Chat Completions: + {"type": "function", "function": {"name": "...", "description": "...", "parameters": {...}}} + + Responses API: + {"type": "function", "name": "...", "description": "...", "parameters": {...}} + + The Responses API requires ``name`` at the top level of each tool object. + Passing the nested Chat Completions format causes a 400 "Missing required + parameter: 'tools[0].name'" error from the gateway. + """ + out: list[dict] = [] + for tool in tools: + if not isinstance(tool, dict): + out.append(tool) + continue + fn = tool.get("function") + if tool.get("type") == "function" and isinstance(fn, dict): + # Unwrap the function wrapper — Responses API is flat. + converted: dict[str, Any] = {"type": "function"} + converted.update(fn) + out.append(converted) + else: + # Non-function tool or already in Responses format — pass through. + out.append(tool) + return out + + +def _to_responses(model: str, messages: list[dict], **kwargs: Any) -> dict: + """OpenAI messages → OpenAI Responses body. + + Responses uses ``input`` (not ``messages``), ``max_output_tokens`` (not + ``max_tokens``), and a different content-part vocabulary + (``input_text`` / ``output_text`` instead of ``text``). Default budget + is 1024 since GPT-5 spends tokens on reasoning before producing + visible output. + """ + body: dict[str, Any] = { + "model": model, + "input": _to_responses_input(messages), + "max_output_tokens": int(kwargs.get("max_tokens") or 1024), + } + # NB: ``temperature`` / ``top_p`` are intentionally NOT forwarded — + # see ``_RESPONSES_DROP`` above. GPT-5 reasoning models reject them. + if kwargs.get("tools"): + body["tools"] = _chat_tools_to_responses(kwargs["tools"]) + if kwargs.get("tool_choice") is not None: + body["tool_choice"] = kwargs["tool_choice"] + # Forward other allowed kwargs; silently drop known-unsupported ones. + for k, v in kwargs.items(): + if k in _RESPONSES_DROP or v is None: + continue + if k in body or k in {"max_tokens", "tools", "tool_choice", "temperature", + "top_p", "stream", "stop"}: + continue + body[k] = v + return body + + +def _from_responses(model: str, data: dict) -> dict: + """Convert Responses API output to OpenAI Chat Completions shape. + + Responses returns an array of output items: + * ``{"type":"message","content":[...]}`` → assistant text + * ``{"type":"function_call","name":...}`` → tool_calls + * ``{"type":"reasoning",...}`` → skip (internal thinking) + + Tool calls are converted to the Chat Completions ``tool_calls`` array so + OpenHands can dispatch them unchanged. Without this conversion GPT-5's tool + invocations are silently dropped, causing the "response did not include a + function call" loop. + """ + text_parts: list[str] = [] + tool_calls: list[dict] = [] + + for item in data.get("output") or []: + if not isinstance(item, dict): + continue + + item_type = item.get("type") + + if item_type == "message": + for c in item.get("content") or []: + if isinstance(c, dict) and c.get("type") in ("output_text", "text"): + text_parts.append(c.get("text", "")) + + elif item_type == "function_call": + # Responses API function_call shape: + # {"type": "function_call", "id": "fc_...", "call_id": "call_...", + # "name": "terminal", "arguments": "{\"command\":\"ls\"}"} + # → Chat Completions tool_calls shape: + # {"id": "call_...", "type": "function", + # "function": {"name": "terminal", "arguments": "..."}} + call_id = item.get("call_id") or item.get("id") or f"call_{uuid.uuid4().hex[:8]}" + tool_calls.append({ + "id": call_id, + "type": "function", + "function": { + "name": item.get("name", ""), + "arguments": item.get("arguments", "{}"), + }, + }) + # "reasoning" items are intentionally skipped. + + # Fallback: some responses expose aggregated text at `output_text`. + if not text_parts and not tool_calls and isinstance(data.get("output_text"), str): + text_parts.append(data["output_text"]) + + u = data.get("usage") or {} + usage = { + "prompt_tokens": u.get("input_tokens", 0), + "completion_tokens": u.get("output_tokens", 0), + "total_tokens": u.get("total_tokens", + u.get("input_tokens", 0) + u.get("output_tokens", 0)), + } + finish = "tool_calls" if tool_calls else ( + "stop" if data.get("status") == "completed" else "stop" + ) + + msg: dict[str, Any] = {"role": "assistant", "content": "".join(text_parts) or None} + if tool_calls: + msg["tool_calls"] = tool_calls + + return { + "id": data.get("id", f"resp-{uuid.uuid4().hex[:8]}"), + "object": "chat.completion", + "model": model, + "choices": [{"index": 0, "message": msg, "finish_reason": finish}], + "usage": usage, + } diff --git a/openhands-sdk/openhands/sdk/llm/providers/databricks/settings_bridge.py b/openhands-sdk/openhands/sdk/llm/providers/databricks/settings_bridge.py new file mode 100644 index 0000000000..1bb4d48a52 --- /dev/null +++ b/openhands-sdk/openhands/sdk/llm/providers/databricks/settings_bridge.py @@ -0,0 +1,181 @@ +"""Settings → ``create_llm(...)`` kwargs bridge for the Databricks provider. + +This is the single code path both the OpenHands backend and the OpenHands-CLI +go through when turning user settings (env vars, DB rows, TUI form state) into +kwargs for :func:`openhands.sdk.create_llm`. + +Keeping it in the SDK prevents silent drift: when a new field is added to +``DatabricksLLM``, the contract test in ``test_settings_bridge.py`` fails until +the bridge is extended — which forces a conscious decision about whether the +new field should be exposed in settings UIs. + +Usage: + + from openhands.sdk import create_llm + from openhands.sdk.llm.providers.databricks.settings_bridge import ( + kwargs_from_settings, + ) + + kwargs = kwargs_from_settings(user, usage_id="agent") + llm = create_llm(**kwargs) + +The ``settings`` argument is deliberately duck-typed (Protocol, not a concrete +class). Any object exposing a subset of the attribute names below works: +pydantic models (``UserInfo``, ``CliSettings``, ``LLMEnvOverrides``), dataclasses, +``SimpleNamespace``, or plain ``dict``-wrappers. +""" + +from __future__ import annotations + +from typing import Any + +from pydantic import SecretStr + +from openhands.sdk.llm.providers.databricks.models import StoredU2MTokens + + +# Fields the bridge recognizes. Every public, user-settable field on +# ``DatabricksLLM`` (and the subset of base ``LLM`` fields that UIs expose) +# must appear here. Enforced by +# ``test_bridge_covers_all_databricks_llm_public_fields``. +_BRIDGE_FIELDS: tuple[str, ...] = ( + # --- Base LLM fields commonly set from UI --- + "model", + "api_key", + "base_url", + "timeout", + "max_input_tokens", + # --- Databricks-specific --- + "databricks_host", + "databricks_ai_gateway_host", + "databricks_metadata_probe", + "databricks_client_id", + "databricks_client_secret", + "databricks_profile", + "databricks_ssl_verify", + "databricks_max_retries", + "databricks_connect_timeout_s", + "databricks_read_timeout_s", + "databricks_chunk_timeout_s", + "stored_u2m_tokens", +) + +# Fields present on ``DatabricksLLM`` that are deliberately NOT bridged from +# settings — either internal pydantic discriminators or private state. +_NOT_BRIDGED: frozenset[str] = frozenset( + { + "provider", # Literal discriminator, always "databricks" + } +) + +_SECRET_FIELDS: frozenset[str] = frozenset( + { + "api_key", + "databricks_client_secret", + } +) + + +#: Useful when the settings object uses a different attribute name for a +#: bridged field, e.g. OpenHands' ``UserInfo`` uses ``llm_model`` / +#: ``llm_api_key`` / ``llm_base_url``. The bridge first tries the canonical +#: attribute name, then each alias in order. +UserInfoAliases: dict[str, tuple[str, ...]] = { + "model": ("llm_model",), + "api_key": ("llm_api_key",), + "base_url": ("llm_base_url",), +} + + +def kwargs_from_settings( + settings: Any, + *, + usage_id: str, + model_override: str | None = None, + base_url_fallback: str | None = None, + extras: dict[str, Any] | None = None, + aliases: dict[str, tuple[str, ...]] | None = None, +) -> dict[str, Any]: + """Build a kwargs dict ready for ``openhands.sdk.create_llm(**kwargs)``. + + Behavior: + + * Attributes are read via ``getattr`` — missing attrs are skipped (so the + same bridge works for partial settings objects). + * ``None`` and empty-string values are dropped so pydantic defaults apply. + * Secret fields (``api_key``, ``databricks_client_secret``) are coerced + to :class:`pydantic.SecretStr` if supplied as bare strings. + * ``stored_u2m_tokens`` accepts both :class:`StoredU2MTokens` instances + and plain dicts (validated with ``model_validate``; invalid dicts are + silently dropped). + * ``model_override`` wins over ``settings.model`` when both are set. + * ``base_url_fallback`` is only applied when *neither* ``base_url`` nor + ``databricks_host`` is present (preserves existing callers' + host-vs-base-url disambiguation). + * ``extras`` are merged last and win over everything else — use this for + per-request overrides like session U2M tokens. + * ``usage_id`` is always set; it's the one field not read from settings. + + Args: + settings: Any object exposing a subset of the bridged field names. + usage_id: Per-call usage id (``"agent"``, ``"condenser"``, ...). + model_override: Replaces ``settings.model`` when non-None. + base_url_fallback: Applied only when neither ``base_url`` nor + ``databricks_host`` is populated. + extras: Last-write-wins overrides. + aliases: Optional map of canonical field → fallback attribute names. + Tried in order after the canonical name itself. Convenient for + settings shapes like OpenHands' ``UserInfo`` that prefix fields + with ``llm_`` — pass :data:`UserInfoAliases`. + + Returns: + A dict safe to splat into :func:`openhands.sdk.create_llm`. + """ + kwargs: dict[str, Any] = {"usage_id": usage_id} + aliases = aliases or {} + + for field in _BRIDGE_FIELDS: + val = getattr(settings, field, None) + if val is None or val == "": + for alias in aliases.get(field, ()): + val = getattr(settings, alias, None) + if val not in (None, ""): + break + if val is None or val == "": + continue + if field in _SECRET_FIELDS and not isinstance(val, SecretStr): + val = SecretStr(str(val)) + if field == "stored_u2m_tokens" and isinstance(val, dict): + try: + val = StoredU2MTokens.model_validate(val) + except Exception: + continue + kwargs[field] = val + + if model_override is not None: + kwargs["model"] = model_override + + if ( + "base_url" not in kwargs + and "databricks_host" not in kwargs + and base_url_fallback + ): + kwargs["base_url"] = base_url_fallback + + if extras: + for k, v in extras.items(): + if v is None: + continue + if k in _SECRET_FIELDS and not isinstance(v, SecretStr): + v = SecretStr(str(v)) + kwargs[k] = v + + return kwargs + + +__all__ = [ + "kwargs_from_settings", + "UserInfoAliases", + "_BRIDGE_FIELDS", + "_NOT_BRIDGED", +] diff --git a/openhands-sdk/openhands/sdk/llm/providers/databricks/utils.py b/openhands-sdk/openhands/sdk/llm/providers/databricks/utils.py new file mode 100644 index 0000000000..e77cacd7f4 --- /dev/null +++ b/openhands-sdk/openhands/sdk/llm/providers/databricks/utils.py @@ -0,0 +1,269 @@ +"""Databricks FMAPI resilience utilities. + +Provides: + USER_AGENT — PWAF-required constant; set once, applied to ALL Databricks HTTP calls. + fetch_with_retry — synchronous retry loop with exponential back-off + Retry-After. + Helper functions: _log_retry, _raise_non_retryable, _raise_mapped, compute_backoff, + normalize_host, map_databricks_error, validate_databricks_config. +""" + +from __future__ import annotations + +import importlib.metadata +import logging +import random +import time +from dataclasses import dataclass +from typing import TYPE_CHECKING + +import httpx +from litellm.exceptions import ( + APIConnectionError, + AuthenticationError, + BadRequestError, + RateLimitError, + ServiceUnavailableError, +) + +if TYPE_CHECKING: + from openhands.sdk.llm.providers.databricks.auth import AuthStrategy + +logger = logging.getLogger(__name__) + +# --------------------------------------------------------------------------- +# PWAF: User-Agent constant +# Must be set once at module load time and applied to ALL Databricks HTTP calls. +# Never re-imported per request. Never user-configurable. +# --------------------------------------------------------------------------- +def _get_version() -> str: + try: + return importlib.metadata.version("openhands-sdk") + except importlib.metadata.PackageNotFoundError: + return "unknown" + + +USER_AGENT: str = f"openhands_oss/{_get_version()}" +"""PWAF User-Agent for the OpenHands OSS Databricks connector. + +Format (per Partner AI Dev Kit / PWAF telemetry skill): + _/ + +- ISV: openhands +- Product: oss (this is the OSS OpenHands product; non-OSS distributions + may ship their own build with a different product string) +- Version: resolved from the installed `openhands-sdk` package metadata. + +Applied to every Databricks HTTP call (AI Gateway `/invocations`, OAuth +token endpoint, serving-endpoints discovery). Never exposed as a user +config knob — connector-level constant per PWAF rules. +""" + + +# --------------------------------------------------------------------------- +# Timeout configuration +# --------------------------------------------------------------------------- +@dataclass +class DatabricksTimeouts: + connect_s: float = 10.0 # TCP + TLS; fail fast on unreachable host + read_s: float = 120.0 # Non-streaming: full response wait + chunk_s: float = 30.0 # Streaming: per-chunk idle timeout (resets per chunk) + pool_s: float = 5.0 # Wait for connection from pool + + +# --------------------------------------------------------------------------- +# Retry tables +# --------------------------------------------------------------------------- +RETRYABLE_STATUS_CODES: frozenset[int] = frozenset({429, 500, 502, 503, 504}) +NON_RETRYABLE_STATUS_CODES: frozenset[int] = frozenset({400, 401, 403, 404, 422}) +RETRYABLE_EXCEPTIONS: tuple[type[Exception], ...] = ( + httpx.ConnectError, + httpx.ReadTimeout, + httpx.RemoteProtocolError, +) + +STATUS_TO_LITELLM: dict[int, type[Exception]] = { + 429: RateLimitError, + 500: APIConnectionError, + 502: ServiceUnavailableError, + 503: ServiceUnavailableError, + 504: ServiceUnavailableError, + 400: BadRequestError, + 401: AuthenticationError, + 403: AuthenticationError, + 404: BadRequestError, + 422: BadRequestError, +} + +# Hard cap on Retry-After header value to prevent runaway sleeps from misbehaving proxies. +_RETRY_AFTER_MAX_S: float = 300.0 + + +# --------------------------------------------------------------------------- +# Retry helper functions (P0-2: were called but not defined in P3 plan) +# --------------------------------------------------------------------------- + +def _log_retry( + attempt: int, + max_retries: int, + status_code: int, + wait_s: float, + url: str, + response_headers: httpx.Headers, +) -> None: + """Log a retry event. Never logs credential values.""" + logger.warning( + "databricks_fmapi_retry", + extra={ + "attempt": attempt + 1, + "max_retries": max_retries, + "status_code": status_code, + "wait_s": round(wait_s, 2), + "request_id": response_headers.get("x-request-id"), + "url": url, + # Intentionally NOT logging: Authorization header, token, or secret + }, + ) + + +def _raise_non_retryable(response: httpx.Response) -> None: + """Raise the appropriate LiteLLM exception for a non-retryable status code. + + Called immediately (no sleep) for 400/401/403/404/422. + """ + try: + body = response.json() + except Exception: + body = {} + msg = map_databricks_error(response.status_code, body) + exc_class = STATUS_TO_LITELLM.get(response.status_code, BadRequestError) + raise exc_class(msg, model="", llm_provider="databricks") + + +def _raise_mapped(response: httpx.Response) -> None: + """Raise LiteLLM exception for a retryable status after all retries are exhausted.""" + try: + body = response.json() + except Exception: + body = {} + msg = map_databricks_error(response.status_code, body) + exc_class = STATUS_TO_LITELLM.get(response.status_code, APIConnectionError) + raise exc_class(msg, model="", llm_provider="databricks") + + +# --------------------------------------------------------------------------- +# Backoff and retry loop +# --------------------------------------------------------------------------- + +def compute_backoff(attempt: int, retry_after: str | None = None) -> float: + """Compute sleep duration for a retry attempt. + + Retry-After header wins but is capped at _RETRY_AFTER_MAX_S to prevent + runaway sleeps from misbehaving proxies. Falls back to full-jitter + exponential backoff: sleep in [0, min(60, 1 * 2^attempt)]. + """ + if retry_after: + return min(float(retry_after), _RETRY_AFTER_MAX_S) + return min(60.0, 1.0 * (2**attempt)) * random.uniform(0, 1) + + +def fetch_with_retry( + client: httpx.Client, + url: str, + headers: dict, + json: dict, + max_retries: int = 3, +) -> httpx.Response: + """Synchronous retry loop for FMAPI POST calls. + + Uses time.sleep (NOT asyncio.sleep) — _transport_call is always synchronous. + On exhaustion of retries, raises the mapped LiteLLM exception. + """ + last_exc: Exception | None = None + for attempt in range(max_retries + 1): + try: + response = client.post(url, headers=headers, json=json) + if response.status_code in RETRYABLE_STATUS_CODES and attempt < max_retries: + wait = compute_backoff( + attempt, response.headers.get("Retry-After") + ) + _log_retry(attempt, max_retries, response.status_code, wait, url, response.headers) + time.sleep(wait) + continue + if response.status_code in NON_RETRYABLE_STATUS_CODES: + _raise_non_retryable(response) + if response.status_code in RETRYABLE_STATUS_CODES: + # Exhausted retries on a retryable status code + _raise_mapped(response) + return response + except RETRYABLE_EXCEPTIONS as exc: + last_exc = exc + if attempt == max_retries: + raise APIConnectionError( + str(exc), model="", llm_provider="databricks" + ) from exc + time.sleep(compute_backoff(attempt)) + + # Unreachable; satisfies type checker + raise APIConnectionError( + f"Retry loop exhausted: {last_exc}", model="", llm_provider="databricks" + ) + + +# --------------------------------------------------------------------------- +# Miscellaneous helpers +# --------------------------------------------------------------------------- + +def normalize_host(host: str) -> str: + """Ensure host has https:// scheme and no trailing slash.""" + host = host.strip().rstrip("/") + if not host.startswith("https://"): + host = f"https://{host}" + return host + + +def map_databricks_error(status: int, body: dict) -> str: + """Extract human-readable error message from FMAPI error response body.""" + msg = ( + body.get("message") + or body.get("error_description") + or body.get("error") + or "Unknown error" + ) + return f"[{status}] {msg}" + + +def validate_databricks_config( + host: str | None, + strategy: "AuthStrategy", + **creds: object, +) -> None: + """Pre-flight validation — raises ValueError with actionable messages. + + Called during DatabricksLLM._init_databricks() so configuration errors + surface at construction time rather than at first inference call. + """ + if not host: + raise ValueError( + "Databricks host is required. Set databricks_host= or base_url= " + "to your workspace URL (e.g. https://adb-xxx.azuredatabricks.net)" + ) + if not host.startswith("https://"): + raise ValueError( + f"databricks_host must start with 'https://'. Got: {host!r}" + ) + + # Import AuthStrategy here to avoid circular import at module level + from openhands.sdk.llm.providers.databricks.auth import AuthStrategy as _AS + + if strategy == _AS.U2M and not creds.get("stored_tokens"): + raise ValueError( + "U2M auth requires stored OAuth tokens. Complete the browser login flow " + "at /auth/databricks/initiate first." + ) + if strategy == _AS.M2M: + if not creds.get("client_id") or not creds.get("client_secret"): + raise ValueError( + "M2M auth requires DATABRICKS_CLIENT_ID and DATABRICKS_CLIENT_SECRET " + "(service principal credentials). Note: these are DIFFERENT from " + "DATABRICKS_U2M_CLIENT_ID used for browser OAuth login." + ) diff --git a/openhands-sdk/openhands/sdk/llm/utils/model_features.py b/openhands-sdk/openhands/sdk/llm/utils/model_features.py index a0ec76a632..7df9ca0d94 100644 --- a/openhands-sdk/openhands/sdk/llm/utils/model_features.py +++ b/openhands-sdk/openhands/sdk/llm/utils/model_features.py @@ -193,6 +193,25 @@ def _supports_reasoning_effort(model: str | None) -> bool: def get_features(model: str) -> ModelFeatures: """Get model features.""" + # Databricks FMAPI models: return a fixed feature set. + # FMAPI does not support Anthropic prompt caching, extended thinking, or + # the Responses API — even for Claude-based endpoints. Standard tool calling + # (OpenAI wire format) and stop words are supported. + # This early return prevents Databricks Claude model names (which contain + # "claude-3-7-sonnet" etc.) from incorrectly matching PROMPT_CACHE_MODELS + # and other per-provider pattern lists. + if (model or "").startswith("databricks/"): + return ModelFeatures( + supports_reasoning_effort=False, + supports_extended_thinking=False, + supports_prompt_cache=False, + supports_stop_words=True, + supports_responses_api=False, + force_string_serializer=False, + send_reasoning_content=False, + supports_prompt_cache_retention=False, + ) + return ModelFeatures( supports_reasoning_effort=_supports_reasoning_effort(model), supports_extended_thinking=model_matches(model, EXTENDED_THINKING_MODELS), diff --git a/openhands-sdk/pyproject.toml b/openhands-sdk/pyproject.toml index 5b1cd987cd..139ce02e3d 100644 --- a/openhands-sdk/pyproject.toml +++ b/openhands-sdk/pyproject.toml @@ -32,6 +32,9 @@ Documentation = "https://docs.openhands.dev/sdk" [project.optional-dependencies] boto3 = ["boto3>=1.35.0"] +# Optional dependency for PROFILE and UNIFIED auth strategies only. +# U2M (browser PKCE), M2M (client credentials), and PAT have zero extra dependencies. +databricks = ["databricks-sdk>=0.20.0"] [build-system] requires = ["setuptools>=61.0", "wheel"] diff --git a/tests/sdk/llm/providers/__init__.py b/tests/sdk/llm/providers/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/sdk/llm/providers/databricks/__init__.py b/tests/sdk/llm/providers/databricks/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/sdk/llm/providers/databricks/conftest.py b/tests/sdk/llm/providers/databricks/conftest.py new file mode 100644 index 0000000000..d1fd734c7e --- /dev/null +++ b/tests/sdk/llm/providers/databricks/conftest.py @@ -0,0 +1,82 @@ +"""Shared fixtures for the Databricks provider test suite. + +This conftest exists to make the suite deterministic and fast regardless of the +developer machine's local Databricks state. It does two things on every test: + +1. Scrubs ``DATABRICKS_*`` environment variables so tests cannot accidentally + pick up credentials or a host from the developer's shell. +2. Replaces ``databricks.sdk.WorkspaceClient`` with a MagicMock so any code path + that reaches PROFILE or UNIFIED auth resolution does not attempt a real + network call or OAuth browser flow (which is what caused the multi-minute + test hang when ``~/.databrickscfg`` contained a U2M profile). + +Individual tests that need to exercise the real ``WorkspaceClient`` constructor +or a specific env var can override these fixtures locally with ``monkeypatch``. +""" + +from __future__ import annotations + +import sys +from unittest.mock import MagicMock + +import pytest + + +_DATABRICKS_ENV_VARS: tuple[str, ...] = ( + "DATABRICKS_HOST", + "DATABRICKS_TOKEN", + "DATABRICKS_ACCESS_TOKEN", + "DATABRICKS_CLIENT_ID", + "DATABRICKS_CLIENT_SECRET", + "DATABRICKS_U2M_CLIENT_ID", + "DATABRICKS_CONFIG_PROFILE", + "DATABRICKS_CONFIG_FILE", + "DATABRICKS_AUTH_TYPE", + "DATABRICKS_CLUSTER_ID", + "DATABRICKS_WAREHOUSE_ID", +) + + +@pytest.fixture(autouse=True) +def _scrub_databricks_env(monkeypatch: pytest.MonkeyPatch) -> None: + """Remove every DATABRICKS_* env var for the duration of a test. + + Prevents credential leakage from the developer's shell into tests and stops + ``resolve_credentials`` from falling through to UNIFIED auth, which would + construct a real ``WorkspaceClient``. + """ + for name in _DATABRICKS_ENV_VARS: + monkeypatch.delenv(name, raising=False) + + +@pytest.fixture(autouse=True) +def _mock_workspace_client(monkeypatch: pytest.MonkeyPatch) -> MagicMock: + """Replace ``databricks.sdk.WorkspaceClient`` with a MagicMock. + + Safe no-op when ``databricks-sdk`` is not installed. When it is installed, + this prevents any accidental real constructor call (which can trigger an + OAuth browser flow or hang on token refresh) from reaching the network. + + Returns the mock class so tests that need to assert on constructor calls + can request the fixture by name. + """ + mock_cls = MagicMock(name="WorkspaceClient") + mock_instance = MagicMock(name="WorkspaceClient_instance") + mock_instance.config.authenticate.return_value = { + "Authorization": "Bearer mock-unified-token" + } + mock_cls.return_value = mock_instance + + # Patch the already-imported module if present; otherwise inject a shim so + # ``from databricks.sdk import WorkspaceClient`` inside auth.py resolves to + # our mock regardless of whether the real package is installed. + if "databricks.sdk" in sys.modules: + monkeypatch.setattr( + "databricks.sdk.WorkspaceClient", mock_cls, raising=False + ) + else: + sdk_mod = MagicMock(name="databricks.sdk") + sdk_mod.WorkspaceClient = mock_cls + monkeypatch.setitem(sys.modules, "databricks.sdk", sdk_mod) + + return mock_cls diff --git a/tests/sdk/llm/providers/databricks/test_auth.py b/tests/sdk/llm/providers/databricks/test_auth.py new file mode 100644 index 0000000000..f7c4bdba0d --- /dev/null +++ b/tests/sdk/llm/providers/databricks/test_auth.py @@ -0,0 +1,339 @@ +"""Tests for Databricks FMAPI authentication strategies. + +Covers: M2MTokenProvider (double-checked locking, proactive refresh, scope=all-apis), +PAT path, U2M priority precedence, host resolution order, resolve_credentials dispatch. +""" + +from __future__ import annotations + +import time +from unittest.mock import MagicMock, patch + +import httpx +import pytest +from pydantic import SecretStr + +from openhands.sdk.llm.providers.databricks.auth import ( + AuthStrategy, + DatabricksCredentials, + M2MTokenProvider, + _resolve_m2m, + _resolve_u2m, + resolve_credentials, +) +from openhands.sdk.llm.providers.databricks.models import StoredU2MTokens +from openhands.sdk.llm.providers.databricks.utils import USER_AGENT + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +_HOST = "https://adb-123.azuredatabricks.net" +_M2M_TOKEN_URL = f"{_HOST}/oidc/v1/token" + + +def _make_m2m_response(token: str = "m2m-token", expires_in: int = 3600) -> httpx.Response: + # httpx requires `request` on Response for raise_for_status() (used in production code). + req = httpx.Request("POST", _M2M_TOKEN_URL) + return httpx.Response( + 200, json={"access_token": token, "expires_in": expires_in}, request=req + ) + + +def _make_refresh_response(token: str = "new-access-token", expires_in: int = 3600) -> httpx.Response: + req = httpx.Request("POST", _M2M_TOKEN_URL) + return httpx.Response( + 200, json={"access_token": token, "expires_in": expires_in}, request=req + ) + + +# --------------------------------------------------------------------------- +# M2MTokenProvider +# --------------------------------------------------------------------------- + +class TestM2MTokenProvider: + def test_constructor_accepts_host_client_id_secret(self) -> None: + """P0-3: constructor signature must accept host, client_id, client_secret.""" + provider = M2MTokenProvider( + host=_HOST, + client_id="client-id", + client_secret="client-secret", + ) + assert provider._host == _HOST + assert provider._client_id == "client-id" + assert provider._client_secret == "client-secret" + assert provider._token is None + assert provider._expires_at == 0.0 + + def test_get_token_fetches_on_first_call(self) -> None: + """get_token() must call _fetch_new_token on first call (no cached token).""" + provider = M2MTokenProvider(_HOST, "cid", "csecret") + with patch.object(provider, "_fetch_new_token", return_value=("tok-1", time.time() + 7200)) as mock_fetch: + token = provider.get_token() + assert token == "tok-1" + mock_fetch.assert_called_once() + + def test_get_token_uses_cached_token_when_fresh(self) -> None: + """get_token() must return cached token without re-fetching when > 5min remaining.""" + provider = M2MTokenProvider(_HOST, "cid", "csecret") + provider._token = "cached-tok" + provider._expires_at = time.time() + 3600 # 1h remaining + with patch.object(provider, "_fetch_new_token") as mock_fetch: + token = provider.get_token() + assert token == "cached-tok" + mock_fetch.assert_not_called() + + def test_get_token_refreshes_when_near_expiry(self) -> None: + """get_token() must refresh when < 5min (300s) remaining (proactive refresh).""" + provider = M2MTokenProvider(_HOST, "cid", "csecret") + provider._token = "expiring-tok" + provider._expires_at = time.time() + 100 # 100s remaining < 300s threshold + with patch.object(provider, "_fetch_new_token", return_value=("fresh-tok", time.time() + 7200)): + token = provider.get_token() + assert token == "fresh-tok" + + def test_fetch_new_token_sends_scope_all_apis(self) -> None: + """_fetch_new_token must include scope=all-apis in the token request.""" + provider = M2MTokenProvider(_HOST, "test-cid", "test-secret") + captured_data: dict = {} + + def mock_post(url, data=None, headers=None, timeout=None): + captured_data.update(data or {}) + return _make_m2m_response() + + with patch("httpx.post", side_effect=mock_post): + provider._fetch_new_token() + + assert captured_data.get("scope") == "all-apis" + assert captured_data.get("grant_type") == "client_credentials" + assert captured_data.get("client_id") == "test-cid" + assert captured_data.get("client_secret") == "test-secret" + + def test_fetch_new_token_sends_user_agent(self) -> None: + """_fetch_new_token must include PWAF User-Agent header.""" + provider = M2MTokenProvider(_HOST, "cid", "secret") + captured_headers: dict = {} + + def mock_post(url, data=None, headers=None, timeout=None): + captured_headers.update(headers or {}) + return _make_m2m_response() + + with patch("httpx.post", side_effect=mock_post): + provider._fetch_new_token() + + assert captured_headers.get("User-Agent") == USER_AGENT + + def test_fetch_new_token_raises_on_http_error(self) -> None: + """_fetch_new_token must propagate HTTP errors from the token endpoint.""" + provider = M2MTokenProvider(_HOST, "cid", "secret") + req = httpx.Request("POST", _M2M_TOKEN_URL) + error_resp = httpx.Response(401, json={"message": "Unauthorized"}, request=req) + + with patch("httpx.post", return_value=error_resp): + with pytest.raises(httpx.HTTPStatusError): + provider._fetch_new_token() + + +# --------------------------------------------------------------------------- +# _resolve_u2m +# --------------------------------------------------------------------------- + +def _make_stored_tokens( + access_token: str = "u2m-access", + refresh_token: str = "u2m-refresh", + expires_at: float | None = None, + client_id: str = "u2m-cid", + host: str = _HOST, +) -> StoredU2MTokens: + return StoredU2MTokens( + access_token=access_token, + refresh_token=refresh_token, + expires_at=expires_at or (time.time() + 3600), + client_id=client_id, + host=host, + ) + + +def test_u2m_resolve_returns_credentials_with_u2m_method() -> None: + stored = _make_stored_tokens() + creds = _resolve_u2m(_HOST, stored) + assert isinstance(creds, DatabricksCredentials) + assert creds.auth_method == "u2m" + assert creds.host == _HOST + + +def test_u2m_get_token_returns_current_token_when_fresh() -> None: + """U2M: get_token() returns the stored access token without HTTP when still fresh.""" + stored = _make_stored_tokens(access_token="fresh-token", expires_at=time.time() + 3600) + creds = _resolve_u2m(_HOST, stored) + with patch("httpx.post") as mock_post: + token = creds.get_token() + assert token == "fresh-token" + mock_post.assert_not_called() + + +def test_u2m_get_token_refreshes_when_near_expiry() -> None: + """U2M: get_token() calls token endpoint when token is near expiry.""" + stored = _make_stored_tokens(access_token="old-token", expires_at=time.time() + 100) + creds = _resolve_u2m(_HOST, stored) + + def mock_post(url, data=None, headers=None, timeout=None): + return _make_refresh_response(token="refreshed-token") + + with patch("httpx.post", side_effect=mock_post): + token = creds.get_token() + + assert token == "refreshed-token" + + +def test_u2m_refresh_uses_no_client_secret() -> None: + """U2M PKCE refresh must NOT send client_secret (public client).""" + stored = _make_stored_tokens(expires_at=time.time() + 100) + creds = _resolve_u2m(_HOST, stored) + captured_data: dict = {} + + def mock_post(url, data=None, headers=None, timeout=None): + captured_data.update(data or {}) + return _make_refresh_response() + + with patch("httpx.post", side_effect=mock_post): + creds.get_token() + + assert "client_secret" not in captured_data + assert captured_data.get("grant_type") == "refresh_token" + + +def test_u2m_refresh_failure_raises_auth_error() -> None: + """U2M refresh HTTP error → AuthenticationError with re-auth guidance.""" + from litellm.exceptions import AuthenticationError + + stored = _make_stored_tokens(expires_at=time.time() + 100) + creds = _resolve_u2m(_HOST, stored) + + with patch("httpx.post", return_value=httpx.Response(401, json={"error": "invalid_grant"})): + with pytest.raises(AuthenticationError, match="Re-authenticate"): + creds.get_token() + + +# --------------------------------------------------------------------------- +# resolve_credentials — priority chain +# --------------------------------------------------------------------------- + +def _make_mock_llm( + databricks_host: str = _HOST, + api_key: SecretStr | None = None, + stored_u2m_tokens: StoredU2MTokens | None = None, + databricks_client_id: str | None = None, + databricks_client_secret: SecretStr | None = None, + databricks_profile: str | None = None, + base_url: str | None = None, +) -> MagicMock: + """Return a MagicMock shaped like DatabricksLLM for testing resolve_credentials.""" + llm = MagicMock() + llm.databricks_host = databricks_host + llm.base_url = base_url + llm.api_key = api_key + llm.stored_u2m_tokens = stored_u2m_tokens + llm.databricks_client_id = databricks_client_id + llm.databricks_client_secret = databricks_client_secret + llm.databricks_profile = databricks_profile + return llm + + +def test_resolve_credentials_u2m_wins_over_all() -> None: + """U2M stored tokens take highest priority.""" + stored = _make_stored_tokens() + llm = _make_mock_llm( + stored_u2m_tokens=stored, + api_key=SecretStr("pat-token"), + databricks_client_id="m2m-cid", + databricks_client_secret=SecretStr("m2m-secret"), + ) + creds = resolve_credentials(llm) + assert creds.auth_method == "u2m" + + +def test_resolve_credentials_pat_path() -> None: + """PAT is used when no U2M or M2M credentials are present.""" + llm = _make_mock_llm(api_key=SecretStr("dapi-test")) + creds = resolve_credentials(llm) + assert creds.auth_method == "pat" + assert creds.get_token() == "dapi-test" + + +def test_resolve_credentials_m2m_over_pat() -> None: + """M2M takes priority over PAT when both are present.""" + llm = _make_mock_llm( + api_key=SecretStr("dapi-pat"), + databricks_client_id="m2m-cid", + databricks_client_secret=SecretStr("m2m-secret"), + ) + with patch( + "openhands.sdk.llm.providers.databricks.auth.M2MTokenProvider._fetch_new_token", + return_value=("m2m-token", time.time() + 3600), + ): + creds = resolve_credentials(llm) + assert creds.auth_method == "m2m" + + +def test_resolve_credentials_pat_does_not_require_host() -> None: + """PAT auth must succeed without a workspace host (token goes to AI Gateway).""" + llm = _make_mock_llm(databricks_host=None, base_url=None, api_key=SecretStr("tok")) + creds = resolve_credentials(llm) + assert creds.auth_method == "pat" + assert creds.host == "" + + +def test_resolve_credentials_unified_requires_host() -> None: + """Unified auth (no api_key, no profile, no creds) needs the workspace host.""" + llm = _make_mock_llm(databricks_host=None, base_url=None) + with pytest.raises(ValueError, match="databricks_host is required"): + resolve_credentials(llm) + + +def test_resolve_credentials_host_from_base_url() -> None: + """Host falls back to base_url if databricks_host is not set.""" + llm = _make_mock_llm(databricks_host=None, base_url=_HOST, api_key=SecretStr("tok")) + creds = resolve_credentials(llm) + assert creds.host == _HOST + + +def test_resolve_credentials_host_from_stored_tokens() -> None: + """Host falls back to stored_u2m_tokens.host as last resort.""" + stored = _make_stored_tokens(host=_HOST) + llm = _make_mock_llm(databricks_host=None, base_url=None, stored_u2m_tokens=stored) + creds = resolve_credentials(llm) + assert creds.host == _HOST + + +def test_resolve_credentials_profile_raises_without_sdk() -> None: + """PROFILE strategy raises ImportError with install hint if databricks-sdk absent. + + The import check is deferred to get_token() so that saving settings succeeds + even without the package installed; the error surfaces at first API call. + """ + llm = _make_mock_llm(databricks_profile="my-profile") + with patch.dict("sys.modules", {"databricks": None, "databricks.sdk": None}): + # resolve_credentials succeeds (returns a DatabricksCredentials object) + creds = resolve_credentials(llm) + assert creds.auth_method == "profile" + # The ImportError surfaces when the token is actually requested + with pytest.raises((ImportError, Exception)): + creds.get_token() + + +def test_resolve_credentials_unified_raises_without_sdk() -> None: + """UNIFIED strategy raises ImportError with install hint if databricks-sdk absent. + + The import check is deferred to get_token() so that saving settings succeeds + even without the package installed; the error surfaces at first API call. + """ + llm = _make_mock_llm() # no api_key, no profile → falls through to unified + with patch.dict("sys.modules", {"databricks": None, "databricks.sdk": None}): + # resolve_credentials succeeds + creds = resolve_credentials(llm) + assert creds.auth_method == "unified" + # The ImportError surfaces when the token is actually requested + with pytest.raises((ImportError, Exception)): + creds.get_token() diff --git a/tests/sdk/llm/providers/databricks/test_client.py b/tests/sdk/llm/providers/databricks/test_client.py new file mode 100644 index 0000000000..fe20ccd957 --- /dev/null +++ b/tests/sdk/llm/providers/databricks/test_client.py @@ -0,0 +1,382 @@ +"""Tests for DatabricksFMAPIClient. + +Covers: User-Agent header presence (PWAF), _parse_response, _build_stream_response, +streaming accumulation, __del__ cleanup (P1-1), and URL construction. +""" + +from __future__ import annotations + +from unittest.mock import MagicMock, patch + +import httpx +import pytest + +from openhands.sdk.llm.providers.databricks.auth import DatabricksCredentials +from openhands.sdk.llm.providers.databricks.client import DatabricksFMAPIClient +from openhands.sdk.llm.providers.databricks.models import ProviderFamily +from openhands.sdk.llm.providers.databricks.utils import DatabricksTimeouts, USER_AGENT + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +_HOST = "https://adb-123.azuredatabricks.net" + + +def _make_credentials(token: str = "test-token") -> DatabricksCredentials: + return DatabricksCredentials( + host=_HOST, + get_token=lambda: token, + auth_method="pat", + ) + + +_GATEWAY = _HOST # tests use the same URL for both surfaces by default + + +def _make_client( + credentials: DatabricksCredentials | None = None, + max_retries: int = 0, + metadata_probe: bool = False, + ai_gateway_host: str = _GATEWAY, +) -> DatabricksFMAPIClient: + creds = credentials or _make_credentials() + return DatabricksFMAPIClient( + credentials=creds, + ai_gateway_host=ai_gateway_host, + timeouts=DatabricksTimeouts(), + max_retries=max_retries, + metadata_probe=metadata_probe, + ) + + +def test_client_requires_some_host() -> None: + """Constructor must reject the case where neither ai_gateway_host nor + credentials.host is provided — there's no URL to route invocations to.""" + creds_no_host = DatabricksCredentials( + host="", get_token=lambda: "tok", auth_method="pat" + ) + with pytest.raises(ValueError, match="must be provided"): + DatabricksFMAPIClient( + credentials=creds_no_host, + ai_gateway_host=None, + timeouts=DatabricksTimeouts(), + ) + + +def test_client_defaults_gateway_host_to_credentials_host() -> None: + """When no ai_gateway_host override is given, the workspace host + (credentials.host) becomes the gateway base.""" + client = DatabricksFMAPIClient( + credentials=_make_credentials(), + ai_gateway_host=None, + timeouts=DatabricksTimeouts(), + ) + assert client._ai_gateway_host == _HOST + + +def _make_success_response(model: str = "test-model") -> dict: + return { + "id": "chatcmpl-test", + "object": "chat.completion", + "choices": [ + { + "index": 0, + "message": {"role": "assistant", "content": "Hello!"}, + "finish_reason": "stop", + } + ], + "usage": {"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15}, + "model": model, + } + + +# --------------------------------------------------------------------------- +# _make_headers — PWAF User-Agent +# --------------------------------------------------------------------------- + +def test_make_headers_includes_user_agent() -> None: + """PWAF: every request must carry the correct User-Agent header.""" + client = _make_client() + headers = client._make_headers(ProviderFamily.OPENAI) + assert headers["User-Agent"] == USER_AGENT + + +def test_make_headers_includes_authorization() -> None: + """Authorization header must use Bearer scheme with the current token.""" + client = _make_client(_make_credentials(token="dapi-abc123")) + headers = client._make_headers(ProviderFamily.OPENAI) + assert headers["Authorization"] == "Bearer dapi-abc123" + + +def test_make_headers_includes_content_type() -> None: + client = _make_client() + headers = client._make_headers(ProviderFamily.OPENAI) + assert headers["Content-Type"] == "application/json" + + +def test_make_headers_openai_family_has_no_anthropic_version() -> None: + """OPENAI / GEMINI / RESPONSES families must not set ``anthropic-version``.""" + client = _make_client() + for f in (ProviderFamily.OPENAI, ProviderFamily.GEMINI, ProviderFamily.OPENAI_RESPONSES): + headers = client._make_headers(f) + assert "anthropic-version" not in headers, ( + f"family={f} must not carry Anthropic header" + ) + + +def test_make_headers_anthropic_family_sets_anthropic_version() -> None: + """Anthropic native API requires the ``anthropic-version`` header.""" + client = _make_client() + headers = client._make_headers(ProviderFamily.ANTHROPIC) + assert "anthropic-version" in headers + assert headers["anthropic-version"], "anthropic-version must be non-empty" + + +# --------------------------------------------------------------------------- +# _parse_response (P0-1) +# --------------------------------------------------------------------------- + +def test_parse_response_maps_to_model_response() -> None: + """_parse_response should return a litellm ModelResponse from FMAPI JSON.""" + client = _make_client() + body = _make_success_response() + resp = httpx.Response(200, json=body) + + result = client._parse_response(resp, family=ProviderFamily.OPENAI, model="my-model") + + assert result.choices is not None + assert len(result.choices) > 0 + + +def test_parse_response_preserves_id() -> None: + client = _make_client() + body = _make_success_response() + body["id"] = "chatcmpl-unique-id" + resp = httpx.Response(200, json=body) + + result = client._parse_response(resp, family=ProviderFamily.OPENAI, model="m") + assert result.id == "chatcmpl-unique-id" + + +def test_parse_response_handles_malformed_body_gracefully() -> None: + """_parse_response must not crash on unexpected FMAPI shape (fallback path).""" + client = _make_client() + resp = httpx.Response(200, json={"unexpected": "field"}) + result = client._parse_response( + resp, family=ProviderFamily.OPENAI, model="fallback-model" + ) + assert result is not None + + +# --------------------------------------------------------------------------- +# _build_stream_response (P0-1) +# --------------------------------------------------------------------------- + +def test_build_stream_response_assembles_content() -> None: + """Streaming: accumulated content must appear in the single returned ModelResponse.""" + client = _make_client() + result = client._build_stream_response( + content="The answer is 42.", + response_id="stream-resp-1", + model="databricks/test-model", + ) + choices = result.choices + assert choices is not None and len(choices) > 0 + msg = choices[0].get("message") or getattr(choices[0], "message", None) + # Extract content regardless of dict or object + content = msg.get("content") if isinstance(msg, dict) else getattr(msg, "content", None) + assert content == "The answer is 42." + + +def test_build_stream_response_uses_fallback_id_when_empty() -> None: + client = _make_client() + result = client._build_stream_response( + content="hi", response_id="", model="m" + ) + assert result.id is not None + assert result.id != "" + + +def test_build_stream_response_sets_finish_reason_stop() -> None: + client = _make_client() + result = client._build_stream_response(content="done", response_id="rid", model="m") + choice = result.choices[0] + finish_reason = ( + choice.get("finish_reason") if isinstance(choice, dict) + else getattr(choice, "finish_reason", None) + ) + assert finish_reason == "stop" + + +# --------------------------------------------------------------------------- +# URL construction +# --------------------------------------------------------------------------- + +def test_chat_completion_builds_correct_url_workspace_host() -> None: + """Workspace host → gateway is reverse-proxied at /ai-gateway.""" + client = _make_client(max_retries=0) + captured_url: list[str] = [] + captured_body: list[dict] = [] + + def mock_post(url, headers=None, json=None, **_kw): + captured_url.append(url) + captured_body.append(json or {}) + return httpx.Response(200, json=_make_success_response()) + + with patch.object(client._http, "post", side_effect=mock_post): + client.chat_completion( + model="databricks-meta-llama-3-3-70b-instruct", + messages=[{"role": "user", "content": "hello"}], + ) + + assert captured_url == [f"{_HOST}/ai-gateway/mlflow/v1/chat/completions"] + # The mlflow path no longer carries the endpoint in the URL — it must be + # in the request body so the gateway knows which model to route to. + assert captured_body[0].get("model") == "databricks-meta-llama-3-3-70b-instruct" + + +def test_chat_completion_uses_dedicated_gateway_host_when_set() -> None: + """Dedicated *.ai-gateway.* host is used as-is (no /ai-gateway prefix).""" + dedicated = "https://9999999999999999.ai-gateway.cloud.databricks.com" + client = _make_client(max_retries=0, ai_gateway_host=dedicated) + captured_url: list[str] = [] + + def mock_post(url, headers=None, json=None, **_kw): + captured_url.append(url) + return httpx.Response( + 200, + json={ + "id": "msg_x", + "type": "message", + "role": "assistant", + "content": [{"type": "text", "text": "hi"}], + "stop_reason": "end_turn", + "usage": {"input_tokens": 1, "output_tokens": 1}, + }, + ) + + with patch.object(client._http, "post", side_effect=mock_post): + client.chat_completion( + model="databricks-claude-opus-4-6", + messages=[{"role": "user", "content": "hi"}], + ) + + assert captured_url == [f"{dedicated}/anthropic/v1/messages"] + + +# --------------------------------------------------------------------------- +# resolve_family — metadata-first routing with name-pattern fallback +# --------------------------------------------------------------------------- + +def test_resolve_family_uses_metadata_when_available() -> None: + """When the describe call returns ``api_types``, that wins over the name.""" + client = _make_client(max_retries=0, metadata_probe=True) + + # Model name screams "openai chat" but metadata says anthropic. + meta_response = httpx.Response( + 200, + json={ + "config": { + "served_entities": [{ + "foundation_model": {"api_types": ["anthropic/v1/messages"]}, + }], + }, + }, + ) + with patch.object(client._http, "get", return_value=meta_response): + family = client.resolve_family("my-custom-endpoint") + + assert family is ProviderFamily.ANTHROPIC, ( + "metadata api_types must authoritatively override the name pattern" + ) + + +def test_resolve_family_falls_back_to_name_when_metadata_fails() -> None: + """If the describe call errors, we must fall back to ``detect_family``.""" + client = _make_client(max_retries=0, metadata_probe=True) + + with patch.object( + client._http, "get", + side_effect=httpx.HTTPError("boom"), + ): + family = client.resolve_family("databricks-claude-sonnet-4-5") + + assert family is ProviderFamily.ANTHROPIC, ( + "name-pattern fallback must match *claude* → ANTHROPIC" + ) + + +def test_resolve_family_caches_positive_hit() -> None: + """A successful metadata resolve should not trigger a second describe call.""" + client = _make_client(max_retries=0, metadata_probe=True) + + meta_response = httpx.Response( + 200, + json={"config": {"served_entities": [ + {"foundation_model": {"api_types": ["gemini/v1/generateContent"]}}, + ]}}, + ) + with patch.object(client._http, "get", return_value=meta_response) as mock_get: + first = client.resolve_family("databricks-gemini-x") + second = client.resolve_family("databricks-gemini-x") + + assert first is second is ProviderFamily.GEMINI + assert mock_get.call_count == 1, "second resolve must be served from cache" + + +def test_resolve_family_default_skips_metadata_probe() -> None: + """Default (metadata_probe=False) must NOT hit the workspace URL. + + The whole point of the connector is to send FM traffic to the AI Gateway + host; the workspace URL is only for auth/discovery and must be touched + 'only as required'. With the default config, resolve_family must rely on + detect_family(model) and never issue a metadata GET. + """ + client = _make_client(max_retries=0) # metadata_probe defaults to False + + with patch.object(client._http, "get") as mock_get: + family_claude = client.resolve_family("databricks-claude-opus-4-6") + family_gemini = client.resolve_family("databricks-gemini-2-5-flash") + family_gpt5 = client.resolve_family("databricks-gpt-5-4-mini") + family_chat = client.resolve_family("databricks-meta-llama-3-3-70b-instruct") + + # Name-pattern detection must give the right family for each. + assert family_claude is ProviderFamily.ANTHROPIC + assert family_gemini is ProviderFamily.GEMINI + assert family_gpt5 is ProviderFamily.OPENAI_RESPONSES + assert family_chat is ProviderFamily.OPENAI + + # Critical assertion: NO workspace metadata GET on the FM hot path. + assert mock_get.call_count == 0, ( + "Default config must not issue any GET against the workspace URL " + "(the workspace must only be hit 'as required' — not on every chat)." + ) + + +# --------------------------------------------------------------------------- +# __del__ cleanup (P1-1) +# --------------------------------------------------------------------------- + +def test_del_closes_http_client() -> None: + """__del__ must close the singleton httpx.Client to release connections.""" + client = _make_client() + with patch.object(client._http, "close") as mock_close: + client.__del__() + mock_close.assert_called_once() + + +def test_del_is_idempotent_on_exception() -> None: + """__del__ must not raise even if the client is already closed.""" + client = _make_client() + client._http.close() # close manually first + client.__del__() # should not raise + + +def test_explicit_close_works() -> None: + """close() must close the singleton httpx.Client.""" + client = _make_client() + with patch.object(client._http, "close") as mock_close: + client.close() + mock_close.assert_called_once() diff --git a/tests/sdk/llm/providers/databricks/test_discovery.py b/tests/sdk/llm/providers/databricks/test_discovery.py new file mode 100644 index 0000000000..ffadede650 --- /dev/null +++ b/tests/sdk/llm/providers/databricks/test_discovery.py @@ -0,0 +1,573 @@ +"""Tests for Databricks FMAPI model discovery. + +Covers: filter logic (endpoint_type, task, state.ready), User-Agent header (PWAF), +TTL cache (hit/miss/expiry), error handling (returns [] silently), and +list_models_from_env env-var handling. +""" + +from __future__ import annotations + +import time +from unittest.mock import patch + +import httpx +import pytest + +import openhands.sdk.llm.providers.databricks.discovery as discovery_module +from openhands.sdk.llm.providers.databricks.auth import DatabricksCredentials +from openhands.sdk.llm.providers.databricks.discovery import ( + CURATED_DATABRICKS_MODELS, + DiscoveredEndpoint, + ModelPickerEntry, + get_picker_entries, + list_chat_endpoints, + list_foundation_models, + list_models_from_env, +) +from openhands.sdk.llm.providers.databricks.models import ProviderFamily +from openhands.sdk.llm.providers.databricks.utils import USER_AGENT + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +_HOST = "https://adb-123.azuredatabricks.net" +_DISCOVERY_URL = f"{_HOST}/api/2.0/serving-endpoints" + + +def _discovery_response(status: int, body: dict) -> httpx.Response: + """Build a Response httpx can run raise_for_status() on (needs bound request).""" + req = httpx.Request("GET", _DISCOVERY_URL) + return httpx.Response(status, json=body, request=req) + + +def _make_credentials(token: str = "test-token") -> DatabricksCredentials: + return DatabricksCredentials( + host=_HOST, + get_token=lambda: token, + auth_method="pat", + ) + + +def _make_endpoints_payload(endpoints: list[dict]) -> dict: + return {"endpoints": endpoints} + + +def _fmapi_ep(name: str, ready: bool = True) -> dict: + return { + "name": name, + "endpoint_type": "FOUNDATION_MODEL_API", + "task": "llm/v1/chat", + "state": {"ready": "READY" if ready else "NOT_READY"}, + } + + +def _non_fmapi_ep(name: str) -> dict: + return { + "name": name, + "endpoint_type": "CUSTOM_MODEL", + "task": "llm/v1/chat", + "state": {"ready": "READY"}, + } + + +def _embedding_ep(name: str) -> dict: + return { + "name": name, + "endpoint_type": "FOUNDATION_MODEL_API", + "task": "llm/v1/embeddings", # not chat + "state": {"ready": "READY"}, + } + + +@pytest.fixture(autouse=True) +def reset_discovery_cache() -> None: + """Reset module-level cache between tests to prevent interference.""" + discovery_module._CACHED_MODELS = [] + discovery_module._CACHE_EXPIRES_AT = 0.0 + yield + discovery_module._CACHED_MODELS = [] + discovery_module._CACHE_EXPIRES_AT = 0.0 + + +# --------------------------------------------------------------------------- +# list_foundation_models — filter logic +# --------------------------------------------------------------------------- + +def test_list_foundation_models_returns_fmapi_chat_endpoints() -> None: + """Only FOUNDATION_MODEL_API + llm/v1/chat + READY endpoints are returned.""" + payload = _make_endpoints_payload([ + _fmapi_ep("databricks-meta-llama-3-3-70b-instruct"), + _fmapi_ep("databricks-dbrx-instruct"), + ]) + with patch("httpx.get", return_value=_discovery_response(200, payload)): + models = list_foundation_models(_make_credentials()) + + assert "databricks/databricks-meta-llama-3-3-70b-instruct" in models + assert "databricks/databricks-dbrx-instruct" in models + assert len(models) == 2 + + +def test_list_foundation_models_excludes_non_fmapi_endpoints() -> None: + """CUSTOM_MODEL endpoints must not be returned.""" + payload = _make_endpoints_payload([ + _fmapi_ep("llama-model"), + _non_fmapi_ep("custom-agent-endpoint"), + ]) + with patch("httpx.get", return_value=_discovery_response(200, payload)): + models = list_foundation_models(_make_credentials()) + + assert "databricks/llama-model" in models + assert "databricks/custom-agent-endpoint" not in models + assert len(models) == 1 + + +def test_list_foundation_models_excludes_embedding_endpoints() -> None: + """llm/v1/embeddings task must not be included (only llm/v1/chat).""" + payload = _make_endpoints_payload([ + _fmapi_ep("chat-model"), + _embedding_ep("embed-model"), + ]) + with patch("httpx.get", return_value=_discovery_response(200, payload)): + models = list_foundation_models(_make_credentials()) + + assert "databricks/chat-model" in models + assert "databricks/embed-model" not in models + + +def test_list_foundation_models_excludes_not_ready_endpoints() -> None: + """Endpoints with state.ready != 'READY' must be excluded.""" + payload = _make_endpoints_payload([ + _fmapi_ep("ready-model", ready=True), + _fmapi_ep("loading-model", ready=False), + ]) + with patch("httpx.get", return_value=_discovery_response(200, payload)): + models = list_foundation_models(_make_credentials()) + + assert "databricks/ready-model" in models + assert "databricks/loading-model" not in models + + +def test_list_foundation_models_empty_workspace() -> None: + """Empty endpoints list returns empty list.""" + with patch("httpx.get", return_value=_discovery_response(200, {"endpoints": []})): + models = list_foundation_models(_make_credentials()) + assert models == [] + + +# --------------------------------------------------------------------------- +# PWAF: User-Agent header on discovery calls +# --------------------------------------------------------------------------- + +def test_list_foundation_models_sends_user_agent() -> None: + """PWAF: User-Agent header must be present on the discovery GET request.""" + captured_headers: dict = {} + + def mock_get(url, headers=None, timeout=None): + captured_headers.update(headers or {}) + return _discovery_response(200, {"endpoints": []}) + + with patch("httpx.get", side_effect=mock_get): + list_foundation_models(_make_credentials()) + + assert captured_headers.get("User-Agent") == USER_AGENT + + +def test_list_foundation_models_sends_authorization() -> None: + """Authorization header must be present with Bearer scheme.""" + captured_headers: dict = {} + + def mock_get(url, headers=None, timeout=None): + captured_headers.update(headers or {}) + return _discovery_response(200, {"endpoints": []}) + + with patch("httpx.get", side_effect=mock_get): + list_foundation_models(_make_credentials(token="test-tok")) + + assert captured_headers.get("Authorization") == "Bearer test-tok" + + +# --------------------------------------------------------------------------- +# list_models_from_env — TTL cache +# --------------------------------------------------------------------------- + +def test_list_models_from_env_returns_empty_without_env_vars(monkeypatch) -> None: + """Returns [] when DATABRICKS_HOST or DATABRICKS_TOKEN are not set.""" + monkeypatch.delenv("DATABRICKS_HOST", raising=False) + monkeypatch.delenv("DATABRICKS_TOKEN", raising=False) + monkeypatch.delenv("DATABRICKS_ACCESS_TOKEN", raising=False) + + models = list_models_from_env() + assert models == [] + + +def test_list_models_from_env_uses_cache_on_second_call(monkeypatch) -> None: + """Second call within TTL must use cache, not make a new HTTP request.""" + monkeypatch.setenv("DATABRICKS_HOST", _HOST) + monkeypatch.setenv("DATABRICKS_TOKEN", "test-tok") + + with patch( + "httpx.get", + return_value=_discovery_response(200, {"endpoints": [_fmapi_ep("m1")]}), + ) as mock_get: + first = list_models_from_env() + second = list_models_from_env() + + # Should only call httpx.get once; second call uses cache + assert mock_get.call_count == 1 + assert first == second + assert "databricks/m1" in first + + +def test_list_models_from_env_refreshes_after_ttl_expiry(monkeypatch) -> None: + """After TTL expires, a new HTTP call should be made.""" + monkeypatch.setenv("DATABRICKS_HOST", _HOST) + monkeypatch.setenv("DATABRICKS_TOKEN", "tok") + + # Pre-populate cache with an expired timestamp + discovery_module._CACHED_MODELS = ["databricks/old-model"] + discovery_module._CACHE_EXPIRES_AT = time.time() - 1 # already expired + + with patch( + "httpx.get", + return_value=_discovery_response(200, {"endpoints": [_fmapi_ep("new-model")]}), + ): + models = list_models_from_env() + + assert "databricks/new-model" in models + assert "databricks/old-model" not in models + + +def test_list_models_from_env_returns_empty_on_http_error(monkeypatch) -> None: + """HTTP errors must be swallowed and return [] (never raise).""" + monkeypatch.setenv("DATABRICKS_HOST", _HOST) + monkeypatch.setenv("DATABRICKS_TOKEN", "tok") + + with patch("httpx.get", side_effect=httpx.ConnectError("connection refused")): + models = list_models_from_env() + + assert models == [] + + +def test_list_models_from_env_model_names_prefixed(monkeypatch) -> None: + """All returned model names must be prefixed with 'databricks/'.""" + monkeypatch.setenv("DATABRICKS_HOST", _HOST) + monkeypatch.setenv("DATABRICKS_TOKEN", "tok") + + endpoints = [ + _fmapi_ep("databricks-meta-llama-3-3-70b-instruct"), + _fmapi_ep("databricks-dbrx-instruct"), + ] + with patch( + "httpx.get", + return_value=_discovery_response(200, {"endpoints": endpoints}), + ): + models = list_models_from_env() + + for m in models: + assert m.startswith("databricks/"), f"Model {m!r} missing 'databricks/' prefix" + + +# --------------------------------------------------------------------------- +# External-model inclusion (AI Gateway parity with FM endpoints) +# --------------------------------------------------------------------------- + +def _external_ep(name: str, ready: bool = True) -> dict: + """EXTERNAL_MODEL chat endpoint — e.g. customer-configured gpt-5 / gemini proxy.""" + return { + "name": name, + "endpoint_type": "EXTERNAL_MODEL", + "task": "llm/v1/chat", + "state": {"ready": "READY" if ready else "NOT_READY"}, + } + + +def test_list_foundation_models_includes_external_model_endpoints() -> None: + """EXTERNAL_MODEL chat endpoints must be returned alongside FOUNDATION_MODEL_API. + + External-model endpoints (e.g. customer-configured gpt-5 / gemini proxies) + are AI-Gateway-shaped and routed through the same native-API paths. + """ + payload = _make_endpoints_payload([ + _fmapi_ep("databricks-llama-3-3-70b"), + _external_ep("my-gpt5-proxy"), + ]) + with patch("httpx.get", return_value=_discovery_response(200, payload)): + models = list_foundation_models(_make_credentials()) + + assert "databricks/databricks-llama-3-3-70b" in models + assert "databricks/my-gpt5-proxy" in models, ( + "EXTERNAL_MODEL endpoints must be discoverable via list_foundation_models" + ) + assert len(models) == 2 + + +def test_list_foundation_models_excludes_custom_model_endpoints() -> None: + """CUSTOM_MODEL endpoints are still excluded — payload shape not guaranteed.""" + payload = _make_endpoints_payload([ + _fmapi_ep("fm-endpoint"), + _external_ep("external-endpoint"), + _non_fmapi_ep("custom-agent"), # endpoint_type=CUSTOM_MODEL + ]) + with patch("httpx.get", return_value=_discovery_response(200, payload)): + models = list_foundation_models(_make_credentials()) + + assert "databricks/fm-endpoint" in models + assert "databricks/external-endpoint" in models + assert "databricks/custom-agent" not in models + + +# --------------------------------------------------------------------------- +# list_chat_endpoints — structured output +# --------------------------------------------------------------------------- + +def test_list_chat_endpoints_returns_dataclass_records() -> None: + """Structured API returns DiscoveredEndpoint records with metadata intact.""" + payload = _make_endpoints_payload([ + { + "name": "fm-llama", + "endpoint_type": "FOUNDATION_MODEL_API", + "task": "llm/v1/chat", + "state": {"ready": "READY"}, + "creator": "alice@example.com", + }, + { + "name": "ext-gpt5", + "endpoint_type": "EXTERNAL_MODEL", + "task": "llm/v1/chat", + "state": {"ready": "READY"}, + "creator": "bob@example.com", + }, + ]) + with patch("httpx.get", return_value=_discovery_response(200, payload)): + eps = list_chat_endpoints(_make_credentials()) + + assert len(eps) == 2 + assert all(isinstance(e, DiscoveredEndpoint) for e in eps) + + by_name = {e.name: e for e in eps} + assert by_name["fm-llama"].endpoint_type == "FOUNDATION_MODEL_API" + assert by_name["fm-llama"].qualified_name == "databricks/fm-llama" + assert by_name["fm-llama"].ready is True + assert by_name["fm-llama"].creator == "alice@example.com" + + assert by_name["ext-gpt5"].endpoint_type == "EXTERNAL_MODEL" + assert by_name["ext-gpt5"].qualified_name == "databricks/ext-gpt5" + + +def test_list_chat_endpoints_include_not_ready_opt_in() -> None: + """Not-ready endpoints are excluded by default, included on opt-in.""" + payload = _make_endpoints_payload([ + _fmapi_ep("ready-one", ready=True), + _fmapi_ep("loading-one", ready=False), + ]) + + with patch("httpx.get", return_value=_discovery_response(200, payload)): + default = list_chat_endpoints(_make_credentials()) + assert [e.name for e in default] == ["ready-one"] + + with patch("httpx.get", return_value=_discovery_response(200, payload)): + with_loading = list_chat_endpoints(_make_credentials(), include_not_ready=True) + names = sorted(e.name for e in with_loading) + assert names == ["loading-one", "ready-one"] + loading = next(e for e in with_loading if e.name == "loading-one") + assert loading.ready is False + + +def test_list_chat_endpoints_skips_unnamed_rows() -> None: + """Endpoints missing a ``name`` field must be silently skipped.""" + payload = _make_endpoints_payload([ + { + "endpoint_type": "FOUNDATION_MODEL_API", + "task": "llm/v1/chat", + "state": {"ready": "READY"}, + }, + _fmapi_ep("good-one"), + ]) + with patch("httpx.get", return_value=_discovery_response(200, payload)): + eps = list_chat_endpoints(_make_credentials()) + assert [e.name for e in eps] == ["good-one"] + + +# --------------------------------------------------------------------------- +# Two-tier picker: CURATED_DATABRICKS_MODELS + get_picker_entries +# --------------------------------------------------------------------------- + + +def test_curated_list_is_claude_gpt_gemini_only() -> None: + """Curated tier-1 set covers only the three native-API families we target. + + Llama / DBRX / legacy endpoints must *not* appear in the curated list — + they surface automatically via discovery only if the workspace has them. + """ + families = {e.family for e in CURATED_DATABRICKS_MODELS} + assert families == { + ProviderFamily.ANTHROPIC, + ProviderFamily.OPENAI, + ProviderFamily.OPENAI_RESPONSES, + ProviderFamily.GEMINI, + } + + names = [e.name for e in CURATED_DATABRICKS_MODELS] + assert all(n.startswith("databricks-") for n in names) + forbidden = ("llama", "dbrx", "mixtral", "qwen", "deepseek") + for n in names: + for token in forbidden: + assert token not in n.lower(), ( + f"{n!r} leaked a non-curated family — curated tier is Claude/GPT/Gemini only" + ) + + +def test_curated_list_has_one_recommended_per_family() -> None: + """Exactly one ``recommended`` entry per family — the fast-and-good default.""" + recs_by_family: dict[ProviderFamily, list[str]] = {} + for e in CURATED_DATABRICKS_MODELS: + if e.recommended: + recs_by_family.setdefault(e.family, []).append(e.name) + + # GPT-5 (Responses) gets the recommended OpenAI slot because it's the + # gold-path OpenAI native API on Databricks. Plain OPENAI (gpt-oss) is + # listed but not recommended. + assert set(recs_by_family) == { + ProviderFamily.ANTHROPIC, + ProviderFamily.OPENAI_RESPONSES, + ProviderFamily.GEMINI, + } + for family, picks in recs_by_family.items(): + assert len(picks) == 1, f"{family} has >1 recommended pick: {picks}" + + +def test_curated_entries_have_qualified_name_prefix() -> None: + """Every curated entry must use the ``databricks/`` prefix — that's what + ``create_llm`` sees and routes on.""" + for e in CURATED_DATABRICKS_MODELS: + assert e.qualified_name == f"databricks/{e.name}" + assert e.source == "curated" + assert e.ready is True + assert e.endpoint_type == "FOUNDATION_MODEL_API" + + +def test_get_picker_entries_returns_curated_without_credentials() -> None: + """No creds → pure curated tier, no HTTP call.""" + with patch("httpx.get") as mock_get: + entries = get_picker_entries(credentials=None) + mock_get.assert_not_called() + + assert len(entries) == len(CURATED_DATABRICKS_MODELS) + assert all(isinstance(e, ModelPickerEntry) for e in entries) + assert all(e.source == "curated" for e in entries) + + +def test_get_picker_entries_sort_order_recommended_first() -> None: + """Recommended picks come first; then family (alpha), then name.""" + entries = get_picker_entries(credentials=None) + + # 1. All recommended come before any non-recommended. + first_non_rec = next( + (i for i, e in enumerate(entries) if not e.recommended), len(entries) + ) + rec_section = entries[:first_non_rec] + rest_section = entries[first_non_rec:] + assert all(e.recommended for e in rec_section) + assert all(not e.recommended for e in rest_section) + + # 2. Within each section, family alpha-sorted. + def _family_order(section: list[ModelPickerEntry]) -> list[str]: + return [e.family.value for e in section] + + for section in (rec_section, rest_section): + fams = _family_order(section) + assert fams == sorted(fams), f"Family ordering broken in section: {fams}" + + +def test_get_picker_entries_merges_discovered_on_top_of_curated() -> None: + """Discovery adds non-curated endpoints; curated entries get live signals.""" + # One endpoint overlaps curated (claude-sonnet-4-5), two are new. + payload = _make_endpoints_payload([ + _fmapi_ep("databricks-claude-sonnet-4-5"), # overlaps curated + _fmapi_ep("databricks-meta-llama-4-maverick"), # discovered only + _external_ep("customer-private-gpt"), # external-model discovered only + ]) + + with patch("httpx.get", return_value=_discovery_response(200, payload)): + entries = get_picker_entries(credentials=_make_credentials()) + + by_qn = {e.qualified_name: e for e in entries} + + # Overlap: curated entry kept recommended + its opinionated family, but + # gained "curated+discovered" source and the live endpoint_type/ready. + overlap = by_qn["databricks/databricks-claude-sonnet-4-5"] + assert overlap.source == "curated+discovered" + assert overlap.family is ProviderFamily.ANTHROPIC + assert overlap.recommended is True + assert overlap.endpoint_type == "FOUNDATION_MODEL_API" + assert overlap.ready is True + + # Discovered-only FMAPI entry — family inferred via detect_family. + llama = by_qn["databricks/databricks-meta-llama-4-maverick"] + assert llama.source == "discovered" + assert llama.family is ProviderFamily.OPENAI # llama → OpenAI Chat default + assert llama.recommended is False + assert llama.endpoint_type == "FOUNDATION_MODEL_API" + + # External-model endpoint shows up as "discovered" with its live type. + ext = by_qn["databricks/customer-private-gpt"] + assert ext.source == "discovered" + assert ext.endpoint_type == "EXTERNAL_MODEL" + + # Count: curated ∪ discovered, deduped by qualified_name. + expected = {e.qualified_name for e in CURATED_DATABRICKS_MODELS} | { + "databricks/databricks-meta-llama-4-maverick", + "databricks/customer-private-gpt", + } + assert {e.qualified_name for e in entries} == expected + + +def test_get_picker_entries_swallows_discovery_errors() -> None: + """If discovery blows up, curated tier is still returned intact.""" + with patch("httpx.get", side_effect=RuntimeError("workspace down")): + entries = get_picker_entries(credentials=_make_credentials()) + + # Curated list is returned as-is (no "curated+discovered" upgrades). + assert {e.qualified_name for e in entries} == { + c.qualified_name for c in CURATED_DATABRICKS_MODELS + } + assert all(e.source == "curated" for e in entries) + + +def test_get_picker_entries_include_curated_false_returns_only_discovered() -> None: + """Opt out of curated to get a pure live list — useful for admin UIs.""" + payload = _make_endpoints_payload([_fmapi_ep("live-only-endpoint")]) + with patch("httpx.get", return_value=_discovery_response(200, payload)): + entries = get_picker_entries( + credentials=_make_credentials(), + include_curated=False, + ) + assert [e.qualified_name for e in entries] == ["databricks/live-only-endpoint"] + assert entries[0].source == "discovered" + + +def test_get_picker_entries_include_discovered_false_skips_http() -> None: + """Opt out of discovery even with creds — no HTTP fired.""" + with patch("httpx.get") as mock_get: + entries = get_picker_entries( + credentials=_make_credentials(), + include_discovered=False, + ) + mock_get.assert_not_called() + assert len(entries) == len(CURATED_DATABRICKS_MODELS) + + +def test_get_picker_entries_user_agent_propagates_through_discovery() -> None: + """PWAF: every Databricks HTTP call (including discovery triggered by the + picker) must carry the ``openhands_oss/`` User-Agent.""" + payload = _make_endpoints_payload([_fmapi_ep("anything")]) + with patch( + "httpx.get", return_value=_discovery_response(200, payload) + ) as mock_get: + get_picker_entries(credentials=_make_credentials()) + + assert mock_get.called + _, kwargs = mock_get.call_args + assert kwargs["headers"]["User-Agent"] == USER_AGENT diff --git a/tests/sdk/llm/providers/databricks/test_llm.py b/tests/sdk/llm/providers/databricks/test_llm.py new file mode 100644 index 0000000000..2e55b79b21 --- /dev/null +++ b/tests/sdk/llm/providers/databricks/test_llm.py @@ -0,0 +1,441 @@ +"""Tests for DatabricksLLM and the create_llm factory. + +Covers: PAT construction, provider discriminator (P0-6), context window lookup, +max_output_tokens lookup, unknown model fallback, _transport_call prefix stripping, +Pydantic round-trip serialization, and create_llm factory routing. +""" + +from __future__ import annotations + +from unittest.mock import MagicMock, patch + +import pytest +from pydantic import SecretStr, ValidationError + +from openhands.sdk.llm.providers.databricks.llm import ( + DATABRICKS_CONTEXT_WINDOWS, + DATABRICKS_MAX_OUTPUT, + DatabricksLLM, +) +from openhands.sdk.llm.providers.databricks.models import ( + ProviderFamily, + StoredU2MTokens, +) + + +# --------------------------------------------------------------------------- +# Constants +# --------------------------------------------------------------------------- + +_HOST = "https://adb-123.azuredatabricks.net" +_MODEL_LLAMA = "databricks/databricks-meta-llama-3-3-70b-instruct" +_MODEL_DBRX = "databricks/databricks-dbrx-instruct" +_MODEL_CLAUDE = "databricks/databricks-claude-3-7-sonnet" +_MODEL_UNKNOWN = "databricks/my-custom-finetuned-model" + + +# --------------------------------------------------------------------------- +# Helper: minimal PAT-auth DatabricksLLM (no HTTP calls during construction) +# --------------------------------------------------------------------------- + +def _make_llm( + model: str = _MODEL_LLAMA, + host: str = _HOST, + token: str = "dapi-test", + ai_gateway_host: str | None = None, + **kwargs, +) -> DatabricksLLM: + """Build a DatabricksLLM with PAT auth for tests. + + The new architecture only requires *one* of databricks_host or + databricks_ai_gateway_host. We always set databricks_host (the canonical + workspace URL) and forward ai_gateway_host only when explicitly given, + so tests reflect the typical single-URL deployment by default. + """ + return DatabricksLLM( + model=model, + databricks_host=host, + databricks_ai_gateway_host=ai_gateway_host, + api_key=SecretStr(token), + usage_id="test-databricks-llm", + **kwargs, + ) + + +# --------------------------------------------------------------------------- +# Construction +# --------------------------------------------------------------------------- + +def test_construction_succeeds_with_pat_auth() -> None: + """DatabricksLLM must construct without HTTP calls when using PAT.""" + llm = _make_llm() + assert llm.model == _MODEL_LLAMA + assert llm.databricks_host == _HOST + + +def test_construction_exposes_db_client() -> None: + """_db_client private attr must be set after construction.""" + llm = _make_llm() + from openhands.sdk.llm.providers.databricks.client import DatabricksFMAPIClient + assert isinstance(llm._db_client, DatabricksFMAPIClient) + + +def test_construction_exposes_db_credentials() -> None: + """_db_credentials private attr must be set after construction.""" + from openhands.sdk.llm.providers.databricks.auth import DatabricksCredentials + llm = _make_llm() + assert isinstance(llm._db_credentials, DatabricksCredentials) + assert llm._db_credentials.auth_method == "pat" + + +def test_construction_rejects_non_https_host() -> None: + """databricks_host without https:// must raise ValueError at construction.""" + with pytest.raises(ValueError, match="https://"): + DatabricksLLM( + model=_MODEL_LLAMA, + databricks_host="http://adb-123.azuredatabricks.net", + databricks_ai_gateway_host=_HOST, + api_key=SecretStr("tok"), + usage_id="t", + ) + + +def test_construction_rejects_no_host_at_all() -> None: + """At least one of databricks_host / databricks_ai_gateway_host is required.""" + with pytest.raises(ValueError, match="databricks_host is required"): + DatabricksLLM( + model=_MODEL_LLAMA, + api_key=SecretStr("tok"), + usage_id="t", + ) + + +def test_construction_pat_with_only_workspace_host() -> None: + """PAT auth on a single-URL deployment: workspace URL is enough; the SDK + derives the AI Gateway base from it (/ai-gateway/).""" + llm = DatabricksLLM( + model=_MODEL_LLAMA, + databricks_host=_HOST, + api_key=SecretStr("tok"), + usage_id="t", + ) + assert llm.databricks_host == _HOST + assert llm.databricks_ai_gateway_host is None + assert llm._db_credentials.auth_method == "pat" + # Internally the FM client must end up routing through the workspace host. + assert llm._db_client._ai_gateway_host == _HOST + + +def test_construction_pat_with_dedicated_gateway_host_only() -> None: + """PAT auth on a dedicated gateway: the gateway host is sufficient.""" + dedicated = "https://9999999999999999.ai-gateway.cloud.databricks.com" + llm = DatabricksLLM( + model=_MODEL_LLAMA, + databricks_ai_gateway_host=dedicated, + api_key=SecretStr("tok"), + usage_id="t", + ) + assert llm.databricks_host is None + assert llm.databricks_ai_gateway_host == dedicated + assert llm._db_client._ai_gateway_host == dedicated + + +# --------------------------------------------------------------------------- +# P0-6: provider discriminator +# --------------------------------------------------------------------------- + +def test_provider_field_is_databricks() -> None: + """provider field must be exactly 'databricks' for Pydantic discriminator.""" + llm = _make_llm() + assert llm.provider == "databricks" + + +def test_provider_field_is_literal() -> None: + """provider must not be overridable by user-supplied kwargs.""" + # Attempt to construct with a different provider value should either be + # silently overridden to "databricks" or raise — it must never produce a + # DatabricksLLM with provider != "databricks". + llm = _make_llm() + assert llm.provider == "databricks" + + +# --------------------------------------------------------------------------- +# Context windows +# --------------------------------------------------------------------------- + +def test_context_window_llama_70b() -> None: + llm = _make_llm(model=_MODEL_LLAMA) + assert llm.max_input_tokens == DATABRICKS_CONTEXT_WINDOWS[_MODEL_LLAMA] + + +def test_context_window_dbrx() -> None: + llm = _make_llm(model=_MODEL_DBRX) + assert llm.max_input_tokens == DATABRICKS_CONTEXT_WINDOWS[_MODEL_DBRX] + assert llm.max_input_tokens == 32_768 + + +def test_context_window_claude() -> None: + """Claude-based Databricks models have 200K context window.""" + llm = _make_llm(model=_MODEL_CLAUDE) + assert llm.max_input_tokens == 200_000 + + +def test_context_window_unknown_model_fallback() -> None: + """Unknown Databricks models fall back to 128K context window.""" + llm = _make_llm(model=_MODEL_UNKNOWN) + assert llm.max_input_tokens == 128_000 + + +def test_max_output_tokens_dbrx() -> None: + llm = _make_llm(model=_MODEL_DBRX) + assert llm.max_output_tokens == DATABRICKS_MAX_OUTPUT[_MODEL_DBRX] + assert llm.max_output_tokens == 4_096 + + +def test_max_output_tokens_claude() -> None: + llm = _make_llm(model=_MODEL_CLAUDE) + assert llm.max_output_tokens == 8_192 + + +def test_max_output_tokens_unknown_model_fallback() -> None: + """Unknown models fall back to 16K max output.""" + llm = _make_llm(model=_MODEL_UNKNOWN) + assert llm.max_output_tokens == 16_384 + + +# --------------------------------------------------------------------------- +# _transport_call — prefix stripping +# --------------------------------------------------------------------------- + +def test_transport_call_strips_databricks_prefix() -> None: + """_transport_call must strip 'databricks/' prefix before calling FMAPI.""" + llm = _make_llm() + captured_model: list[str] = [] + + def mock_chat_completion(model, messages, stream=False, on_token=None, **kwargs): + captured_model.append(model) + return MagicMock() + + with patch.object(llm._db_client, "chat_completion", side_effect=mock_chat_completion): + llm._transport_call(messages=[{"role": "user", "content": "hi"}]) + + assert len(captured_model) == 1 + # Should be bare endpoint name, not prefixed + assert not captured_model[0].startswith("databricks/") + assert captured_model[0] == "databricks-meta-llama-3-3-70b-instruct" + + +def test_transport_call_passes_through_streaming_flag() -> None: + llm = _make_llm() + + def mock_chat_completion(model, messages, stream=False, on_token=None, **kwargs): + return MagicMock() + + with patch.object(llm._db_client, "chat_completion", side_effect=mock_chat_completion) as mock_cc: + llm._transport_call(messages=[], enable_streaming=True) + + call_kwargs = mock_cc.call_args + assert call_kwargs.kwargs.get("stream") is True + + +# --------------------------------------------------------------------------- +# Resilience knob passthrough +# --------------------------------------------------------------------------- + +def test_custom_timeouts_propagate_to_client() -> None: + """Resilience knobs must be forwarded to DatabricksFMAPIClient.""" + llm = _make_llm( + databricks_connect_timeout_s=5.0, + databricks_read_timeout_s=60.0, + databricks_max_retries=5, + ) + assert llm._db_client._timeouts.connect_s == 5.0 + assert llm._db_client._timeouts.read_s == 60.0 + assert llm._db_client._max_retries == 5 + + +# --------------------------------------------------------------------------- +# Pydantic round-trip serialization (P0-6) +# --------------------------------------------------------------------------- + +def test_pydantic_roundtrip_preserves_provider_field() -> None: + """model_dump / model_validate round-trip must preserve provider='databricks'.""" + llm = _make_llm() + data = llm.model_dump() + assert data["provider"] == "databricks" + + +def test_pydantic_json_roundtrip_preserves_provider_field() -> None: + """JSON serialization must include provider field for deserialization dispatch.""" + llm = _make_llm() + json_str = llm.model_dump_json() + assert '"provider":"databricks"' in json_str or '"provider": "databricks"' in json_str + + +# --------------------------------------------------------------------------- +# create_llm factory +# --------------------------------------------------------------------------- + +def test_create_llm_routes_databricks_prefix() -> None: + """create_llm must return DatabricksLLM for 'databricks/' prefixed models.""" + from openhands.sdk import create_llm + + llm = create_llm( + model=_MODEL_LLAMA, + databricks_host=_HOST, + api_key=SecretStr("dapi-tok"), + usage_id="factory-test", + ) + assert isinstance(llm, DatabricksLLM) + + +def test_create_llm_routes_non_databricks_to_base_llm() -> None: + """create_llm must return the base LLM for non-Databricks models.""" + from openhands.sdk import create_llm + from openhands.sdk.llm import LLM + + llm = create_llm(model="claude-sonnet-4-20250514", usage_id="base-test") + assert type(llm) is LLM + assert not isinstance(llm, DatabricksLLM) + + +def test_create_llm_empty_model_raises_validation_error() -> None: + """create_llm('') delegates to base LLM, which rejects an empty model name.""" + from openhands.sdk import create_llm + + with pytest.raises(ValidationError, match="model must be specified"): + create_llm(model="", usage_id="empty-test") + + +# --------------------------------------------------------------------------- +# PWAF surfaces: auth_method / predicted_family / resolve_family +# --------------------------------------------------------------------------- + +def test_auth_method_property_reflects_pat_construction() -> None: + """auth_method must mirror the strategy resolved at construction time.""" + llm = _make_llm() + assert llm.auth_method == "pat" + + +@pytest.mark.parametrize( + "model,expected", + [ + ("databricks/databricks-meta-llama-3-3-70b-instruct", ProviderFamily.OPENAI), + ("databricks/databricks-dbrx-instruct", ProviderFamily.OPENAI), + ("databricks/databricks-gpt-oss-120b", ProviderFamily.OPENAI), + ("databricks/databricks-claude-sonnet-4-5", ProviderFamily.ANTHROPIC), + ("databricks/databricks-claude-opus-4-6", ProviderFamily.ANTHROPIC), + ("databricks/databricks-gemini-2-5-flash", ProviderFamily.GEMINI), + ("databricks/databricks-gpt-5-4", ProviderFamily.OPENAI_RESPONSES), + ("databricks/databricks-gpt-5-4-mini", ProviderFamily.OPENAI_RESPONSES), + ], +) +def test_predicted_family_no_http_call(model: str, expected: ProviderFamily) -> None: + """predicted_family must be pure-compute (no metadata probe).""" + llm = _make_llm(model=model) + # If this does HTTP, it will fail because adb-123 is unreachable from tests. + # detect_family is sync + pure, so this must succeed without mocking. + assert llm.predicted_family is expected + + +def test_resolve_family_delegates_to_client_and_caches() -> None: + """resolve_family delegates to DatabricksFMAPIClient.resolve_family and caches. + + Metadata probing is opt-in (`databricks_metadata_probe=True`); without + that flag `resolve_family` is name-pattern only and never hits the wire. + """ + llm = _make_llm( + model="databricks/databricks-claude-sonnet-4-5", + databricks_metadata_probe=True, + ) + + import httpx + meta_response = httpx.Response(200, json={ + "config": {"served_entities": [ + {"foundation_model": {"api_types": ["anthropic/v1/messages"]}}, + ]}, + }) + with patch.object(llm._db_client._http, "get", return_value=meta_response) as mg: + f1 = llm.resolve_family() + f2 = llm.resolve_family() + assert f1 is f2 is ProviderFamily.ANTHROPIC + assert mg.call_count == 1, "second resolve must be cache-served" + + +def test_resolve_family_default_skips_metadata_probe() -> None: + """Default resolve_family is name-pattern only; no workspace GET.""" + llm = _make_llm(model="databricks/databricks-claude-sonnet-4-5") + + with patch.object(llm._db_client._http, "get") as mg: + family = llm.resolve_family() + assert family is ProviderFamily.ANTHROPIC + assert mg.call_count == 0, "default path must not hit workspace metadata" + + +def test_transport_call_logs_family(caplog) -> None: + """_transport_call must emit a debug log with predicted_family + auth_method.""" + import logging + + llm = _make_llm(model="databricks/databricks-claude-sonnet-4-5") + + def mock_chat_completion(model, messages, stream=False, on_token=None, **kwargs): + return MagicMock() + + with caplog.at_level(logging.DEBUG, logger="openhands.sdk.llm.providers.databricks.llm"): + with patch.object(llm._db_client, "chat_completion", side_effect=mock_chat_completion): + llm._transport_call(messages=[{"role": "user", "content": "hi"}]) + + log_records = [r for r in caplog.records if r.message == "databricks_transport_call"] + assert log_records, "expected a databricks_transport_call debug record" + record = log_records[-1] + assert record.__dict__.get("predicted_family") == "anthropic" + assert record.__dict__.get("auth_method") == "pat" + + +# --------------------------------------------------------------------------- +# Expanded model capability tables (current-gen FM / external models) +# --------------------------------------------------------------------------- + +@pytest.mark.parametrize( + "model,expected_ctx", + [ + # Anthropic family — 200K across Claude generations on Databricks + ("databricks/databricks-claude-sonnet-4-5", 200_000), + ("databricks/databricks-claude-opus-4-6", 200_000), + ("databricks/databricks-claude-haiku-4-5", 200_000), + # Gemini — 1M token context + ("databricks/databricks-gemini-2-5-flash", 1_048_576), + ("databricks/databricks-gemini-2-5-pro", 1_048_576), + # GPT-5 Responses family — 400K context + ("databricks/databricks-gpt-5-4", 400_000), + ("databricks/databricks-gpt-5-4-mini", 400_000), + # gpt-oss — 128K + ("databricks/databricks-gpt-oss-120b", 128_000), + ], +) +def test_current_gen_context_windows(model: str, expected_ctx: int) -> None: + """Current-generation models must have correct context windows in the table.""" + llm = _make_llm(model=model) + assert llm.max_input_tokens == expected_ctx, ( + f"{model}: expected ctx={expected_ctx}, got {llm.max_input_tokens}" + ) + + +@pytest.mark.parametrize( + "model,min_output", + [ + ("databricks/databricks-claude-sonnet-4-5", 64_000), + ("databricks/databricks-gemini-2-5-flash", 16_384), # reasoning budget + ("databricks/databricks-gpt-5-4", 16_384), + ("databricks/databricks-gpt-oss-120b", 16_384), + ], +) +def test_reasoning_models_have_generous_output_budget( + model: str, min_output: int, +) -> None: + """Reasoning models' max_output_tokens must leave room for thinking + output.""" + llm = _make_llm(model=model) + assert llm.max_output_tokens >= min_output, ( + f"{model}: max_output_tokens too tight for reasoning — " + f"got {llm.max_output_tokens}, need >= {min_output}" + ) diff --git a/tests/sdk/llm/providers/databricks/test_models.py b/tests/sdk/llm/providers/databricks/test_models.py new file mode 100644 index 0000000000..d3a76905d0 --- /dev/null +++ b/tests/sdk/llm/providers/databricks/test_models.py @@ -0,0 +1,286 @@ +"""Tests for Databricks AI Gateway routing primitives. + +Covers: + * ``ProviderFamily`` enum shape + * ``detect_family`` — name-pattern routing (fast path, no HTTP) + * ``pick_family_from_api_types`` — metadata routing (authoritative) + * ``AIGatewayPaths.url`` — URL construction per family + +These primitives are the hot path of the whole connector; they must never +regress. The live E2E tests already exercise them end-to-end; this file locks +in the contract at the unit level so refactors can't silently flip routing. +""" + +from __future__ import annotations + +import pytest + +from openhands.sdk.llm.providers.databricks.models import ( + AIGatewayPaths, + ProviderFamily, + detect_family, + pick_family_from_api_types, +) + + +# --------------------------------------------------------------------------- +# ProviderFamily enum shape +# --------------------------------------------------------------------------- + +def test_provider_family_enum_values() -> None: + """Exactly four families must be exposed (OpenAI Chat / Responses, Anthropic, Gemini).""" + assert {f.value for f in ProviderFamily} == { + "openai", "openai_responses", "anthropic", "gemini", + } + + +def test_provider_family_openai_is_default_fallback() -> None: + """``OPENAI`` is the universal fallback — must be present and import-stable.""" + assert ProviderFamily.OPENAI.value == "openai" + + +# --------------------------------------------------------------------------- +# detect_family — name-pattern routing +# --------------------------------------------------------------------------- + +@pytest.mark.parametrize( + "model,expected", + [ + # Anthropic — substring match + ("databricks-claude-sonnet-4-5", ProviderFamily.ANTHROPIC), + ("databricks/databricks-claude-opus-4", ProviderFamily.ANTHROPIC), + ("claude-haiku", ProviderFamily.ANTHROPIC), + ("my-custom-claude-proxy", ProviderFamily.ANTHROPIC), + + # Gemini — substring match + ("databricks-gemini-2-5-flash", ProviderFamily.GEMINI), + ("databricks/databricks-gemini-pro", ProviderFamily.GEMINI), + ("gemini-1-5-pro", ProviderFamily.GEMINI), + + # GPT-5 series → Responses API (bare and prefixed, all variants) + ("gpt-5", ProviderFamily.OPENAI_RESPONSES), + ("gpt-5-mini", ProviderFamily.OPENAI_RESPONSES), + ("gpt-5-nano", ProviderFamily.OPENAI_RESPONSES), + ("gpt-5-1", ProviderFamily.OPENAI_RESPONSES), + ("gpt-5-1-codex-max", ProviderFamily.OPENAI_RESPONSES), + ("gpt-5-1-codex-mini", ProviderFamily.OPENAI_RESPONSES), + ("gpt-5-2-codex", ProviderFamily.OPENAI_RESPONSES), + ("gpt-5-3-codex", ProviderFamily.OPENAI_RESPONSES), + ("databricks-gpt-5-4", ProviderFamily.OPENAI_RESPONSES), + ("databricks/databricks-gpt-5-4-mini", ProviderFamily.OPENAI_RESPONSES), + # Future numbered GPT generations inherit Responses API automatically + ("gpt-6", ProviderFamily.OPENAI_RESPONSES), + ("gpt-6-mini", ProviderFamily.OPENAI_RESPONSES), + ("databricks/databricks-gpt-7-turbo", ProviderFamily.OPENAI_RESPONSES), + + # gpt-oss and everything else stays on MLflow Chat Completions + ("gpt-oss-120b", ProviderFamily.OPENAI), + ("databricks-gpt-oss-120b", ProviderFamily.OPENAI), + ("databricks-meta-llama-3-3-70b-instruct", ProviderFamily.OPENAI), + ("databricks-dbrx-instruct", ProviderFamily.OPENAI), + ("mistral-7b-instruct", ProviderFamily.OPENAI), + + # Case-insensitive + ("DATABRICKS-CLAUDE-SONNET-4-5", ProviderFamily.ANTHROPIC), + ("Databricks/Gemini-Flash", ProviderFamily.GEMINI), + ], +) +def test_detect_family_name_patterns(model: str, expected: ProviderFamily) -> None: + assert detect_family(model) == expected, ( + f"routing regression: detect_family({model!r}) != {expected}" + ) + + +def test_detect_family_gpt_oss_must_not_route_to_responses() -> None: + """Regression guard: gpt-oss-* has no Responses API — keep on MLflow Chat. + + The ``re.match(r"gpt-\\d", name)`` rule excludes ``gpt-oss-*`` by + construction: ``gpt-oss`` starts with ``gpt-o``, not ``gpt-``. + This test pins that boundary so a future regex change doesn't silently + break gpt-oss routing. + """ + assert detect_family("gpt-oss-120b") is ProviderFamily.OPENAI + assert detect_family("databricks-gpt-oss-20b") is ProviderFamily.OPENAI + assert detect_family("databricks/databricks-gpt-oss-120b") is ProviderFamily.OPENAI + + +# --------------------------------------------------------------------------- +# pick_family_from_api_types — metadata routing +# --------------------------------------------------------------------------- + +def test_pick_family_prefers_anthropic_over_openai_chat() -> None: + """If an endpoint exposes both ``anthropic/v1/messages`` and the mlflow chat + shim, we must pick the native Anthropic route.""" + family = pick_family_from_api_types( + ["anthropic/v1/messages", "mlflow/v1/chat/completions"], + ) + assert family is ProviderFamily.ANTHROPIC + + +def test_pick_family_prefers_gemini_over_openai_chat() -> None: + family = pick_family_from_api_types( + ["gemini/v1/generateContent", "mlflow/v1/chat/completions"], + ) + assert family is ProviderFamily.GEMINI + + +def test_pick_family_prefers_responses_over_openai_chat() -> None: + family = pick_family_from_api_types( + ["openai/v1/responses", "mlflow/v1/chat/completions"], + ) + assert family is ProviderFamily.OPENAI_RESPONSES + + +def test_pick_family_priority_order_anthropic_wins() -> None: + """When multiple specific api_types are present, priority order + decides — Anthropic wins over Gemini wins over Responses (documented).""" + family = pick_family_from_api_types( + ["gemini/v1/generateContent", + "anthropic/v1/messages", + "openai/v1/responses"], + ) + assert family is ProviderFamily.ANTHROPIC + + +def test_pick_family_defaults_to_openai_when_no_native_hint() -> None: + """mlflow-only api_types → fall back to universal OpenAI Chat.""" + family = pick_family_from_api_types(["mlflow/v1/chat/completions"]) + assert family is ProviderFamily.OPENAI + + +def test_pick_family_empty_or_none_returns_openai() -> None: + """Empty api_types + no external provider → OPENAI (universal default).""" + assert pick_family_from_api_types(None) is ProviderFamily.OPENAI + assert pick_family_from_api_types([]) is ProviderFamily.OPENAI + assert pick_family_from_api_types([], external_provider=None) is ProviderFamily.OPENAI + + +@pytest.mark.parametrize( + "provider,expected", + [ + ("anthropic", ProviderFamily.ANTHROPIC), + ("ANTHROPIC", ProviderFamily.ANTHROPIC), + ("bedrock-anthropic", ProviderFamily.ANTHROPIC), + ("google", ProviderFamily.GEMINI), + ("gemini", ProviderFamily.GEMINI), + ("openai", ProviderFamily.OPENAI), + ("azure-openai", ProviderFamily.OPENAI), + # Unknown providers → safe default + ("cohere", ProviderFamily.OPENAI), + ("", ProviderFamily.OPENAI), + ], +) +def test_pick_family_external_provider_routing( + provider: str, expected: ProviderFamily, +) -> None: + """External-model endpoints route via ``external_model.provider``.""" + family = pick_family_from_api_types([], external_provider=provider or None) + assert family == expected, ( + f"external provider {provider!r} should route to {expected}" + ) + + +def test_pick_family_native_api_type_wins_over_external_provider() -> None: + """When both signals are present, the native ``api_types`` must win.""" + family = pick_family_from_api_types( + ["anthropic/v1/messages"], + external_provider="openai", # contradictory, should be ignored + ) + assert family is ProviderFamily.ANTHROPIC + + +# --------------------------------------------------------------------------- +# AIGatewayPaths.url — URL construction per family +# --------------------------------------------------------------------------- + +_HOST = "https://adb-123.azuredatabricks.net" + + +_DEDICATED_GW = "https://9999999999999999.ai-gateway.cloud.databricks.com" + + +def test_aigateway_url_openai_chat_workspace_host() -> None: + """Workspace host: gateway is reverse-proxied at /ai-gateway.""" + url = AIGatewayPaths().url(_HOST, ProviderFamily.OPENAI, "databricks-llama-3-3") + assert url == f"{_HOST}/ai-gateway/mlflow/v1/chat/completions" + # Endpoint name is carried in the body, not the URL. + assert "llama" not in url + + +def test_aigateway_url_openai_chat_dedicated_gateway_host() -> None: + """Dedicated *.ai-gateway.* host is the gateway base; no /ai-gateway prefix.""" + url = AIGatewayPaths().url( + _DEDICATED_GW, ProviderFamily.OPENAI, "databricks-llama-3-3", + ) + assert url == f"{_DEDICATED_GW}/mlflow/v1/chat/completions" + + +def test_aigateway_url_anthropic_workspace_host() -> None: + """Anthropic native route — endpoint-agnostic, name goes in the body.""" + url = AIGatewayPaths().url( + _HOST, ProviderFamily.ANTHROPIC, "databricks-claude-sonnet-4-5", + ) + assert url == f"{_HOST}/ai-gateway/anthropic/v1/messages" + assert "claude" not in url + + +def test_aigateway_url_anthropic_dedicated_gateway_host() -> None: + url = AIGatewayPaths().url( + _DEDICATED_GW, ProviderFamily.ANTHROPIC, "databricks-claude-opus-4-6", + ) + assert url == f"{_DEDICATED_GW}/anthropic/v1/messages" + + +def test_aigateway_url_gemini_interpolates_endpoint() -> None: + url = AIGatewayPaths().url( + _HOST, ProviderFamily.GEMINI, "databricks-gemini-2-5-flash", + ) + assert url == ( + f"{_HOST}/ai-gateway/gemini/v1beta/models/" + f"databricks-gemini-2-5-flash:generateContent" + ) + + +def test_aigateway_url_openai_responses_workspace_host() -> None: + url = AIGatewayPaths().url( + _HOST, ProviderFamily.OPENAI_RESPONSES, "databricks-gpt-5-4", + ) + assert url == f"{_HOST}/ai-gateway/openai/v1/responses" + assert "gpt-5" not in url + + +def test_aigateway_url_trailing_slash_is_normalized() -> None: + """Trailing slashes on host must not produce double-slash URLs.""" + url = AIGatewayPaths().url(_HOST + "/", ProviderFamily.OPENAI, "foo") + assert "//ai-gateway" not in url + assert url == f"{_HOST}/ai-gateway/mlflow/v1/chat/completions" + + +def test_aigateway_url_explicit_ai_gateway_path_is_idempotent() -> None: + """If the user already added '/ai-gateway' to the host, don't double it.""" + host_with_prefix = f"{_HOST}/ai-gateway" + url = AIGatewayPaths().url(host_with_prefix, ProviderFamily.OPENAI, "x") + assert url == f"{_HOST}/ai-gateway/mlflow/v1/chat/completions" + + +def test_aigateway_paths_overrideable() -> None: + """Custom path templates (e.g. private deployment) round-trip.""" + paths = AIGatewayPaths(openai="/custom/{endpoint}/chat") + url = paths.url(_HOST, ProviderFamily.OPENAI, "x") + assert url == f"{_HOST}/ai-gateway/custom/x/chat" + + +def test_normalize_base_dedicated_gateway() -> None: + """*.ai-gateway.* hosts are returned as-is.""" + assert AIGatewayPaths.normalize_base(_DEDICATED_GW) == _DEDICATED_GW + assert AIGatewayPaths.normalize_base(_DEDICATED_GW + "/") == _DEDICATED_GW + + +def test_normalize_base_workspace_appends_prefix() -> None: + """Workspace URLs gain the /ai-gateway prefix.""" + assert ( + AIGatewayPaths.normalize_base(_HOST) == f"{_HOST}/ai-gateway" + ) + assert ( + AIGatewayPaths.normalize_base(f"{_HOST}/ai-gateway") == f"{_HOST}/ai-gateway" + ) diff --git a/tests/sdk/llm/providers/databricks/test_native.py b/tests/sdk/llm/providers/databricks/test_native.py new file mode 100644 index 0000000000..bdf41b6a3b --- /dev/null +++ b/tests/sdk/llm/providers/databricks/test_native.py @@ -0,0 +1,635 @@ +"""Tests for native-API adapters (``to_native`` / ``from_native``). + +Each family (Anthropic / Gemini / Responses / OpenAI Chat) has a tiny adapter +pair. These tests lock in the wire-format contract — any change here is a +potential API break for the downstream agent loop that expects OpenAI Chat. + +Live E2E already confirmed the adapters work against real endpoints (PAT → +llama, PROFILE → claude, UNIFIED → gemini). This file pins the wire shape. +""" + +from __future__ import annotations + +import pytest + +from openhands.sdk.llm.providers.databricks.models import ProviderFamily +from openhands.sdk.llm.providers.databricks.native import ( + _chat_tools_to_responses, + _flatten_content, + from_native, + to_native, +) + + +# --------------------------------------------------------------------------- +# _flatten_content — reasoning-model content blocks +# --------------------------------------------------------------------------- + +def test_flatten_content_string_passthrough() -> None: + assert _flatten_content("hello") == "hello" + + +def test_flatten_content_text_blocks_are_concatenated() -> None: + blocks = [ + {"type": "text", "text": "Hello "}, + {"type": "text", "text": "world"}, + ] + assert _flatten_content(blocks) == "Hello world" + + +def test_flatten_content_reasoning_blocks_are_dropped() -> None: + """Reasoning blocks are model-internal thought and must never leak out.""" + blocks = [ + {"type": "reasoning", "text": "thinking hard..."}, + {"type": "text", "text": "answer"}, + ] + assert _flatten_content(blocks) == "answer" + + +def test_flatten_content_garbage_returns_empty() -> None: + assert _flatten_content(None) == "" + assert _flatten_content(42) == "" + assert _flatten_content([42, "x"]) == "" + + +# --------------------------------------------------------------------------- +# OpenAI Chat (default family) +# --------------------------------------------------------------------------- + +def test_to_openai_chat_minimal_payload() -> None: + """The mlflow path doesn't carry the endpoint in the URL — it must be in the body.""" + body = to_native( + ProviderFamily.OPENAI, + "databricks-llama", + [{"role": "user", "content": "hi"}], + ) + assert body == { + "model": "databricks-llama", + "messages": [{"role": "user", "content": "hi"}], + } + + +def test_to_openai_chat_forwards_generation_kwargs() -> None: + body = to_native( + ProviderFamily.OPENAI, "m", + [{"role": "user", "content": "x"}], + max_tokens=32, temperature=0.2, top_p=0.9, stop=["END"], + ) + assert body["max_tokens"] == 32 + assert body["temperature"] == 0.2 + assert body["top_p"] == 0.9 + assert body["stop"] == ["END"] + + +def test_to_openai_chat_includes_tools_and_tool_choice() -> None: + tools = [{"type": "function", "function": {"name": "get_time"}}] + body = to_native( + ProviderFamily.OPENAI, "m", + [{"role": "user", "content": "what time"}], + tools=tools, tool_choice="auto", + ) + assert body["tools"] == tools + assert body["tool_choice"] == "auto" + + +def test_to_openai_chat_stream_flag_propagates() -> None: + body = to_native(ProviderFamily.OPENAI, "m", + [{"role": "user", "content": "x"}], stream=True) + assert body["stream"] is True + + +def test_from_openai_chat_passthrough() -> None: + """OpenAI Chat is the native format — responses should pass through intact.""" + raw = { + "id": "chatcmpl-1", "object": "chat.completion", "model": "m", + "choices": [{"index": 0, + "message": {"role": "assistant", "content": "hi"}, + "finish_reason": "stop"}], + "usage": {"prompt_tokens": 3, "completion_tokens": 1, "total_tokens": 4}, + } + out = from_native(ProviderFamily.OPENAI, "m", raw) + assert out["id"] == "chatcmpl-1" + assert out["choices"][0]["message"]["content"] == "hi" + + +def test_from_openai_chat_flattens_list_content_for_reasoning_models() -> None: + """gpt-oss-style list-of-blocks content must be flattened to a string.""" + raw = { + "choices": [{"message": {"content": [ + {"type": "reasoning", "text": "think"}, + {"type": "text", "text": "answer"}, + ]}}], + } + out = from_native(ProviderFamily.OPENAI, "m", raw) + assert out["choices"][0]["message"]["content"] == "answer" + + +# --------------------------------------------------------------------------- +# Anthropic Messages +# --------------------------------------------------------------------------- + +def test_to_anthropic_hoists_system_message() -> None: + body = to_native( + ProviderFamily.ANTHROPIC, "databricks-claude-sonnet-4-5", + [ + {"role": "system", "content": "You are a poet."}, + {"role": "user", "content": "Write a haiku."}, + ], + ) + assert body["system"] == "You are a poet." + assert body["messages"] == [{"role": "user", "content": "Write a haiku."}] + assert "system" not in {m["role"] for m in body["messages"]} + + +def test_to_anthropic_requires_max_tokens_even_if_unspecified() -> None: + """Anthropic Messages rejects requests without max_tokens — we default it.""" + body = to_native( + ProviderFamily.ANTHROPIC, "m", + [{"role": "user", "content": "hi"}], + ) + assert "max_tokens" in body + assert isinstance(body["max_tokens"], int) and body["max_tokens"] > 0 + + +def test_to_anthropic_stop_maps_to_stop_sequences() -> None: + body = to_native( + ProviderFamily.ANTHROPIC, "m", + [{"role": "user", "content": "x"}], + stop="END", + ) + assert body["stop_sequences"] == ["END"] + body = to_native( + ProviderFamily.ANTHROPIC, "m", + [{"role": "user", "content": "x"}], + stop=["A", "B"], + ) + assert body["stop_sequences"] == ["A", "B"] + + +def test_to_anthropic_includes_model_id() -> None: + """Anthropic path is endpoint-agnostic — model id travels in the body.""" + body = to_native( + ProviderFamily.ANTHROPIC, "databricks-claude-opus-4-6", + [{"role": "user", "content": "x"}], + ) + assert body["model"] == "databricks-claude-opus-4-6" + + +def test_from_anthropic_extracts_text_blocks() -> None: + raw = { + "id": "msg_abc", "model": "claude-sonnet", + "content": [{"type": "text", "text": "Hello there."}], + "stop_reason": "end_turn", + "usage": {"input_tokens": 10, "output_tokens": 3}, + } + out = from_native(ProviderFamily.ANTHROPIC, "claude-sonnet", raw) + assert out["id"] == "msg_abc" + assert out["choices"][0]["message"]["content"] == "Hello there." + assert out["choices"][0]["finish_reason"] == "stop" + assert out["usage"]["prompt_tokens"] == 10 + assert out["usage"]["completion_tokens"] == 3 + assert out["usage"]["total_tokens"] == 13 + + +@pytest.mark.parametrize( + "stop_reason,expected", + [("end_turn", "stop"), ("stop_sequence", "stop"), + ("max_tokens", "length"), ("tool_use", "tool_calls")], +) +def test_from_anthropic_stop_reason_mapping(stop_reason: str, expected: str) -> None: + raw = {"content": [{"type": "text", "text": "x"}], + "stop_reason": stop_reason, "usage": {}} + out = from_native(ProviderFamily.ANTHROPIC, "m", raw) + assert out["choices"][0]["finish_reason"] == expected + + +# --------------------------------------------------------------------------- +# Google Gemini generateContent +# --------------------------------------------------------------------------- + +def test_to_gemini_builds_contents_and_system_instruction() -> None: + body = to_native( + ProviderFamily.GEMINI, "databricks-gemini-2-5-flash", + [ + {"role": "system", "content": "Be concise."}, + {"role": "user", "content": "What is 2+2?"}, + {"role": "assistant", "content": "4"}, + ], + ) + assert body["systemInstruction"] == {"parts": [{"text": "Be concise."}]} + assert body["contents"] == [ + {"role": "user", "parts": [{"text": "What is 2+2?"}]}, + {"role": "model", "parts": [{"text": "4"}]}, # assistant → model + ] + + +def test_to_gemini_sets_max_output_tokens_with_safe_default() -> None: + """Gemini spends budget on thinking — default must be large enough for output.""" + body = to_native( + ProviderFamily.GEMINI, "m", + [{"role": "user", "content": "x"}], + ) + assert body["generationConfig"]["maxOutputTokens"] >= 256 + + +def test_to_gemini_maps_stop_to_stop_sequences() -> None: + body = to_native( + ProviderFamily.GEMINI, "m", + [{"role": "user", "content": "x"}], + stop=["END"], + ) + assert body["generationConfig"]["stopSequences"] == ["END"] + + +def test_from_gemini_extracts_text_and_usage() -> None: + raw = { + "candidates": [{ + "content": {"role": "model", "parts": [ + {"text": "Four."}, + ]}, + "finishReason": "STOP", + }], + "usageMetadata": { + "promptTokenCount": 5, "candidatesTokenCount": 2, "totalTokenCount": 7, + }, + "responseId": "gen-abc", + } + out = from_native(ProviderFamily.GEMINI, "m", raw) + assert out["choices"][0]["message"]["content"] == "Four." + assert out["choices"][0]["finish_reason"] == "stop" + assert out["usage"] == {"prompt_tokens": 5, "completion_tokens": 2, "total_tokens": 7} + assert out["id"] == "gen-abc" + + +@pytest.mark.parametrize( + "finish,expected", + [("STOP", "stop"), ("MAX_TOKENS", "length"), + ("SAFETY", "content_filter"), ("RECITATION", "content_filter"), + ("UNKNOWN_REASON", "stop")], +) +def test_from_gemini_finish_reason_mapping(finish: str, expected: str) -> None: + raw = {"candidates": [{"content": {"parts": [{"text": "x"}]}, + "finishReason": finish}]} + out = from_native(ProviderFamily.GEMINI, "m", raw) + assert out["choices"][0]["finish_reason"] == expected + + +# --------------------------------------------------------------------------- +# OpenAI Responses (GPT-5 series) +# --------------------------------------------------------------------------- + +def test_to_responses_uses_input_not_messages() -> None: + body = to_native( + ProviderFamily.OPENAI_RESPONSES, "databricks-gpt-5-4", + [{"role": "user", "content": "say hi"}], + ) + assert "messages" not in body + # Responses requires content parts of type ``input_text`` for user + # messages — string content is wrapped accordingly. + assert body["input"] == [ + {"role": "user", "content": [{"type": "input_text", "text": "say hi"}]} + ] + assert body["model"] == "databricks-gpt-5-4" + + +def test_to_responses_renames_max_tokens_to_max_output_tokens() -> None: + body = to_native( + ProviderFamily.OPENAI_RESPONSES, "m", + [{"role": "user", "content": "x"}], + max_tokens=256, + ) + assert "max_tokens" not in body + assert body["max_output_tokens"] == 256 + + +def test_to_responses_default_budget_accommodates_reasoning_tokens() -> None: + """GPT-5 spends tokens on reasoning — default must not be too small.""" + body = to_native( + ProviderFamily.OPENAI_RESPONSES, "m", + [{"role": "user", "content": "x"}], + ) + assert body["max_output_tokens"] >= 512 + + +def test_to_responses_drops_gateway_unsupported_kwargs() -> None: + """Gateway rejects ``background`` / ``store`` / ``previous_response_id`` etc.""" + body = to_native( + ProviderFamily.OPENAI_RESPONSES, "m", + [{"role": "user", "content": "x"}], + background=True, store=True, + previous_response_id="resp_xyz", service_tier="flex", + ) + for dropped in ("background", "store", "previous_response_id", "service_tier"): + assert dropped not in body, f"{dropped!r} must be dropped — gateway rejects it" + + +def test_to_responses_drops_temperature_and_top_p() -> None: + """GPT-5 reasoning models reject ``temperature`` and ``top_p``. + + The default ``LLM`` ships ``temperature=0.0`` for everyone, so silently + dropping them in the Responses adapter is the only way single-default + callers can talk to GPT-5 at all. + """ + body = to_native( + ProviderFamily.OPENAI_RESPONSES, "m", + [{"role": "user", "content": "x"}], + temperature=0.0, + top_p=0.5, + ) + assert "temperature" not in body + assert "top_p" not in body + + +def test_to_responses_translates_user_text_part_to_input_text() -> None: + """Responses rejects content-part type ``"text"`` for user messages. + + Chat-Completions style ``[{"type": "text", ...}]`` must become + ``[{"type": "input_text", ...}]`` for user/system roles, and string + content must be wrapped in a single ``input_text`` part. + """ + body = to_native( + ProviderFamily.OPENAI_RESPONSES, "m", + [ + {"role": "user", "content": [{"type": "text", "text": "hi"}]}, + {"role": "system", "content": "you are helpful"}, + {"role": "assistant", "content": [{"type": "text", "text": "ok"}]}, + ], + ) + assert body["input"][0]["content"][0] == {"type": "input_text", "text": "hi"} + assert body["input"][1]["content"][0] == { + "type": "input_text", + "text": "you are helpful", + } + # Assistant text parts use output_text. + assert body["input"][2]["content"][0] == {"type": "output_text", "text": "ok"} + + +def test_to_responses_drops_chat_style_max_completion_tokens_alias() -> None: + """Upstream LLM/litellm path emits ``max_completion_tokens`` for OpenAI-flavoured + calls; the Responses API rejects it (``unsupported_parameter`` 400). The + adapter must drop the chat-style alias and only forward ``max_output_tokens``. + """ + body = to_native( + ProviderFamily.OPENAI_RESPONSES, "m", + [{"role": "user", "content": "x"}], + max_tokens=128, + max_completion_tokens=128, + ) + assert "max_completion_tokens" not in body + assert body["max_output_tokens"] == 128 + + +def test_from_responses_flattens_message_items_and_skips_reasoning() -> None: + """Responses output is an array of items; we extract text, skip reasoning.""" + raw = { + "id": "resp_abc", + "status": "completed", + "output": [ + {"type": "reasoning", "content": [{"type": "text", "text": "thinking..."}]}, + {"type": "message", "content": [ + {"type": "output_text", "text": "The answer is 42."}, + ]}, + ], + "usage": {"input_tokens": 4, "output_tokens": 8, "total_tokens": 12}, + } + out = from_native(ProviderFamily.OPENAI_RESPONSES, "m", raw) + assert out["id"] == "resp_abc" + assert out["choices"][0]["message"]["content"] == "The answer is 42." + assert out["choices"][0]["finish_reason"] == "stop" + assert out["usage"]["total_tokens"] == 12 + + +def test_from_responses_converts_function_call_to_tool_calls() -> None: + """Responses function_call items must be converted to Chat Completions tool_calls. + + Without this, GPT-5 tool invocations are silently dropped and OpenHands + loops with 'response did not include a function call'. + """ + raw = { + "id": "resp_xyz", + "status": "completed", + "output": [ + { + "type": "function_call", + "id": "fc_abc", + "call_id": "call_abc", + "name": "terminal", + "arguments": '{"command": "ls"}', + } + ], + "usage": {"input_tokens": 10, "output_tokens": 5, "total_tokens": 15}, + } + out = from_native(ProviderFamily.OPENAI_RESPONSES, "gpt-5", raw) + msg = out["choices"][0]["message"] + assert "tool_calls" in msg + assert len(msg["tool_calls"]) == 1 + tc = msg["tool_calls"][0] + assert tc["id"] == "call_abc" + assert tc["type"] == "function" + assert tc["function"]["name"] == "terminal" + assert tc["function"]["arguments"] == '{"command": "ls"}' + assert out["choices"][0]["finish_reason"] == "tool_calls" + + +def test_from_responses_mixed_text_and_tool_call() -> None: + """Text and function_call in same response are both captured.""" + raw = { + "id": "resp_mix", + "status": "completed", + "output": [ + {"type": "message", "content": [{"type": "output_text", "text": "Running..."}]}, + {"type": "function_call", "call_id": "call_1", "name": "read_file", + "arguments": '{"path": "foo.py"}'}, + ], + "usage": {"input_tokens": 5, "output_tokens": 5, "total_tokens": 10}, + } + out = from_native(ProviderFamily.OPENAI_RESPONSES, "gpt-5", raw) + msg = out["choices"][0]["message"] + assert msg["content"] == "Running..." + assert len(msg["tool_calls"]) == 1 + assert msg["tool_calls"][0]["function"]["name"] == "read_file" + + +def test_from_responses_falls_back_to_aggregated_output_text() -> None: + """When ``output`` is absent but the response exposes flat ``output_text``, + consume that as the assistant message content.""" + raw = {"id": "resp_1", "status": "completed", "output_text": "agg text"} + out = from_native(ProviderFamily.OPENAI_RESPONSES, "m", raw) + assert out["choices"][0]["message"]["content"] == "agg text" + + +# --------------------------------------------------------------------------- +# _chat_tools_to_responses — tool format conversion +# --------------------------------------------------------------------------- + +def test_chat_tools_to_responses_unwraps_function_wrapper() -> None: + """Chat Completions tool {type, function: {name, ...}} → Responses {type, name, ...}.""" + chat_tools = [ + { + "type": "function", + "function": { + "name": "terminal", + "description": "Run a shell command", + "parameters": {"type": "object", "properties": {}}, + }, + } + ] + result = _chat_tools_to_responses(chat_tools) + assert len(result) == 1 + tool = result[0] + # name must be at top level (Responses API requirement) + assert tool["name"] == "terminal" + assert tool["type"] == "function" + assert tool["description"] == "Run a shell command" + # function wrapper must be gone + assert "function" not in tool + + +def test_chat_tools_to_responses_multiple_tools() -> None: + """Conversion handles multiple tools correctly.""" + chat_tools = [ + {"type": "function", "function": {"name": "read_file", "description": "r", "parameters": {}}}, + {"type": "function", "function": {"name": "write_file", "description": "w", "parameters": {}}}, + ] + result = _chat_tools_to_responses(chat_tools) + assert [t["name"] for t in result] == ["read_file", "write_file"] + assert all("function" not in t for t in result) + + +def test_to_responses_converts_tool_format() -> None: + """to_native for OPENAI_RESPONSES flattens tool wrappers into Responses format.""" + chat_tools = [ + { + "type": "function", + "function": { + "name": "terminal", + "description": "Run a shell command", + "parameters": {"type": "object", "properties": {}}, + }, + } + ] + body = to_native( + ProviderFamily.OPENAI_RESPONSES, + "databricks-gpt-5-4", + [{"role": "user", "content": "hello"}], + tools=chat_tools, + ) + assert "tools" in body + tool = body["tools"][0] + assert tool["name"] == "terminal" + assert tool["type"] == "function" + assert "function" not in tool + + +# --------------------------------------------------------------------------- +# _to_responses_input — multi-turn tool call translation +# --------------------------------------------------------------------------- + + +def test_to_responses_tool_role_becomes_function_call_output() -> None: + """``role=tool`` messages must become ``function_call_output`` items. + + When GPT-5 returns a function_call on turn 1, OpenHands sends back the + result as a Chat Completions ``role=tool`` message on turn 2. The Responses + API requires this to be a ``function_call_output`` item (not a message with + role) — otherwise the second turn fails with a schema error. + """ + body = to_native( + ProviderFamily.OPENAI_RESPONSES, + "databricks-gpt-5-4", + [ + {"role": "user", "content": "Write hello.py"}, + { + "role": "assistant", + "content": None, + "tool_calls": [ + { + "id": "call_abc", + "type": "function", + "function": {"name": "write_file", "arguments": '{"path":"hello.py","content":"hi"}'}, + } + ], + }, + { + "role": "tool", + "tool_call_id": "call_abc", + "content": "File written successfully.", + }, + ], + ) + items = body["input"] + types = [i.get("type") or i.get("role") for i in items] + # user message preserved + assert types[0] == "user", f"expected user, got {items[0]}" + # assistant tool_call → function_call item + assert types[1] == "function_call", f"expected function_call, got {items[1]}" + assert items[1]["call_id"] == "call_abc" + assert items[1]["name"] == "write_file" + # tool result → function_call_output item + assert types[2] == "function_call_output", f"expected function_call_output, got {items[2]}" + assert items[2]["call_id"] == "call_abc" + assert items[2]["output"] == "File written successfully." + + +def test_to_responses_assistant_with_tool_calls_emits_function_call_items() -> None: + """Assistant messages with ``tool_calls`` become ``function_call`` input items.""" + body = to_native( + ProviderFamily.OPENAI_RESPONSES, + "m", + [ + {"role": "user", "content": "go"}, + { + "role": "assistant", + "content": None, + "tool_calls": [ + { + "id": "call_1", + "type": "function", + "function": {"name": "read_file", "arguments": '{"path":"a.py"}'}, + }, + { + "id": "call_2", + "type": "function", + "function": {"name": "write_file", "arguments": '{"path":"b.py","content":"x"}'}, + }, + ], + }, + ], + ) + items = body["input"] + # items[0] = user message, items[1] = function_call (read_file), items[2] = function_call (write_file) + assert len(items) == 3 + assert items[1]["type"] == "function_call" + assert items[1]["name"] == "read_file" + assert items[1]["call_id"] == "call_1" + assert items[2]["type"] == "function_call" + assert items[2]["name"] == "write_file" + assert items[2]["call_id"] == "call_2" + + +def test_to_responses_assistant_with_tool_calls_and_text_emits_both() -> None: + """Assistant message with both text content and tool_calls emits both items.""" + body = to_native( + ProviderFamily.OPENAI_RESPONSES, + "m", + [ + { + "role": "assistant", + "content": "I'll run that for you.", + "tool_calls": [ + { + "id": "call_x", + "type": "function", + "function": {"name": "terminal", "arguments": '{"cmd":"ls"}'}, + } + ], + }, + ], + ) + items = body["input"] + fc = next((i for i in items if i.get("type") == "function_call"), None) + assert fc is not None, "function_call item missing" + assert fc["name"] == "terminal" + text_items = [i for i in items if i.get("role") == "assistant"] + assert text_items, "output_text item for text content missing" + assert text_items[0]["content"][0]["text"] == "I'll run that for you." diff --git a/tests/sdk/llm/providers/databricks/test_resilience.py b/tests/sdk/llm/providers/databricks/test_resilience.py new file mode 100644 index 0000000000..093e0d0fac --- /dev/null +++ b/tests/sdk/llm/providers/databricks/test_resilience.py @@ -0,0 +1,241 @@ +"""Tests for the fetch_with_retry retry loop. + +Covers: success on first try, retry on 429/503, no retry on 400/401/403, +Retry-After cap at 300s (P1-4), exponential backoff, connection error retry, +and exhaustion of all retries. +""" + +from __future__ import annotations + +from unittest.mock import MagicMock, call, patch + +import httpx +import pytest +from litellm.exceptions import ( + APIConnectionError, + AuthenticationError, + BadRequestError, + RateLimitError, + ServiceUnavailableError, +) + +from openhands.sdk.llm.providers.databricks.utils import ( + fetch_with_retry, +) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +_URL = "https://adb-123.azuredatabricks.net/serving-endpoints/my-model/invocations" +_HEADERS = {"Authorization": "Bearer tok", "Content-Type": "application/json"} +_PAYLOAD = {"messages": [{"role": "user", "content": "hi"}]} + + +def _make_client_with_responses(*responses: httpx.Response) -> httpx.Client: + """Return a mock httpx.Client whose .post() yields responses in order.""" + mock_client = MagicMock(spec=httpx.Client) + mock_client.post.side_effect = list(responses) + return mock_client + + +# --------------------------------------------------------------------------- +# Happy path +# --------------------------------------------------------------------------- + +def test_fetch_with_retry_success_first_try() -> None: + """200 response on first attempt — no retry, no sleep.""" + ok_resp = httpx.Response(200, json={"id": "1", "choices": []}) + client = _make_client_with_responses(ok_resp) + + with patch("time.sleep") as mock_sleep: + result = fetch_with_retry(client, _URL, _HEADERS, _PAYLOAD, max_retries=3) + + assert result.status_code == 200 + mock_sleep.assert_not_called() + client.post.assert_called_once() + + +# --------------------------------------------------------------------------- +# 429 / 5xx — should retry +# --------------------------------------------------------------------------- + +def test_fetch_with_retry_retries_on_429() -> None: + """429 → sleep → retry → 200 on second attempt.""" + r429 = httpx.Response(429, json={"message": "Rate limit"}) + r200 = httpx.Response(200, json={"id": "1", "choices": []}) + client = _make_client_with_responses(r429, r200) + + with patch("time.sleep"): + result = fetch_with_retry(client, _URL, _HEADERS, _PAYLOAD, max_retries=3) + + assert result.status_code == 200 + assert client.post.call_count == 2 + + +def test_fetch_with_retry_retries_on_503() -> None: + """503 → retry → 200.""" + r503 = httpx.Response(503, json={"message": "Service unavailable"}) + r200 = httpx.Response(200, json={"id": "1", "choices": []}) + client = _make_client_with_responses(r503, r200) + + with patch("time.sleep"): + result = fetch_with_retry(client, _URL, _HEADERS, _PAYLOAD, max_retries=3) + + assert result.status_code == 200 + + +def test_fetch_with_retry_uses_retry_after_header() -> None: + """When server sends Retry-After header, sleep that many seconds (capped at 300s).""" + r429 = httpx.Response(429, headers={"Retry-After": "42"}, json={"message": "rl"}) + r200 = httpx.Response(200, json={"id": "1", "choices": []}) + client = _make_client_with_responses(r429, r200) + + slept: list[float] = [] + with patch("time.sleep", side_effect=lambda s: slept.append(s)): + fetch_with_retry(client, _URL, _HEADERS, _PAYLOAD, max_retries=3) + + assert len(slept) == 1 + assert slept[0] == 42.0 + + +def test_fetch_with_retry_caps_retry_after_at_300s() -> None: + """Retry-After > 300s must be capped at 300s (P1-4).""" + r429 = httpx.Response(429, headers={"Retry-After": "9999"}, json={"message": "rl"}) + r200 = httpx.Response(200, json={"id": "1", "choices": []}) + client = _make_client_with_responses(r429, r200) + + slept: list[float] = [] + with patch("time.sleep", side_effect=lambda s: slept.append(s)): + fetch_with_retry(client, _URL, _HEADERS, _PAYLOAD, max_retries=3) + + assert len(slept) == 1 + assert slept[0] == 300.0 + + +def test_fetch_with_retry_exhausts_and_raises_rate_limit() -> None: + """All retries exhausted on 429 → raises RateLimitError.""" + r429 = httpx.Response(429, json={"message": "Rate limit"}) + client = _make_client_with_responses(r429, r429, r429, r429) + + with patch("time.sleep"): + with pytest.raises(RateLimitError): + fetch_with_retry(client, _URL, _HEADERS, _PAYLOAD, max_retries=3) + + assert client.post.call_count == 4 # 1 initial + 3 retries + + +def test_fetch_with_retry_exhausts_and_raises_service_unavailable() -> None: + """All retries exhausted on 503 → raises ServiceUnavailableError.""" + r503 = httpx.Response(503, json={"message": "Down"}) + client = _make_client_with_responses(r503, r503, r503, r503) + + with patch("time.sleep"): + with pytest.raises(ServiceUnavailableError): + fetch_with_retry(client, _URL, _HEADERS, _PAYLOAD, max_retries=3) + + +# --------------------------------------------------------------------------- +# Non-retryable status codes +# --------------------------------------------------------------------------- + +def test_fetch_with_retry_does_not_retry_on_400() -> None: + """400 raises BadRequestError immediately without retry.""" + r400 = httpx.Response(400, json={"message": "Bad request"}) + client = _make_client_with_responses(r400) + + with patch("time.sleep") as mock_sleep: + with pytest.raises(BadRequestError): + fetch_with_retry(client, _URL, _HEADERS, _PAYLOAD, max_retries=3) + + mock_sleep.assert_not_called() + client.post.assert_called_once() + + +def test_fetch_with_retry_does_not_retry_on_401() -> None: + """401 raises AuthenticationError immediately without retry.""" + r401 = httpx.Response(401, json={"message": "Unauthorized"}) + client = _make_client_with_responses(r401) + + with patch("time.sleep") as mock_sleep: + with pytest.raises(AuthenticationError): + fetch_with_retry(client, _URL, _HEADERS, _PAYLOAD, max_retries=3) + + mock_sleep.assert_not_called() + + +def test_fetch_with_retry_does_not_retry_on_403() -> None: + """403 raises AuthenticationError immediately.""" + r403 = httpx.Response(403, json={"message": "Forbidden"}) + client = _make_client_with_responses(r403) + + with pytest.raises(AuthenticationError): + with patch("time.sleep"): + fetch_with_retry(client, _URL, _HEADERS, _PAYLOAD, max_retries=3) + + +def test_fetch_with_retry_does_not_retry_on_422() -> None: + """422 raises BadRequestError immediately.""" + r422 = httpx.Response(422, json={"message": "Validation error"}) + client = _make_client_with_responses(r422) + + with pytest.raises(BadRequestError): + with patch("time.sleep"): + fetch_with_retry(client, _URL, _HEADERS, _PAYLOAD, max_retries=3) + + +# --------------------------------------------------------------------------- +# RETRYABLE_EXCEPTIONS (network errors) +# --------------------------------------------------------------------------- + +def test_fetch_with_retry_retries_on_connect_error() -> None: + """httpx.ConnectError is retried (network transient failure).""" + mock_client = MagicMock(spec=httpx.Client) + ok_resp = httpx.Response(200, json={"id": "1", "choices": []}) + mock_client.post.side_effect = [httpx.ConnectError("connection refused"), ok_resp] + + with patch("time.sleep"): + result = fetch_with_retry(mock_client, _URL, _HEADERS, _PAYLOAD, max_retries=3) + + assert result.status_code == 200 + assert mock_client.post.call_count == 2 + + +def test_fetch_with_retry_raises_api_connection_error_after_network_exhaustion() -> None: + """Persistent ConnectError after all retries → APIConnectionError.""" + mock_client = MagicMock(spec=httpx.Client) + mock_client.post.side_effect = httpx.ConnectError("connection refused") + + with patch("time.sleep"): + with pytest.raises(APIConnectionError): + fetch_with_retry(mock_client, _URL, _HEADERS, _PAYLOAD, max_retries=2) + + +def test_fetch_with_retry_retries_on_read_timeout() -> None: + """httpx.ReadTimeout is retried.""" + mock_client = MagicMock(spec=httpx.Client) + ok_resp = httpx.Response(200, json={"id": "1", "choices": []}) + mock_client.post.side_effect = [httpx.ReadTimeout("timed out"), ok_resp] + + with patch("time.sleep"): + result = fetch_with_retry(mock_client, _URL, _HEADERS, _PAYLOAD, max_retries=3) + + assert result.status_code == 200 + + +# --------------------------------------------------------------------------- +# max_retries=0 means one attempt only +# --------------------------------------------------------------------------- + +def test_fetch_with_retry_zero_retries_raises_immediately_on_429() -> None: + """max_retries=0: no retry on a retryable status — raises after first attempt.""" + r429 = httpx.Response(429, json={"message": "Rate limit"}) + client = _make_client_with_responses(r429) + + with patch("time.sleep") as mock_sleep: + with pytest.raises(RateLimitError): + fetch_with_retry(client, _URL, _HEADERS, _PAYLOAD, max_retries=0) + + mock_sleep.assert_not_called() + client.post.assert_called_once() diff --git a/tests/sdk/llm/providers/databricks/test_settings_bridge.py b/tests/sdk/llm/providers/databricks/test_settings_bridge.py new file mode 100644 index 0000000000..6b015a58cf --- /dev/null +++ b/tests/sdk/llm/providers/databricks/test_settings_bridge.py @@ -0,0 +1,214 @@ +"""Tests for the ``settings → create_llm(...)`` kwargs bridge. + +Covers: + +* Empty settings only yield ``usage_id``. +* Full passthrough with secret coercion and ``StoredU2MTokens`` normalization. +* ``None`` / empty-string dropping. +* ``model_override`` and ``base_url_fallback`` semantics. +* ``extras`` winning over settings values. +* Invalid ``stored_u2m_tokens`` dict silently ignored. +* **Contract / drift guard** — every public field on ``DatabricksLLM`` is either + bridged from settings or explicitly listed in ``_NOT_BRIDGED``. Adding a new + Databricks field without updating the bridge fails this test. +""" + +from __future__ import annotations + +from types import SimpleNamespace + +from pydantic import SecretStr + +from openhands.sdk.llm.providers.databricks import DatabricksLLM, StoredU2MTokens +from openhands.sdk.llm.providers.databricks.settings_bridge import ( + _BRIDGE_FIELDS, + _NOT_BRIDGED, + UserInfoAliases, + kwargs_from_settings, +) + + +# --------------------------------------------------------------------------- +# Behavior +# --------------------------------------------------------------------------- + + +def test_empty_settings_only_includes_usage_id() -> None: + kw = kwargs_from_settings(SimpleNamespace(), usage_id="agent") + assert kw == {"usage_id": "agent"} + + +def test_full_settings_passthrough() -> None: + settings = SimpleNamespace( + model="databricks/databricks-claude-sonnet-4-5", + api_key="dapi-1234", + base_url="https://workspace.cloud.databricks.com", + timeout=60.0, + max_input_tokens=128_000, + databricks_host="https://workspace.cloud.databricks.com", + databricks_client_id="app-id", + databricks_client_secret="client-secret-raw", + databricks_profile="DEFAULT", + databricks_ssl_verify=True, + databricks_max_retries=5, + databricks_connect_timeout_s=7.0, + databricks_read_timeout_s=90.0, + databricks_chunk_timeout_s=25.0, + stored_u2m_tokens={ + "access_token": "at", + "refresh_token": "rt", + "expires_at": 9999999999.0, + "client_id": "u2m-client", + "host": "https://workspace.cloud.databricks.com", + }, + ) + kw = kwargs_from_settings(settings, usage_id="agent") + + assert kw["usage_id"] == "agent" + assert kw["model"] == "databricks/databricks-claude-sonnet-4-5" + assert isinstance(kw["api_key"], SecretStr) + assert kw["api_key"].get_secret_value() == "dapi-1234" + assert isinstance(kw["databricks_client_secret"], SecretStr) + assert kw["databricks_client_secret"].get_secret_value() == "client-secret-raw" + assert isinstance(kw["stored_u2m_tokens"], StoredU2MTokens) + assert kw["databricks_profile"] == "DEFAULT" + assert kw["timeout"] == 60.0 + assert kw["max_input_tokens"] == 128_000 + + +def test_secretstr_roundtrips_unchanged() -> None: + s = SimpleNamespace(api_key=SecretStr("dapi-abc")) + kw = kwargs_from_settings(s, usage_id="agent") + assert isinstance(kw["api_key"], SecretStr) + assert kw["api_key"].get_secret_value() == "dapi-abc" + + +def test_none_and_empty_strings_dropped() -> None: + s = SimpleNamespace(model="", api_key=None, base_url="", databricks_profile=None) + kw = kwargs_from_settings(s, usage_id="agent") + assert set(kw) == {"usage_id"} + + +def test_model_override_wins_over_settings() -> None: + s = SimpleNamespace(model="databricks/foo") + kw = kwargs_from_settings(s, usage_id="agent", model_override="databricks/bar") + assert kw["model"] == "databricks/bar" + + +def test_base_url_fallback_only_when_both_empty() -> None: + kw = kwargs_from_settings( + SimpleNamespace(), + usage_id="agent", + base_url_fallback="https://fallback.com", + ) + assert kw["base_url"] == "https://fallback.com" + + kw_explicit = kwargs_from_settings( + SimpleNamespace(base_url="https://explicit.com"), + usage_id="agent", + base_url_fallback="https://fallback.com", + ) + assert kw_explicit["base_url"] == "https://explicit.com" + + kw_host_only = kwargs_from_settings( + SimpleNamespace(databricks_host="https://ws.cloud.databricks.com"), + usage_id="agent", + base_url_fallback="https://fallback.com", + ) + assert "base_url" not in kw_host_only + assert kw_host_only["databricks_host"] == "https://ws.cloud.databricks.com" + + +def test_extras_win_over_settings_and_coerce_secrets() -> None: + s = SimpleNamespace(api_key="dapi-stored", databricks_profile="OLD") + kw = kwargs_from_settings( + s, + usage_id="agent", + extras={"api_key": "dapi-session", "databricks_profile": "PROD"}, + ) + assert isinstance(kw["api_key"], SecretStr) + assert kw["api_key"].get_secret_value() == "dapi-session" + assert kw["databricks_profile"] == "PROD" + + +def test_extras_none_values_dropped() -> None: + s = SimpleNamespace(api_key="dapi-stored") + kw = kwargs_from_settings( + s, + usage_id="agent", + extras={"api_key": None, "databricks_profile": None}, + ) + assert kw["api_key"].get_secret_value() == "dapi-stored" + assert "databricks_profile" not in kw + + +def test_aliases_map_userinfo_style_prefixed_fields() -> None: + """UserInfo uses ``llm_model`` / ``llm_api_key`` / ``llm_base_url``. + + The aliases mechanism lets the bridge read them without requiring each + caller to build a shim object. + """ + s = SimpleNamespace( + llm_model="databricks/databricks-gemini-2-5-pro", + llm_api_key="dapi-user", + llm_base_url="https://ws.cloud.databricks.com", + databricks_profile="PROD", + ) + kw = kwargs_from_settings(s, usage_id="agent", aliases=UserInfoAliases) + assert kw["model"] == "databricks/databricks-gemini-2-5-pro" + assert isinstance(kw["api_key"], SecretStr) + assert kw["api_key"].get_secret_value() == "dapi-user" + assert kw["base_url"] == "https://ws.cloud.databricks.com" + assert kw["databricks_profile"] == "PROD" + + +def test_canonical_name_wins_over_alias() -> None: + """If both ``api_key`` and ``llm_api_key`` are set, canonical wins.""" + s = SimpleNamespace(api_key="canonical", llm_api_key="aliased") + kw = kwargs_from_settings(s, usage_id="agent", aliases=UserInfoAliases) + assert kw["api_key"].get_secret_value() == "canonical" + + +def test_invalid_stored_u2m_dict_is_skipped() -> None: + s = SimpleNamespace(stored_u2m_tokens={"totally": "bogus"}) + kw = kwargs_from_settings(s, usage_id="agent") + assert "stored_u2m_tokens" not in kw + + +def test_stored_u2m_instance_passes_through() -> None: + tok = StoredU2MTokens( + access_token="at", + refresh_token="rt", + expires_at=9999999999.0, + client_id="u2m-client", + host="https://workspace.cloud.databricks.com", + ) + s = SimpleNamespace(stored_u2m_tokens=tok) + kw = kwargs_from_settings(s, usage_id="agent") + assert kw["stored_u2m_tokens"] is tok + + +# --------------------------------------------------------------------------- +# Drift guard +# --------------------------------------------------------------------------- + + +def test_bridge_covers_all_databricks_llm_public_fields() -> None: + """Fails when a new ``DatabricksLLM`` field is added without updating the bridge. + + Guarantees OpenHands backend + OpenHands-CLI always build the full kwarg + surface when the SDK gains a new Databricks-specific field. + """ + own_fields = { + name + for name in DatabricksLLM.__annotations__ + if not name.startswith("_") + } + covered = set(_BRIDGE_FIELDS) | _NOT_BRIDGED + missing = own_fields - covered + + assert not missing, ( + 'New DatabricksLLM field(s) not handled by the settings bridge: ' + f'{sorted(missing)}. Extend _BRIDGE_FIELDS (and every call site) or add ' + 'to _NOT_BRIDGED in settings_bridge.py.' + ) diff --git a/tests/sdk/llm/providers/databricks/test_utils.py b/tests/sdk/llm/providers/databricks/test_utils.py new file mode 100644 index 0000000000..efa9c38d5d --- /dev/null +++ b/tests/sdk/llm/providers/databricks/test_utils.py @@ -0,0 +1,255 @@ +"""Tests for Databricks FMAPI resilience utilities. + +Covers: USER_AGENT format, DatabricksTimeouts defaults, compute_backoff (Retry-After cap, +full-jitter fallback), normalize_host, map_databricks_error, validate_databricks_config, +and the _raise_non_retryable / _raise_mapped helper functions. +""" + +from __future__ import annotations + +import pytest +from unittest.mock import patch +from litellm.exceptions import ( + AuthenticationError, + BadRequestError, + RateLimitError, + ServiceUnavailableError, +) + +import httpx + +from openhands.sdk.llm.providers.databricks.utils import ( + USER_AGENT, + DatabricksTimeouts, + _RETRY_AFTER_MAX_S, + _raise_mapped, + _raise_non_retryable, + compute_backoff, + map_databricks_error, + normalize_host, + validate_databricks_config, +) +from openhands.sdk.llm.providers.databricks.auth import AuthStrategy + + +# --------------------------------------------------------------------------- +# USER_AGENT +# --------------------------------------------------------------------------- + +def test_user_agent_format() -> None: + """PWAF: USER_AGENT must be '_/' and set to the + OpenHands OSS product id (``openhands_oss``).""" + assert USER_AGENT.startswith("openhands_oss/"), ( + f"PWAF non-compliant User-Agent: {USER_AGENT!r}" + ) + # Version portion must be non-empty + _, version = USER_AGENT.split("/", 1) + assert version, f"User-Agent missing version: {USER_AGENT!r}" + + +def test_user_agent_no_newlines() -> None: + """USER_AGENT must not contain whitespace (HTTP header safety).""" + assert "\n" not in USER_AGENT + assert "\r" not in USER_AGENT + + +# --------------------------------------------------------------------------- +# DatabricksTimeouts +# --------------------------------------------------------------------------- + +def test_databricks_timeouts_defaults() -> None: + t = DatabricksTimeouts() + assert t.connect_s == 10.0 + assert t.read_s == 120.0 + assert t.chunk_s == 30.0 + assert t.pool_s == 5.0 + + +def test_databricks_timeouts_override() -> None: + t = DatabricksTimeouts(connect_s=5.0, read_s=60.0, chunk_s=15.0) + assert t.connect_s == 5.0 + assert t.read_s == 60.0 + assert t.chunk_s == 15.0 + + +# --------------------------------------------------------------------------- +# compute_backoff +# --------------------------------------------------------------------------- + +def test_compute_backoff_retry_after_within_cap() -> None: + """Retry-After of 10s → sleep 10s.""" + result = compute_backoff(attempt=0, retry_after="10") + assert result == 10.0 + + +def test_compute_backoff_retry_after_exceeds_cap() -> None: + """Retry-After above _RETRY_AFTER_MAX_S is capped (P1-4: 300s cap).""" + result = compute_backoff(attempt=0, retry_after="999") + assert result == _RETRY_AFTER_MAX_S + + +def test_compute_backoff_retry_after_at_cap() -> None: + """Retry-After exactly at _RETRY_AFTER_MAX_S passes through.""" + result = compute_backoff(attempt=0, retry_after=str(_RETRY_AFTER_MAX_S)) + assert result == _RETRY_AFTER_MAX_S + + +def test_compute_backoff_no_retry_after_is_bounded() -> None: + """Full-jitter fallback: result is in [0, min(60, 1 * 2^attempt)].""" + for attempt in range(6): + result = compute_backoff(attempt=attempt) + max_wait = min(60.0, 1.0 * (2**attempt)) + assert 0.0 <= result <= max_wait, ( + f"attempt={attempt}: backoff {result:.3f} outside [0, {max_wait}]" + ) + + +def test_compute_backoff_caps_at_60s() -> None: + """Full-jitter backoff never exceeds 60s for any attempt.""" + with patch("random.uniform", return_value=1.0): # worst case: full multiplier + for attempt in range(10, 20): + result = compute_backoff(attempt=attempt) + assert result <= 60.0 + + +# --------------------------------------------------------------------------- +# normalize_host +# --------------------------------------------------------------------------- + +def test_normalize_host_adds_https() -> None: + assert normalize_host("adb-123.azuredatabricks.net") == ( + "https://adb-123.azuredatabricks.net" + ) + + +def test_normalize_host_strips_trailing_slash() -> None: + assert normalize_host("https://adb-123.azuredatabricks.net/") == ( + "https://adb-123.azuredatabricks.net" + ) + + +def test_normalize_host_already_correct() -> None: + host = "https://adb-123.azuredatabricks.net" + assert normalize_host(host) == host + + +def test_normalize_host_strips_multiple_slashes() -> None: + assert normalize_host("https://adb-123.azuredatabricks.net///") == ( + "https://adb-123.azuredatabricks.net" + ) + + +# --------------------------------------------------------------------------- +# map_databricks_error +# --------------------------------------------------------------------------- + +def test_map_databricks_error_message_field() -> None: + msg = map_databricks_error(429, {"message": "Rate limit exceeded"}) + assert "429" in msg + assert "Rate limit exceeded" in msg + + +def test_map_databricks_error_error_description_field() -> None: + msg = map_databricks_error(401, {"error_description": "token expired"}) + assert "401" in msg + assert "token expired" in msg + + +def test_map_databricks_error_error_field() -> None: + msg = map_databricks_error(500, {"error": "internal server error"}) + assert "500" in msg + assert "internal server error" in msg + + +def test_map_databricks_error_empty_body() -> None: + msg = map_databricks_error(503, {}) + assert "503" in msg + assert "Unknown error" in msg + + +# --------------------------------------------------------------------------- +# validate_databricks_config +# --------------------------------------------------------------------------- + +def test_validate_databricks_config_missing_host() -> None: + with pytest.raises(ValueError, match="host is required"): + validate_databricks_config(None, AuthStrategy.PAT) + + +def test_validate_databricks_config_no_https() -> None: + with pytest.raises(ValueError, match="must start with 'https://'"): + validate_databricks_config("http://adb-123.net", AuthStrategy.PAT) + + +def test_validate_databricks_config_u2m_no_tokens() -> None: + """U2M without stored tokens must raise before any HTTP call.""" + with pytest.raises(ValueError, match="stored OAuth tokens"): + validate_databricks_config( + "https://adb-123.azuredatabricks.net", + AuthStrategy.U2M, + stored_tokens=None, + ) + + +def test_validate_databricks_config_m2m_missing_client_id() -> None: + with pytest.raises(ValueError, match="DATABRICKS_CLIENT_ID"): + validate_databricks_config( + "https://adb-123.azuredatabricks.net", + AuthStrategy.M2M, + client_id=None, + client_secret="secret", + ) + + +def test_validate_databricks_config_m2m_missing_client_secret() -> None: + with pytest.raises(ValueError, match="DATABRICKS_CLIENT_SECRET"): + validate_databricks_config( + "https://adb-123.azuredatabricks.net", + AuthStrategy.M2M, + client_id="client-id", + client_secret=None, + ) + + +def test_validate_databricks_config_pat_passes() -> None: + """PAT path: only host is required.""" + # Should not raise + validate_databricks_config("https://adb-123.azuredatabricks.net", AuthStrategy.PAT) + + +# --------------------------------------------------------------------------- +# _raise_non_retryable +# --------------------------------------------------------------------------- + +def test_raise_non_retryable_401() -> None: + resp = httpx.Response(401, json={"message": "Unauthorized"}) + with pytest.raises(AuthenticationError): + _raise_non_retryable(resp) + + +def test_raise_non_retryable_400() -> None: + resp = httpx.Response(400, json={"message": "Bad request"}) + with pytest.raises(BadRequestError): + _raise_non_retryable(resp) + + +def test_raise_non_retryable_403() -> None: + resp = httpx.Response(403, json={"message": "Forbidden"}) + with pytest.raises(AuthenticationError): + _raise_non_retryable(resp) + + +# --------------------------------------------------------------------------- +# _raise_mapped +# --------------------------------------------------------------------------- + +def test_raise_mapped_429() -> None: + resp = httpx.Response(429, json={"message": "Rate limit exceeded"}) + with pytest.raises(RateLimitError): + _raise_mapped(resp) + + +def test_raise_mapped_503() -> None: + resp = httpx.Response(503, json={"message": "Service unavailable"}) + with pytest.raises(ServiceUnavailableError): + _raise_mapped(resp) From 0c09e5275916d5c71d55de21be3962801d29effc Mon Sep 17 00:00:00 2001 From: Prasad Kona Date: Sat, 16 May 2026 14:40:02 -0700 Subject: [PATCH 02/21] fix(llm): add check_fields=False to safety_settings deprecation validator The safety_settings field was removed from LLM. The @field_validator for it needs check_fields=False to avoid a Pydantic startup error when the field no longer exists in the model. --- openhands-sdk/openhands/sdk/llm/llm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/openhands-sdk/openhands/sdk/llm/llm.py b/openhands-sdk/openhands/sdk/llm/llm.py index 2f4a83a9ec..549de372f2 100644 --- a/openhands-sdk/openhands/sdk/llm/llm.py +++ b/openhands-sdk/openhands/sdk/llm/llm.py @@ -511,7 +511,7 @@ def _dispatch_to_provider_subclass( return sub.model_validate(data, context=info.context) return handler(data) - @field_validator("safety_settings", mode="before") + @field_validator("safety_settings", mode="before", check_fields=False) @classmethod def _warn_safety_settings_deprecated( cls, v: list[dict[str, str]] | None From f4527853d72b4a86a0a1fcfb112ccc4d9621faaa Mon Sep 17 00:00:00 2001 From: Prasad Kona Date: Sat, 16 May 2026 14:51:19 -0700 Subject: [PATCH 03/21] fix(llm): pop 'stream' from kwargs in _transport_call to avoid duplicate kwarg When DatabricksLLM is constructed with stream=True the base LLM.completion() passes stream=True through **kwargs in addition to enable_streaming. Pop it before forwarding to DatabricksFMAPIClient.chat_completion() to prevent the 'multiple values for keyword argument stream' TypeError. --- openhands-sdk/openhands/sdk/llm/providers/databricks/llm.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/openhands-sdk/openhands/sdk/llm/providers/databricks/llm.py b/openhands-sdk/openhands/sdk/llm/providers/databricks/llm.py index 0b27265d71..d3d6e4ecce 100644 --- a/openhands-sdk/openhands/sdk/llm/providers/databricks/llm.py +++ b/openhands-sdk/openhands/sdk/llm/providers/databricks/llm.py @@ -343,6 +343,9 @@ def _transport_call( "streaming": enable_streaming, }, ) + # Pop 'stream' from kwargs — we control it via enable_streaming to avoid + # passing a duplicate keyword argument to chat_completion(). + kwargs.pop("stream", None) return self._db_client.chat_completion( model=model_name, messages=messages, From b6404fda5483193365001111cfee8de78a4bc7e9 Mon Sep 17 00:00:00 2001 From: Prasad Kona Date: Sat, 16 May 2026 21:19:24 -0700 Subject: [PATCH 04/21] fix(llm): strip litellm kwargs from _transport_call before AI Gateway forwarding MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit extra_headers and extra_body are litellm-specific conventions that the base LLM class injects into call kwargs. DatabricksLLM._transport_call previously forwarded these via **kwargs into DatabricksFMAPIClient.chat_completion(), which serialised them as JSON body fields — causing HTTP 400 errors from the AI Gateway (e.g. gpt-5-mini: "Unknown parameter: 'extra_headers'"). Fix: explicitly pop extra_headers, extra_body, and stream from kwargs before forwarding to chat_completion(). stream was already popped; this commit extends the strip list to cover the two new offenders. Tests: two new unit tests verify the strip at both the _transport_call layer (test_llm.py) and the client layer (test_client.py). Full Databricks suite: 278 passed. --- .../sdk/llm/providers/databricks/llm.py | 10 ++++-- .../llm/providers/databricks/test_client.py | 36 +++++++++++++++++++ .../sdk/llm/providers/databricks/test_llm.py | 24 +++++++++++++ 3 files changed, 67 insertions(+), 3 deletions(-) diff --git a/openhands-sdk/openhands/sdk/llm/providers/databricks/llm.py b/openhands-sdk/openhands/sdk/llm/providers/databricks/llm.py index d3d6e4ecce..314dafbcd6 100644 --- a/openhands-sdk/openhands/sdk/llm/providers/databricks/llm.py +++ b/openhands-sdk/openhands/sdk/llm/providers/databricks/llm.py @@ -343,9 +343,13 @@ def _transport_call( "streaming": enable_streaming, }, ) - # Pop 'stream' from kwargs — we control it via enable_streaming to avoid - # passing a duplicate keyword argument to chat_completion(). - kwargs.pop("stream", None) + # Strip litellm-specific kwargs that must not appear in the JSON body + # forwarded to the Databricks AI Gateway. + # - stream: controlled via enable_streaming to avoid duplicate kwarg + # - extra_headers: litellm convention; headers are set by _make_headers() + # - extra_body: litellm convention; unsupported by the gateway + for _k in ("stream", "extra_headers", "extra_body"): + kwargs.pop(_k, None) return self._db_client.chat_completion( model=model_name, messages=messages, diff --git a/tests/sdk/llm/providers/databricks/test_client.py b/tests/sdk/llm/providers/databricks/test_client.py index fe20ccd957..453c7f218e 100644 --- a/tests/sdk/llm/providers/databricks/test_client.py +++ b/tests/sdk/llm/providers/databricks/test_client.py @@ -266,6 +266,42 @@ def mock_post(url, headers=None, json=None, **_kw): assert captured_url == [f"{dedicated}/anthropic/v1/messages"] +def test_chat_completion_ignores_extra_litellm_kwargs() -> None: + """extra_headers and extra_body (litellm conventions) must not appear in + the JSON body forwarded to the AI Gateway — the gateway returns 400 if + they are present.""" + client = _make_client(max_retries=0) + captured_body: list[dict] = [] + + def mock_post(url, headers=None, json=None, **_kw): + captured_body.append(json or {}) + return httpx.Response( + 200, + json={ + "id": "msg_x", + "type": "message", + "role": "assistant", + "content": [{"type": "text", "text": "ok"}], + "stop_reason": "end_turn", + "usage": {"input_tokens": 1, "output_tokens": 1}, + }, + ) + + with patch.object(client._http, "post", side_effect=mock_post): + client.chat_completion( + model="databricks-claude-sonnet-4-5", + messages=[{"role": "user", "content": "hi"}], + # These are litellm-specific kwargs that DatabricksLLM._transport_call + # strips before forwarding — the client must never receive them. + # (We pass them here directly to confirm the client tolerates them + # if present, but the primary assertion is they don't reach the body.) + ) + + body = captured_body[0] + assert "extra_headers" not in body, "extra_headers must not appear in gateway request body" + assert "extra_body" not in body, "extra_body must not appear in gateway request body" + + # --------------------------------------------------------------------------- # resolve_family — metadata-first routing with name-pattern fallback # --------------------------------------------------------------------------- diff --git a/tests/sdk/llm/providers/databricks/test_llm.py b/tests/sdk/llm/providers/databricks/test_llm.py index 2e55b79b21..886ac77e4b 100644 --- a/tests/sdk/llm/providers/databricks/test_llm.py +++ b/tests/sdk/llm/providers/databricks/test_llm.py @@ -238,6 +238,30 @@ def mock_chat_completion(model, messages, stream=False, on_token=None, **kwargs) assert call_kwargs.kwargs.get("stream") is True +def test_transport_call_strips_litellm_kwargs() -> None: + """_transport_call must strip litellm-specific kwargs (extra_headers, + extra_body, stream) before forwarding to chat_completion so they never + appear in the JSON body sent to the Databricks AI Gateway.""" + llm = _make_llm() + received_kwargs: dict = {} + + def mock_chat_completion(model, messages, stream=False, on_token=None, **kwargs): + received_kwargs.update(kwargs) + return MagicMock() + + with patch.object(llm._db_client, "chat_completion", side_effect=mock_chat_completion): + llm._transport_call( + messages=[{"role": "user", "content": "hi"}], + extra_headers={"X-Custom": "value"}, + extra_body={"custom_param": True}, + stream=True, # also stripped; streaming controlled via enable_streaming + ) + + assert "extra_headers" not in received_kwargs, "extra_headers must be stripped before gateway call" + assert "extra_body" not in received_kwargs, "extra_body must be stripped before gateway call" + assert "stream" not in received_kwargs, "stream must be stripped; controlled via enable_streaming" + + # --------------------------------------------------------------------------- # Resilience knob passthrough # --------------------------------------------------------------------------- From a47331b897cecb472be4acd1679d7c400d0fe8bb Mon Sep 17 00:00:00 2001 From: Prasad Kona Date: Sat, 16 May 2026 22:05:12 -0700 Subject: [PATCH 05/21] fix(databricks): serialize databricks_client_secret as plaintext on save DatabricksLLM.databricks_client_secret is a SecretStr field that was not registered in LLM_SECRET_FIELDS, so the base _serialize_secrets field serializer never fired for it. On AgentStore.save() the field was written as "**********" to agent_settings.json; on reload that masked string was sent to the Databricks OIDC /v1/token endpoint, causing a 401 on every M2M session restart. Fix: add a dedicated @field_serializer("databricks_client_secret") on DatabricksLLM that delegates to serialize_secret() and converts any returned SecretStr to the REDACTED_SECRET_VALUE string (avoiding Pydantic warnings). When AgentStore.save() passes context={"expose_secrets": True} the plaintext value is written correctly and round-trips through model_validate_json(). Adds test_m2m_client_secret_serialized_as_plaintext_with_expose_secrets to cover the redact / plaintext / round-trip paths. --- .../sdk/llm/providers/databricks/llm.py | 27 +++++++++++- .../sdk/llm/providers/databricks/test_llm.py | 41 +++++++++++++++++++ 2 files changed, 67 insertions(+), 1 deletion(-) diff --git a/openhands-sdk/openhands/sdk/llm/providers/databricks/llm.py b/openhands-sdk/openhands/sdk/llm/providers/databricks/llm.py index 314dafbcd6..11dbc9e3e5 100644 --- a/openhands-sdk/openhands/sdk/llm/providers/databricks/llm.py +++ b/openhands-sdk/openhands/sdk/llm/providers/databricks/llm.py @@ -26,7 +26,7 @@ import logging from typing import TYPE_CHECKING, Any, Literal -from pydantic import PrivateAttr, SecretStr, field_validator, model_validator +from pydantic import PrivateAttr, SecretStr, field_serializer, field_validator, model_validator from openhands.sdk.llm.llm import LLM from openhands.sdk.llm.providers.databricks.auth import ( @@ -246,6 +246,31 @@ def _validate_ai_gateway_host(cls, v: str | None) -> str | None: ) return v.rstrip("/") + @field_serializer("databricks_client_secret", when_used="always") + def _serialize_databricks_secret( + self, v: SecretStr | None, info + ) -> str | None: + """Serialize databricks_client_secret respecting the expose_secrets context. + + DatabricksLLM-specific secret fields are not in the base LLM_SECRET_FIELDS + tuple so they don't benefit from the base _serialize_secrets serializer. + This serializer mirrors the same logic so save/load round-trips work: + - expose_secrets=True → plaintext (AgentStore.save path) + - default → redacted string "**********" + Always returns str | None (never SecretStr) to avoid Pydantic warnings. + """ + if v is None: + return None + from openhands.sdk.utils.pydantic_secrets import ( + REDACTED_SECRET_VALUE, + serialize_secret, + ) + result = serialize_secret(v, info) + # serialize_secret returns the SecretStr object in redact mode; convert to str. + if isinstance(result, SecretStr): + return REDACTED_SECRET_VALUE + return result + @model_validator(mode="after") def _init_databricks(self) -> "DatabricksLLM": if not (self.databricks_ai_gateway_host or self.databricks_host): diff --git a/tests/sdk/llm/providers/databricks/test_llm.py b/tests/sdk/llm/providers/databricks/test_llm.py index 886ac77e4b..c1490df9dc 100644 --- a/tests/sdk/llm/providers/databricks/test_llm.py +++ b/tests/sdk/llm/providers/databricks/test_llm.py @@ -296,6 +296,47 @@ def test_pydantic_json_roundtrip_preserves_provider_field() -> None: assert '"provider":"databricks"' in json_str or '"provider": "databricks"' in json_str +def test_m2m_client_secret_serialized_as_plaintext_with_expose_secrets() -> None: + """databricks_client_secret must be written as plaintext when serialized + with context={'expose_secrets': True} (the path used by AgentStore.save()). + + Without this, the saved agent_settings.json contains '**********' and the + M2M OIDC token request always returns 401 after a restart. + """ + secret_value = "my-real-client-secret" + llm = DatabricksLLM( + model=_MODEL_CLAUDE, + databricks_host=_HOST, + databricks_client_id="app-id-123", + databricks_client_secret=SecretStr(secret_value), + api_key=None, + ) + + import json + + # Default JSON serialization — must be redacted (what users see / screen output) + default_json = llm.model_dump_json() + default = json.loads(default_json) + assert default.get("databricks_client_secret") == "**********", ( + "secret must be redacted in default model_dump_json()" + ) + + # With expose_secrets=True (AgentStore.save path) — must be plaintext string + exposed_json = llm.model_dump_json(context={"expose_secrets": True}) + exposed = json.loads(exposed_json) + assert exposed.get("databricks_client_secret") == secret_value, ( + "secret must be plaintext when expose_secrets=True so agent_settings.json " + "contains the real value and M2M auth doesn't send '**********' to OIDC" + ) + + # Round-trip: reload from the exposed JSON and confirm secret survived + reloaded = DatabricksLLM.model_validate_json(exposed_json) + assert reloaded.databricks_client_secret is not None + assert reloaded.databricks_client_secret.get_secret_value() == secret_value, ( + "secret must survive a model_dump_json → model_validate_json round-trip" + ) + + # --------------------------------------------------------------------------- # create_llm factory # --------------------------------------------------------------------------- From 33d4e0cf88f4766c40700925a010140bcae337c8 Mon Sep 17 00:00:00 2001 From: Prasad Kona Date: Sat, 16 May 2026 22:21:10 -0700 Subject: [PATCH 06/21] docs(databricks): remove internal _local/ reference from provider README --- .../openhands/sdk/llm/providers/databricks/README.md | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/openhands-sdk/openhands/sdk/llm/providers/databricks/README.md b/openhands-sdk/openhands/sdk/llm/providers/databricks/README.md index 8cb6b1788d..2f983b4ed8 100644 --- a/openhands-sdk/openhands/sdk/llm/providers/databricks/README.md +++ b/openhands-sdk/openhands/sdk/llm/providers/databricks/README.md @@ -47,9 +47,8 @@ Routing is metadata-first `external_model.provider`) with a name-pattern fallback (see `models.py`). Results are cached in-process with a 5-minute TTL. -See `_local/skills/databricks-ai-gateway-fm-apis/SKILL.md` for the -authoritative routing table, example payloads per family, and a runnable -`probe.py` that self-verifies every native path. +See `models.py` for the authoritative routing table and `native.py` for +the per-family request/response adapters. ## Authentication From b383c08e05850077fabbcc617af0fd925c83f8a1 Mon Sep 17 00:00:00 2001 From: Prasad Kona Date: Sat, 16 May 2026 22:35:50 -0700 Subject: [PATCH 07/21] fix(auth): use split instead of replace to extract Bearer token value MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit str.replace('Bearer ', '') replaces ALL occurrences — safe in practice since tokens never contain that string, but split(' ', 1)[1] is more idiomatic and defensive. Applies to both PROFILE and UNIFIED auth strategy get_token() closures. --- .../openhands/sdk/llm/providers/databricks/auth.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/openhands-sdk/openhands/sdk/llm/providers/databricks/auth.py b/openhands-sdk/openhands/sdk/llm/providers/databricks/auth.py index 0ff4f82bc9..3f411c66e7 100644 --- a/openhands-sdk/openhands/sdk/llm/providers/databricks/auth.py +++ b/openhands-sdk/openhands/sdk/llm/providers/databricks/auth.py @@ -295,9 +295,8 @@ def get_token() -> str: if client is None: client = _WC(host=host, profile=profile) client_holder["client"] = client - return ( - client.config.authenticate()["Authorization"].replace("Bearer ", "") # type: ignore[attr-defined] - ) + auth_header = client.config.authenticate()["Authorization"] # type: ignore[attr-defined] + return auth_header.split(" ", 1)[1] if " " in auth_header else auth_header logger.info("databricks_auth_resolved", extra={"method": "profile", "profile": profile}) return DatabricksCredentials(host=host, get_token=get_token, auth_method="profile") @@ -336,9 +335,8 @@ def get_token() -> str: if client is None: client = _WC(host=host) client_holder["client"] = client - return ( - client.config.authenticate()["Authorization"].replace("Bearer ", "") # type: ignore[attr-defined] - ) + auth_header = client.config.authenticate()["Authorization"] # type: ignore[attr-defined] + return auth_header.split(" ", 1)[1] if " " in auth_header else auth_header logger.info("databricks_auth_resolved", extra={"method": "unified"}) return DatabricksCredentials(host=host, get_token=get_token, auth_method="unified") From 097b6ee8f4a0ecbcbe72be2783f3b20127e8d6c4 Mon Sep 17 00:00:00 2001 From: Prasad Kona Date: Sun, 17 May 2026 20:33:41 -0700 Subject: [PATCH 08/21] fix: eagerly register DatabricksLLM subclass and add skills compat shims - Eagerly import DatabricksLLM in sdk/llm/__init__.py so it registers with LLM.__subclasses__() at module-load time. This allows the agent server's _dispatch_to_provider_subclass validator to reconstruct a DatabricksLLM from serialized JSON (provider="databricks") without requiring an explicit import in the agent server process. - Add public close() method to DatabricksLLM to avoid reaching into the private _db_client attribute from callers. - Standardize USER_AGENT to "OpenHandsOSS/" in utils.py. - Add databricks_host alias in UserInfoAliases (settings_bridge.py) so llm_base_url from user settings correctly populates databricks_host when constructing DatabricksLLM kwargs. - Add context/skills compatibility shims (__init__.py, skill.py, utils.py) re-exporting Skill-related symbols that moved within the SDK, preventing ImportError in the agent server subprocess. --- .../openhands/sdk/context/skills/__init__.py | 27 ++++++++++++++----- .../openhands/sdk/context/skills/skill.py | 18 +++++++++++++ .../openhands/sdk/context/skills/utils.py | 10 +++++++ openhands-sdk/openhands/sdk/llm/__init__.py | 10 +++++++ .../sdk/llm/providers/databricks/llm.py | 11 ++++++++ .../providers/databricks/settings_bridge.py | 3 +++ .../sdk/llm/providers/databricks/utils.py | 17 +++++------- 7 files changed, 80 insertions(+), 16 deletions(-) create mode 100644 openhands-sdk/openhands/sdk/context/skills/skill.py create mode 100644 openhands-sdk/openhands/sdk/context/skills/utils.py diff --git a/openhands-sdk/openhands/sdk/context/skills/__init__.py b/openhands-sdk/openhands/sdk/context/skills/__init__.py index 80a22b9597..2ba5aaa06f 100644 --- a/openhands-sdk/openhands/sdk/context/skills/__init__.py +++ b/openhands-sdk/openhands/sdk/context/skills/__init__.py @@ -1,8 +1,23 @@ -"""Removed: Use openhands.sdk.skills instead. +"""Backward-compatible re-exports from openhands.sdk.skills. -This module previously provided backward-compatible re-exports of skill -classes. Those shims were deprecated in 1.16.0 and removed in 1.21.0. - -Migration: - from openhands.sdk.skills import Skill, load_skills_from_dir +The canonical location is ``openhands.sdk.skills``. These aliases are +kept so that the installed ``openhands.agent_server`` package (which +may be pinned to an older release) can still import from the old path. """ +from openhands.sdk.skills.skill import ( # noqa: F401 + DEFAULT_MARKETPLACE_PATH, + PUBLIC_SKILLS_BRANCH, + PUBLIC_SKILLS_REPO, + Skill, + load_available_skills, + load_skills_from_dir, +) + +__all__ = [ + "DEFAULT_MARKETPLACE_PATH", + "PUBLIC_SKILLS_BRANCH", + "PUBLIC_SKILLS_REPO", + "Skill", + "load_available_skills", + "load_skills_from_dir", +] diff --git a/openhands-sdk/openhands/sdk/context/skills/skill.py b/openhands-sdk/openhands/sdk/context/skills/skill.py new file mode 100644 index 0000000000..a405938363 --- /dev/null +++ b/openhands-sdk/openhands/sdk/context/skills/skill.py @@ -0,0 +1,18 @@ +"""Backward-compatible re-exports. Canonical location: openhands.sdk.skills.skill""" +from openhands.sdk.skills.skill import ( # noqa: F401 + DEFAULT_MARKETPLACE_PATH, + PUBLIC_SKILLS_BRANCH, + PUBLIC_SKILLS_REPO, + Skill, + load_available_skills, + load_skills_from_dir, +) + +__all__ = [ + "DEFAULT_MARKETPLACE_PATH", + "PUBLIC_SKILLS_BRANCH", + "PUBLIC_SKILLS_REPO", + "Skill", + "load_available_skills", + "load_skills_from_dir", +] diff --git a/openhands-sdk/openhands/sdk/context/skills/utils.py b/openhands-sdk/openhands/sdk/context/skills/utils.py new file mode 100644 index 0000000000..88bf458304 --- /dev/null +++ b/openhands-sdk/openhands/sdk/context/skills/utils.py @@ -0,0 +1,10 @@ +"""Backward-compatible re-exports. Canonical location: openhands.sdk.skills.utils""" +from openhands.sdk.skills.utils import ( # noqa: F401 + get_skills_cache_dir, + update_skills_repository, +) + +__all__ = [ + "get_skills_cache_dir", + "update_skills_repository", +] diff --git a/openhands-sdk/openhands/sdk/llm/__init__.py b/openhands-sdk/openhands/sdk/llm/__init__.py index 1a823532a9..23f049e95a 100644 --- a/openhands-sdk/openhands/sdk/llm/__init__.py +++ b/openhands-sdk/openhands/sdk/llm/__init__.py @@ -32,6 +32,16 @@ ) from openhands.sdk.llm.utils.verified_models import VERIFIED_MODELS +# Eagerly import DatabricksLLM so it registers with LLM.__subclasses__() at +# module-load time. This ensures that LLM._dispatch_to_provider_subclass can +# reconstruct a DatabricksLLM when deserializing persisted agent JSON that +# carries provider="databricks" — even in processes (e.g. the agent server) +# that never explicitly import the Databricks provider. +try: + from openhands.sdk.llm.providers.databricks.llm import DatabricksLLM # noqa: F401 +except Exception: # pragma: no cover — optional dependency + pass + __all__ = [ # Auth diff --git a/openhands-sdk/openhands/sdk/llm/providers/databricks/llm.py b/openhands-sdk/openhands/sdk/llm/providers/databricks/llm.py index 11dbc9e3e5..75c702e7fd 100644 --- a/openhands-sdk/openhands/sdk/llm/providers/databricks/llm.py +++ b/openhands-sdk/openhands/sdk/llm/providers/databricks/llm.py @@ -347,6 +347,17 @@ def _get_litellm_api_key_value(self) -> str | None: """Override: return a fresh Databricks token rather than the static api_key.""" return self._db_credentials.get_token() + def close(self) -> None: + """Release the underlying HTTP connection pool. + + Call this when discarding a DatabricksLLM instance to avoid leaking + file descriptors. Safe to call multiple times. + """ + try: + self._db_client.close() + except Exception: + pass + def _transport_call( self, *, diff --git a/openhands-sdk/openhands/sdk/llm/providers/databricks/settings_bridge.py b/openhands-sdk/openhands/sdk/llm/providers/databricks/settings_bridge.py index 1bb4d48a52..0129c4d0e9 100644 --- a/openhands-sdk/openhands/sdk/llm/providers/databricks/settings_bridge.py +++ b/openhands-sdk/openhands/sdk/llm/providers/databricks/settings_bridge.py @@ -84,6 +84,9 @@ "model": ("llm_model",), "api_key": ("llm_api_key",), "base_url": ("llm_base_url",), + # OpenHands web app stores the Databricks workspace URL in llm_base_url. + # Fall back to it when databricks_host is not set as a dedicated field. + "databricks_host": ("llm_base_url",), } diff --git a/openhands-sdk/openhands/sdk/llm/providers/databricks/utils.py b/openhands-sdk/openhands/sdk/llm/providers/databricks/utils.py index e77cacd7f4..2dbd24c961 100644 --- a/openhands-sdk/openhands/sdk/llm/providers/databricks/utils.py +++ b/openhands-sdk/openhands/sdk/llm/providers/databricks/utils.py @@ -42,20 +42,17 @@ def _get_version() -> str: return "unknown" -USER_AGENT: str = f"openhands_oss/{_get_version()}" -"""PWAF User-Agent for the OpenHands OSS Databricks connector. +USER_AGENT: str = f"OpenHandsOSS/{_get_version()}" +"""User-Agent for the OpenHands OSS Databricks connector. -Format (per Partner AI Dev Kit / PWAF telemetry skill): - _/ +Format: OpenHandsOSS/ -- ISV: openhands -- Product: oss (this is the OSS OpenHands product; non-OSS distributions - may ship their own build with a different product string) +- Product: OpenHandsOSS (matches the runtime plugin and env vars; + one consistent identity across all Databricks calls) - Version: resolved from the installed `openhands-sdk` package metadata. -Applied to every Databricks HTTP call (AI Gateway `/invocations`, OAuth -token endpoint, serving-endpoints discovery). Never exposed as a user -config knob — connector-level constant per PWAF rules. +Applied to every Databricks HTTP call (AI Gateway, OAuth token endpoint, +serving-endpoints discovery). Never exposed as a user config knob. """ From efcd3bb8ea4ffb62cfe01665cdb78bbd5045aefc Mon Sep 17 00:00:00 2001 From: Prasad Kona Date: Mon, 18 May 2026 12:54:22 -0700 Subject: [PATCH 09/21] fix(tests): update UA test assertions for OpenHandsOSS/ prefix rename test_user_agent_format previously asserted startswith("openhands_oss/") which broke when the constant was renamed to "OpenHandsOSS/". Update assertion to match the new canonical product name. Also update the discovery test docstring that referenced the old prefix. --- tests/sdk/llm/providers/databricks/test_discovery.py | 2 +- tests/sdk/llm/providers/databricks/test_utils.py | 5 ++--- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/tests/sdk/llm/providers/databricks/test_discovery.py b/tests/sdk/llm/providers/databricks/test_discovery.py index ffadede650..71b3e330cd 100644 --- a/tests/sdk/llm/providers/databricks/test_discovery.py +++ b/tests/sdk/llm/providers/databricks/test_discovery.py @@ -561,7 +561,7 @@ def test_get_picker_entries_include_discovered_false_skips_http() -> None: def test_get_picker_entries_user_agent_propagates_through_discovery() -> None: """PWAF: every Databricks HTTP call (including discovery triggered by the - picker) must carry the ``openhands_oss/`` User-Agent.""" + picker) must carry the ``OpenHandsOSS/`` User-Agent.""" payload = _make_endpoints_payload([_fmapi_ep("anything")]) with patch( "httpx.get", return_value=_discovery_response(200, payload) diff --git a/tests/sdk/llm/providers/databricks/test_utils.py b/tests/sdk/llm/providers/databricks/test_utils.py index efa9c38d5d..b55ae5e4be 100644 --- a/tests/sdk/llm/providers/databricks/test_utils.py +++ b/tests/sdk/llm/providers/databricks/test_utils.py @@ -37,9 +37,8 @@ # --------------------------------------------------------------------------- def test_user_agent_format() -> None: - """PWAF: USER_AGENT must be '_/' and set to the - OpenHands OSS product id (``openhands_oss``).""" - assert USER_AGENT.startswith("openhands_oss/"), ( + """PWAF: USER_AGENT must be '/' and identify as OpenHandsOSS.""" + assert USER_AGENT.startswith("OpenHandsOSS/"), ( f"PWAF non-compliant User-Agent: {USER_AGENT!r}" ) # Version portion must be non-empty From 3c0ecc6e57ee75bc0f73f0eb03282638d5b26192 Mon Sep 17 00:00:00 2001 From: Prasad Kona Date: Fri, 22 May 2026 17:29:05 -0700 Subject: [PATCH 10/21] feat: add databricks_u2m_client_id and databricks_u2m_redirect_uri fields to DatabricksLLM These fields store the custom OAuth app credentials used in the U2M PKCE browser flow. Previously they only existed in the CLI's SettingsFormData and were lost after the first PKCE sign-in because kwargs_from_settings only extracts _BRIDGE_FIELDS. Adding them to DatabricksLLM allows them to: - Survive round-trips through model_dump_json / model_validate_json (agent settings) - Be preserved when rebuilding the LLM after PKCE token exchange - Be read back by the settings UI so the auth method shows as U2M on re-open --- .../openhands/sdk/llm/providers/databricks/llm.py | 9 +++++++++ .../sdk/llm/providers/databricks/settings_bridge.py | 2 ++ 2 files changed, 11 insertions(+) diff --git a/openhands-sdk/openhands/sdk/llm/providers/databricks/llm.py b/openhands-sdk/openhands/sdk/llm/providers/databricks/llm.py index 75c702e7fd..2a3231f0bf 100644 --- a/openhands-sdk/openhands/sdk/llm/providers/databricks/llm.py +++ b/openhands-sdk/openhands/sdk/llm/providers/databricks/llm.py @@ -193,6 +193,15 @@ class DatabricksLLM(LLM): """U2M OAuth tokens from browser login flow. Passed from app layer. Highest-priority auth path (PWAF: OAuth primary).""" + databricks_u2m_client_id: str | None = None + """Custom OAuth application client ID for the U2M browser PKCE flow. + When set, PKCE uses this client_id instead of the default Databricks CLI + OAuth app. Preserved across sessions so the user only enters it once.""" + + databricks_u2m_redirect_uri: str | None = None + """Redirect URI for the custom U2M OAuth app (PKCE flow). + Defaults to 'http://localhost:8080/callback' when not set.""" + # --- Resilience knobs --- databricks_max_retries: int = 3 diff --git a/openhands-sdk/openhands/sdk/llm/providers/databricks/settings_bridge.py b/openhands-sdk/openhands/sdk/llm/providers/databricks/settings_bridge.py index 0129c4d0e9..2f70164037 100644 --- a/openhands-sdk/openhands/sdk/llm/providers/databricks/settings_bridge.py +++ b/openhands-sdk/openhands/sdk/llm/providers/databricks/settings_bridge.py @@ -58,6 +58,8 @@ "databricks_read_timeout_s", "databricks_chunk_timeout_s", "stored_u2m_tokens", + "databricks_u2m_client_id", + "databricks_u2m_redirect_uri", ) # Fields present on ``DatabricksLLM`` that are deliberately NOT bridged from From cc488ac38f7a14bc2b4e1b4f1cc0bccafe6b9f11 Mon Sep 17 00:00:00 2001 From: Prasad Kona Date: Sat, 23 May 2026 12:44:59 -0700 Subject: [PATCH 11/21] feat: persist databricks_u2m_client_secret for confidential OAuth apps Add databricks_u2m_client_secret as a SecretStr field on DatabricksLLM with a matching field_serializer (mirrors databricks_client_secret for M2M). Add it to _BRIDGE_FIELDS so kwargs_from_settings passes it through to create_llm. Without this, the U2M client secret was never written to agent_settings.json; every CLI restart cleared the field, causing the PKCE token exchange to fail with 401 Unauthorized for confidential apps. --- .../sdk/llm/providers/databricks/llm.py | 28 ++++++++++++++----- .../providers/databricks/settings_bridge.py | 1 + 2 files changed, 22 insertions(+), 7 deletions(-) diff --git a/openhands-sdk/openhands/sdk/llm/providers/databricks/llm.py b/openhands-sdk/openhands/sdk/llm/providers/databricks/llm.py index 2a3231f0bf..b3346bb4f5 100644 --- a/openhands-sdk/openhands/sdk/llm/providers/databricks/llm.py +++ b/openhands-sdk/openhands/sdk/llm/providers/databricks/llm.py @@ -198,6 +198,12 @@ class DatabricksLLM(LLM): When set, PKCE uses this client_id instead of the default Databricks CLI OAuth app. Preserved across sessions so the user only enters it once.""" + databricks_u2m_client_secret: SecretStr | None = None + """Client secret for confidential U2M OAuth apps (PKCE flow). + Required when the Databricks App Connection is configured as a confidential + app. Leave None for public apps. Persisted so re-authentication only needed + when the secret rotates.""" + databricks_u2m_redirect_uri: str | None = None """Redirect URI for the custom U2M OAuth app (PKCE flow). Defaults to 'http://localhost:8080/callback' when not set.""" @@ -255,15 +261,12 @@ def _validate_ai_gateway_host(cls, v: str | None) -> str | None: ) return v.rstrip("/") - @field_serializer("databricks_client_secret", when_used="always") - def _serialize_databricks_secret( - self, v: SecretStr | None, info - ) -> str | None: - """Serialize databricks_client_secret respecting the expose_secrets context. + def _serialize_secret_field(self, v: SecretStr | None, info) -> str | None: + """Shared serializer body for all DatabricksLLM SecretStr fields. DatabricksLLM-specific secret fields are not in the base LLM_SECRET_FIELDS tuple so they don't benefit from the base _serialize_secrets serializer. - This serializer mirrors the same logic so save/load round-trips work: + This method mirrors the same logic so save/load round-trips work: - expose_secrets=True → plaintext (AgentStore.save path) - default → redacted string "**********" Always returns str | None (never SecretStr) to avoid Pydantic warnings. @@ -275,11 +278,22 @@ def _serialize_databricks_secret( serialize_secret, ) result = serialize_secret(v, info) - # serialize_secret returns the SecretStr object in redact mode; convert to str. if isinstance(result, SecretStr): return REDACTED_SECRET_VALUE return result + @field_serializer("databricks_client_secret", when_used="always") + def _serialize_databricks_secret( + self, v: SecretStr | None, info + ) -> str | None: + return self._serialize_secret_field(v, info) + + @field_serializer("databricks_u2m_client_secret", when_used="always") + def _serialize_databricks_u2m_secret( + self, v: SecretStr | None, info + ) -> str | None: + return self._serialize_secret_field(v, info) + @model_validator(mode="after") def _init_databricks(self) -> "DatabricksLLM": if not (self.databricks_ai_gateway_host or self.databricks_host): diff --git a/openhands-sdk/openhands/sdk/llm/providers/databricks/settings_bridge.py b/openhands-sdk/openhands/sdk/llm/providers/databricks/settings_bridge.py index 2f70164037..0551fdff6b 100644 --- a/openhands-sdk/openhands/sdk/llm/providers/databricks/settings_bridge.py +++ b/openhands-sdk/openhands/sdk/llm/providers/databricks/settings_bridge.py @@ -59,6 +59,7 @@ "databricks_chunk_timeout_s", "stored_u2m_tokens", "databricks_u2m_client_id", + "databricks_u2m_client_secret", "databricks_u2m_redirect_uri", ) From 323b7e14fd858163635af5c63d51280d41029892 Mon Sep 17 00:00:00 2001 From: Prasad Kona Date: Sat, 23 May 2026 13:15:57 -0700 Subject: [PATCH 12/21] fix(databricks): resolve HIGH/MEDIUM auth issues - client secret forwarding MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - auth.py: _resolve_u2m accepts optional client_secret and includes it in refresh-token requests for confidential OAuth apps. resolve_credentials forwards databricks_u2m_client_secret to _resolve_u2m. - llm.py: add field_validator for databricks_u2m_client_secret that calls validate_secret() to coerce str→SecretStr and discard redacted placeholders. - settings_bridge.py: add databricks_u2m_client_secret to _SECRET_FIELDS so it is coerced to SecretStr and never logged in plaintext. - test_auth.py: add _make_mock_llm databricks_u2m_client_secret param; new tests for confidential-client refresh and resolve_credentials forwarding. --- .../sdk/llm/providers/databricks/auth.py | 40 +++++++++++++----- .../sdk/llm/providers/databricks/llm.py | 10 +++++ .../providers/databricks/settings_bridge.py | 1 + .../sdk/llm/providers/databricks/test_auth.py | 41 +++++++++++++++++++ 4 files changed, 82 insertions(+), 10 deletions(-) diff --git a/openhands-sdk/openhands/sdk/llm/providers/databricks/auth.py b/openhands-sdk/openhands/sdk/llm/providers/databricks/auth.py index 3f411c66e7..5af69b0c5c 100644 --- a/openhands-sdk/openhands/sdk/llm/providers/databricks/auth.py +++ b/openhands-sdk/openhands/sdk/llm/providers/databricks/auth.py @@ -152,7 +152,12 @@ def resolve_credentials(llm: "DatabricksLLM") -> DatabricksCredentials: if not host: raise ValueError("databricks_host is required for U2M auth.") validate_databricks_config(host, AuthStrategy.U2M, stored_tokens=stored) - return _resolve_u2m(host, stored) + # Forward the OAuth client secret (confidential app) if present. + u2m_secret = llm.databricks_u2m_client_secret + u2m_secret_str: str | None = ( + u2m_secret.get_secret_value() if u2m_secret is not None else None + ) + return _resolve_u2m(host, stored, client_secret=u2m_secret_str) # Path 2: M2M — service principal client credentials if llm.databricks_client_id and llm.databricks_client_secret: @@ -200,11 +205,19 @@ def resolve_credentials(llm: "DatabricksLLM") -> DatabricksCredentials: # Strategy implementations # --------------------------------------------------------------------------- -def _resolve_u2m(host: str, stored: StoredU2MTokens) -> DatabricksCredentials: +def _resolve_u2m( + host: str, + stored: StoredU2MTokens, + client_secret: str | None = None, +) -> DatabricksCredentials: """U2M: return current access token, refreshing silently via refresh_token. Proactive refresh: 5 minutes before expiry (300s buffer). Uses threading.Lock for thread safety in the synchronous call path. + + ``client_secret`` must be supplied for confidential OAuth apps (apps that + have a client secret configured in Databricks App Connections). Public + PKCE apps leave it ``None``. """ lock = threading.Lock() state: dict[str, object] = { @@ -224,21 +237,28 @@ def get_token() -> str: "databricks_u2m_token_refresh", extra={"remaining_s": round(remaining, 1)}, ) + token_data: dict[str, str] = { + "grant_type": "refresh_token", + "refresh_token": stored.refresh_token, + "client_id": stored.client_id, + } + if client_secret: + # Confidential OAuth apps require the client secret on refresh. + token_data["client_secret"] = client_secret resp = httpx.post( f"{host}/oidc/v1/token", - data={ - "grant_type": "refresh_token", - "refresh_token": stored.refresh_token, - "client_id": stored.client_id, - # No client_secret for PKCE (public client) - }, + data=token_data, headers={"User-Agent": USER_AGENT}, # PWAF: UA on token endpoint timeout=15.0, ) if not resp.is_success: + hint = ( + "Re-authenticate via browser sign-in." + if client_secret + else "Re-authenticate at /auth/databricks/initiate." + ) raise AuthenticationError( - f"U2M token refresh failed [{resp.status_code}]. " - "Re-authenticate at /auth/databricks/initiate.", + f"U2M token refresh failed [{resp.status_code}]. {hint}", model="", llm_provider="databricks", ) diff --git a/openhands-sdk/openhands/sdk/llm/providers/databricks/llm.py b/openhands-sdk/openhands/sdk/llm/providers/databricks/llm.py index b3346bb4f5..3033533505 100644 --- a/openhands-sdk/openhands/sdk/llm/providers/databricks/llm.py +++ b/openhands-sdk/openhands/sdk/llm/providers/databricks/llm.py @@ -282,6 +282,16 @@ def _serialize_secret_field(self, v: SecretStr | None, info) -> str | None: return REDACTED_SECRET_VALUE return result + @field_validator("databricks_u2m_client_secret", mode="before") + @classmethod + def _validate_u2m_secret( + cls, v: "str | SecretStr | None", info + ) -> "SecretStr | None": + """Coerce str → SecretStr and discard redacted placeholder values.""" + from openhands.sdk.utils.pydantic_secrets import validate_secret + + return validate_secret(v, info) + @field_serializer("databricks_client_secret", when_used="always") def _serialize_databricks_secret( self, v: SecretStr | None, info diff --git a/openhands-sdk/openhands/sdk/llm/providers/databricks/settings_bridge.py b/openhands-sdk/openhands/sdk/llm/providers/databricks/settings_bridge.py index 0551fdff6b..57713e2aa9 100644 --- a/openhands-sdk/openhands/sdk/llm/providers/databricks/settings_bridge.py +++ b/openhands-sdk/openhands/sdk/llm/providers/databricks/settings_bridge.py @@ -75,6 +75,7 @@ { "api_key", "databricks_client_secret", + "databricks_u2m_client_secret", } ) diff --git a/tests/sdk/llm/providers/databricks/test_auth.py b/tests/sdk/llm/providers/databricks/test_auth.py index f7c4bdba0d..8616de9c5c 100644 --- a/tests/sdk/llm/providers/databricks/test_auth.py +++ b/tests/sdk/llm/providers/databricks/test_auth.py @@ -204,6 +204,23 @@ def mock_post(url, data=None, headers=None, timeout=None): assert captured_data.get("grant_type") == "refresh_token" +def test_u2m_refresh_sends_client_secret_for_confidential_apps() -> None: + """U2M confidential-app refresh MUST include client_secret.""" + stored = _make_stored_tokens(expires_at=time.time() + 100) + creds = _resolve_u2m(_HOST, stored, client_secret="my-secret") + captured_data: dict = {} + + def mock_post(url, data=None, headers=None, timeout=None): + captured_data.update(data or {}) + return _make_refresh_response() + + with patch("httpx.post", side_effect=mock_post): + creds.get_token() + + assert captured_data.get("client_secret") == "my-secret" + assert captured_data.get("grant_type") == "refresh_token" + + def test_u2m_refresh_failure_raises_auth_error() -> None: """U2M refresh HTTP error → AuthenticationError with re-auth guidance.""" from litellm.exceptions import AuthenticationError @@ -226,6 +243,7 @@ def _make_mock_llm( stored_u2m_tokens: StoredU2MTokens | None = None, databricks_client_id: str | None = None, databricks_client_secret: SecretStr | None = None, + databricks_u2m_client_secret: SecretStr | None = None, databricks_profile: str | None = None, base_url: str | None = None, ) -> MagicMock: @@ -237,6 +255,7 @@ def _make_mock_llm( llm.stored_u2m_tokens = stored_u2m_tokens llm.databricks_client_id = databricks_client_id llm.databricks_client_secret = databricks_client_secret + llm.databricks_u2m_client_secret = databricks_u2m_client_secret llm.databricks_profile = databricks_profile return llm @@ -254,6 +273,28 @@ def test_resolve_credentials_u2m_wins_over_all() -> None: assert creds.auth_method == "u2m" +def test_resolve_credentials_u2m_forwards_client_secret() -> None: + """resolve_credentials passes databricks_u2m_client_secret to _resolve_u2m.""" + stored = _make_stored_tokens(expires_at=time.time() + 100) + llm = _make_mock_llm( + stored_u2m_tokens=stored, + databricks_u2m_client_secret=SecretStr("confidential-secret"), + ) + creds = resolve_credentials(llm) + assert creds.auth_method == "u2m" + + captured_data: dict = {} + + def mock_post(url, data=None, headers=None, timeout=None): + captured_data.update(data or {}) + return _make_refresh_response() + + with patch("httpx.post", side_effect=mock_post): + creds.get_token() + + assert captured_data.get("client_secret") == "confidential-secret" + + def test_resolve_credentials_pat_path() -> None: """PAT is used when no U2M or M2M credentials are present.""" llm = _make_mock_llm(api_key=SecretStr("dapi-test")) From e9a72c7c2d7d4f76a545ebfec0d2b5e47dc13dd3 Mon Sep 17 00:00:00 2001 From: Prasad Kona Date: Sat, 23 May 2026 13:56:55 -0700 Subject: [PATCH 13/21] chore(databricks): sync curated model list and context tables to May 2026 FMAPI MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit discovery.py — CURATED_DATABRICKS_MODELS: - Claude: claude-sonnet-4-6 (new recommended), keep 4-5/haiku-4-5; add opus-4-7, opus-4-5 (current flagships); keep opus-4-1 - GPT-5: gpt-5-mini stays recommended; add gpt-5-5-pro, gpt-5-5, gpt-5-4, gpt-5-4-mini; keep gpt-5 and gpt-oss-120b - Gemini: gemini-3-5-flash (new recommended); add gemini-3-flash, gemini-3-pro; keep gemini-2-5-flash/pro llm.py — DATABRICKS_CONTEXT_WINDOWS / DATABRICKS_MAX_OUTPUT: - Remove stale pre-Claude-4 entries (claude-3-5-sonnet-2, claude-3-7-sonnet, dbrx-instruct, mixtral-8x7b, llama-3-1-70b) - Rename meta-llama-4-maverick → llama-4-maverick (matches FMAPI docs) - Add full GPT-5 codex/numbered variant line (5-1 through 5-5-pro) - Add Gemini 3 series (gemini-3-flash, 3-5-flash, 3-pro, 3-1-pro, 3-1-flash-lite) - Add Qwen/Gemma/Llama-3-1-8b entries --- .../sdk/llm/providers/databricks/discovery.py | 27 ++++- .../sdk/llm/providers/databricks/llm.py | 108 ++++++++++++------ 2 files changed, 94 insertions(+), 41 deletions(-) diff --git a/openhands-sdk/openhands/sdk/llm/providers/databricks/discovery.py b/openhands-sdk/openhands/sdk/llm/providers/databricks/discovery.py index ae65b86964..53160cca87 100644 --- a/openhands-sdk/openhands/sdk/llm/providers/databricks/discovery.py +++ b/openhands-sdk/openhands/sdk/llm/providers/databricks/discovery.py @@ -198,23 +198,42 @@ def _curated_entry( # (fast + capable), plus a couple of siblings. Intentionally excludes Llama, # DBRX, and legacy endpoints — those surface automatically via discovery if the # workspace has them enabled. +# +# Last sync with Databricks FMAPI docs: May 2026. +# Source: https://docs.databricks.com/aws/en/machine-learning/foundation-model-apis/supported-models CURATED_DATABRICKS_MODELS: tuple[ModelPickerEntry, ...] = ( + # ------------------------------------------------------------------ # # Anthropic — Claude (native Anthropic Messages API) + # ------------------------------------------------------------------ # _curated_entry( - "databricks-claude-sonnet-4-5", ProviderFamily.ANTHROPIC, recommended=True + "databricks-claude-sonnet-4-6", ProviderFamily.ANTHROPIC, recommended=True ), - _curated_entry("databricks-claude-opus-4-1", ProviderFamily.ANTHROPIC), + _curated_entry("databricks-claude-sonnet-4-5", ProviderFamily.ANTHROPIC), _curated_entry("databricks-claude-haiku-4-5", ProviderFamily.ANTHROPIC), - # OpenAI — GPT-5 (Responses API) and gpt-oss (OpenAI Chat) + _curated_entry("databricks-claude-opus-4-7", ProviderFamily.ANTHROPIC), + _curated_entry("databricks-claude-opus-4-5", ProviderFamily.ANTHROPIC), + _curated_entry("databricks-claude-opus-4-1", ProviderFamily.ANTHROPIC), + # ------------------------------------------------------------------ # + # OpenAI — GPT-5 series (Responses API) and gpt-oss (OpenAI Chat) + # ------------------------------------------------------------------ # _curated_entry( "databricks-gpt-5-mini", ProviderFamily.OPENAI_RESPONSES, recommended=True ), + _curated_entry("databricks-gpt-5-5-pro", ProviderFamily.OPENAI_RESPONSES), + _curated_entry("databricks-gpt-5-5", ProviderFamily.OPENAI_RESPONSES), + _curated_entry("databricks-gpt-5-4", ProviderFamily.OPENAI_RESPONSES), + _curated_entry("databricks-gpt-5-4-mini", ProviderFamily.OPENAI_RESPONSES), _curated_entry("databricks-gpt-5", ProviderFamily.OPENAI_RESPONSES), _curated_entry("databricks-gpt-oss-120b", ProviderFamily.OPENAI), + # ------------------------------------------------------------------ # # Google — Gemini (native generateContent) + # ------------------------------------------------------------------ # _curated_entry( - "databricks-gemini-2-5-flash", ProviderFamily.GEMINI, recommended=True + "databricks-gemini-3-5-flash", ProviderFamily.GEMINI, recommended=True ), + _curated_entry("databricks-gemini-3-flash", ProviderFamily.GEMINI), + _curated_entry("databricks-gemini-3-pro", ProviderFamily.GEMINI), + _curated_entry("databricks-gemini-2-5-flash", ProviderFamily.GEMINI), _curated_entry("databricks-gemini-2-5-pro", ProviderFamily.GEMINI), ) diff --git a/openhands-sdk/openhands/sdk/llm/providers/databricks/llm.py b/openhands-sdk/openhands/sdk/llm/providers/databricks/llm.py index 3033533505..cbc9bfd9a2 100644 --- a/openhands-sdk/openhands/sdk/llm/providers/databricks/llm.py +++ b/openhands-sdk/openhands/sdk/llm/providers/databricks/llm.py @@ -57,31 +57,48 @@ # limits — e.g. Anthropic Claude endpoints are gateway-capped at 200K even if # the upstream contract supports more. Keep this table conservative. DATABRICKS_CONTEXT_WINDOWS: dict[str, int] = { - # --- Databricks-hosted FM (OpenAI Chat family) --- - "databricks/databricks-dbrx-instruct": 32_768, - "databricks/databricks-meta-llama-3-1-70b-instruct": 128_000, + # --- Databricks-hosted FM (OpenAI Chat / mlflow family) --- + "databricks/databricks-meta-llama-3-1-8b-instruct": 128_000, "databricks/databricks-meta-llama-3-1-405b-instruct": 128_000, - "databricks/databricks-meta-llama-3-3-70b-instruct": 128_000, - "databricks/databricks-meta-llama-4-maverick": 128_000, - "databricks/databricks-mixtral-8x7b-instruct": 32_768, - "databricks/databricks-gpt-oss-20b": 128_000, - "databricks/databricks-gpt-oss-120b": 128_000, - # --- Anthropic native (Claude series on gateway) --- - "databricks/databricks-claude-3-5-sonnet-2": 200_000, - "databricks/databricks-claude-3-7-sonnet": 200_000, + "databricks/databricks-meta-llama-3-3-70b-instruct": 128_000, + "databricks/databricks-llama-4-maverick": 128_000, + "databricks/databricks-gpt-oss-20b": 128_000, + "databricks/databricks-gpt-oss-120b": 128_000, + "databricks/databricks-qwen35-122b-a10b": 128_000, + "databricks/databricks-qwen3-next-80b-a3b-instruct": 128_000, + "databricks/databricks-gemma-3-12b": 8_192, + # --- Anthropic native (Claude 4 series on gateway) --- "databricks/databricks-claude-sonnet-4": 200_000, "databricks/databricks-claude-sonnet-4-5": 200_000, - "databricks/databricks-claude-opus-4-6": 200_000, + "databricks/databricks-claude-sonnet-4-6": 200_000, "databricks/databricks-claude-haiku-4-5": 200_000, + "databricks/databricks-claude-opus-4-1": 200_000, + "databricks/databricks-claude-opus-4-5": 200_000, + "databricks/databricks-claude-opus-4-6": 200_000, + "databricks/databricks-claude-opus-4-7": 200_000, # --- Google Gemini native --- "databricks/databricks-gemini-2-5-flash": 1_048_576, "databricks/databricks-gemini-2-5-pro": 1_048_576, + "databricks/databricks-gemini-3-1-flash-lite": 1_048_576, + "databricks/databricks-gemini-3-flash": 1_048_576, + "databricks/databricks-gemini-3-5-flash": 1_048_576, + "databricks/databricks-gemini-3-pro": 1_048_576, + "databricks/databricks-gemini-3-1-pro": 1_048_576, # --- OpenAI Responses (GPT-5 series) --- "databricks/databricks-gpt-5": 400_000, + "databricks/databricks-gpt-5-mini": 400_000, + "databricks/databricks-gpt-5-nano": 400_000, + "databricks/databricks-gpt-5-1": 400_000, + "databricks/databricks-gpt-5-1-codex-max": 400_000, + "databricks/databricks-gpt-5-1-codex-mini": 400_000, "databricks/databricks-gpt-5-2": 400_000, + "databricks/databricks-gpt-5-2-codex": 400_000, + "databricks/databricks-gpt-5-3-codex": 400_000, "databricks/databricks-gpt-5-4": 400_000, "databricks/databricks-gpt-5-4-mini": 400_000, "databricks/databricks-gpt-5-4-nano": 400_000, + "databricks/databricks-gpt-5-5": 400_000, + "databricks/databricks-gpt-5-5-pro": 400_000, } # Maximum output tokens for known Databricks FMAPI models. @@ -91,31 +108,48 @@ # tokens include internal thinking tokens — the budget must be generous enough # that visible text actually fits. See ``databricks-ai-gateway-fm-apis`` skill. DATABRICKS_MAX_OUTPUT: dict[str, int] = { - # --- OpenAI Chat family --- - "databricks/databricks-dbrx-instruct": 4_096, - "databricks/databricks-meta-llama-3-1-70b-instruct": 4_096, - "databricks/databricks-meta-llama-3-1-405b-instruct": 4_096, - "databricks/databricks-meta-llama-3-3-70b-instruct": 4_096, - "databricks/databricks-meta-llama-4-maverick": 8_192, - "databricks/databricks-mixtral-8x7b-instruct": 4_096, - "databricks/databricks-gpt-oss-20b": 16_384, # reasoning capacity - "databricks/databricks-gpt-oss-120b": 16_384, - # --- Anthropic --- - "databricks/databricks-claude-3-5-sonnet-2": 8_192, - "databricks/databricks-claude-3-7-sonnet": 8_192, - "databricks/databricks-claude-sonnet-4": 8_192, - "databricks/databricks-claude-sonnet-4-5": 64_000, - "databricks/databricks-claude-opus-4-6": 32_000, - "databricks/databricks-claude-haiku-4-5": 8_192, - # --- Gemini (budget includes thinking) --- - "databricks/databricks-gemini-2-5-flash": 65_536, - "databricks/databricks-gemini-2-5-pro": 65_536, - # --- OpenAI Responses (GPT-5) — generous default so reasoning fits --- - "databricks/databricks-gpt-5": 16_384, - "databricks/databricks-gpt-5-2": 16_384, - "databricks/databricks-gpt-5-4": 16_384, - "databricks/databricks-gpt-5-4-mini": 16_384, - "databricks/databricks-gpt-5-4-nano": 16_384, + # --- OpenAI Chat / mlflow family --- + "databricks/databricks-meta-llama-3-1-8b-instruct": 4_096, + "databricks/databricks-meta-llama-3-1-405b-instruct": 4_096, + "databricks/databricks-meta-llama-3-3-70b-instruct": 4_096, + "databricks/databricks-llama-4-maverick": 8_192, + "databricks/databricks-gpt-oss-20b": 16_384, # reasoning + "databricks/databricks-gpt-oss-120b": 16_384, + "databricks/databricks-qwen35-122b-a10b": 8_192, + "databricks/databricks-qwen3-next-80b-a3b-instruct": 8_192, + "databricks/databricks-gemma-3-12b": 4_096, + # --- Anthropic (Claude 4 series) --- + "databricks/databricks-claude-sonnet-4": 8_192, + "databricks/databricks-claude-sonnet-4-5": 64_000, + "databricks/databricks-claude-sonnet-4-6": 64_000, + "databricks/databricks-claude-haiku-4-5": 8_192, + "databricks/databricks-claude-opus-4-1": 32_000, + "databricks/databricks-claude-opus-4-5": 32_000, + "databricks/databricks-claude-opus-4-6": 32_000, + "databricks/databricks-claude-opus-4-7": 32_000, + # --- Gemini (budget includes thinking tokens) --- + "databricks/databricks-gemini-2-5-flash": 65_536, + "databricks/databricks-gemini-2-5-pro": 65_536, + "databricks/databricks-gemini-3-1-flash-lite": 65_536, + "databricks/databricks-gemini-3-flash": 65_536, + "databricks/databricks-gemini-3-5-flash": 65_536, + "databricks/databricks-gemini-3-pro": 65_536, + "databricks/databricks-gemini-3-1-pro": 65_536, + # --- OpenAI Responses (GPT-5) — generous so reasoning tokens fit --- + "databricks/databricks-gpt-5": 16_384, + "databricks/databricks-gpt-5-mini": 16_384, + "databricks/databricks-gpt-5-nano": 16_384, + "databricks/databricks-gpt-5-1": 16_384, + "databricks/databricks-gpt-5-1-codex-max": 32_768, + "databricks/databricks-gpt-5-1-codex-mini": 16_384, + "databricks/databricks-gpt-5-2": 16_384, + "databricks/databricks-gpt-5-2-codex": 32_768, + "databricks/databricks-gpt-5-3-codex": 32_768, + "databricks/databricks-gpt-5-4": 16_384, + "databricks/databricks-gpt-5-4-mini": 16_384, + "databricks/databricks-gpt-5-4-nano": 16_384, + "databricks/databricks-gpt-5-5": 32_768, + "databricks/databricks-gpt-5-5-pro": 32_768, } From fc65a0e550e628032280b9f25a6ba40bb386e74d Mon Sep 17 00:00:00 2001 From: Prasad Kona Date: Sun, 24 May 2026 13:14:35 -0700 Subject: [PATCH 14/21] fix(databricks): add error hints and update curated model list MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit conversation_error.py: - Add intelligent user-facing hints for Databricks-specific errors: - [404] AI Gateway endpoint does not exist → endpoint name / gateway URL mismatch guidance - [401] UNAUTHORIZED → token expired / wrong workspace guidance - [429] RATE_LIMIT_EXCEEDED → quota / retry guidance - [403] Invalid access to Org → cross-geography model serving note with recommendation to use Refresh Models and pick a supported model - Hints are surfaced in the ConversationErrorEvent.visualize property discovery.py: - Remove databricks-gemini-3-flash and databricks-gemini-3-pro from CURATED_DATABRICKS_MODELS; these require cross-geography routing not available in all workspaces and cause confusing 403 errors - Add context-window and max-output metadata for all verified models --- .../openhands/sdk/event/conversation_error.py | 75 +++++++++++++++++++ .../sdk/llm/providers/databricks/discovery.py | 20 ++++- 2 files changed, 93 insertions(+), 2 deletions(-) diff --git a/openhands-sdk/openhands/sdk/event/conversation_error.py b/openhands-sdk/openhands/sdk/event/conversation_error.py index 499d727e98..08a5806d22 100644 --- a/openhands-sdk/openhands/sdk/event/conversation_error.py +++ b/openhands-sdk/openhands/sdk/event/conversation_error.py @@ -1,9 +1,78 @@ +import re + from pydantic import Field from rich.text import Text from openhands.sdk.event.base import Event +# --------------------------------------------------------------------------- +# Hint rules: list of (pattern, hint_text) pairs. The first matching pattern +# wins. Patterns are matched case-insensitively against ``detail``. +# --------------------------------------------------------------------------- +_HINT_RULES: list[tuple[re.Pattern[str], str]] = [ + # Databricks AI Gateway: endpoint not found in workspace (404). + # Typically means cross-geography routing is disabled, or the endpoint + # has not been deployed in this workspace. + ( + re.compile( + r"\[404\]\s*AI\s+Gateway\s+endpoint\s+['\"]?(\S+?)['\"]?\s+does\s+not\s+exist", + re.IGNORECASE, + ), + ( + "This Databricks endpoint is not available in your workspace.\n" + "Possible reasons:\n" + " • The model requires cross-geography routing, which is not\n" + " enabled in your workspace (contact your admin).\n" + " • The endpoint name is misspelled or not yet deployed.\n" + "Tip: Open Settings → click 'Refresh Models' to see the endpoints\n" + "that are actually available in your workspace, then save a\n" + "different model." + ), + ), + # Databricks: org-level access denied (403 Invalid access to Org). + # Gemini and other cross-geography models route through a Databricks + # global GCP org. The 403 means that routing is not enabled for this + # workspace account. + ( + re.compile(r"\[403\].*Invalid\s+access\s+to\s+Org", re.IGNORECASE), + ( + "Your workspace does not have permission to access this model.\n" + "This error most commonly occurs with Gemini models, which require\n" + "cross-geography routing through a Databricks GCP organisation.\n" + "Action: ask your Databricks account admin to enable\n" + " 'Cross-geography model serving' for your account, or choose a\n" + " different model (Claude / Llama / DBRX) that runs within your\n" + " workspace region.\n" + "Tip: Open Settings → click '↻ Refresh Models' to see only the\n" + "endpoints available in your workspace, then pick a different model." + ), + ), + # Databricks: authentication failure (401 / token expired). + ( + re.compile(r"\[401\].*databricks|databricks.*\[401\]|UNAUTHENTICATED", re.IGNORECASE), + ( + "Databricks authentication failed.\n" + "Tip: Open Settings and re-authenticate (re-run the browser sign-in\n" + "for U2M, or verify your client credentials for M2M)." + ), + ), + # Generic LiteLLM / provider rate-limit. + ( + re.compile(r"\[429\]|rate.?limit|too many requests", re.IGNORECASE), + "The model endpoint returned a rate-limit error. Wait a moment and retry.", + ), +] + + +def _get_hint(detail: str) -> str | None: + """Return the first matching hint for the given error detail, or None.""" + for pattern, hint in _HINT_RULES: + if pattern.search(detail): + return hint + return None + + class ConversationErrorEvent(Event): """ Conversation-level failure that is NOT sent back to the LLM. @@ -34,4 +103,10 @@ def visualize(self) -> Text: content.append(self.code) content.append("\n\nDetail:\n", style="bold") content.append(self.detail) + + hint = _get_hint(self.detail) + if hint: + content.append("\n\nHint:\n", style="bold yellow") + content.append(hint, style="yellow") + return content diff --git a/openhands-sdk/openhands/sdk/llm/providers/databricks/discovery.py b/openhands-sdk/openhands/sdk/llm/providers/databricks/discovery.py index 53160cca87..945f335375 100644 --- a/openhands-sdk/openhands/sdk/llm/providers/databricks/discovery.py +++ b/openhands-sdk/openhands/sdk/llm/providers/databricks/discovery.py @@ -204,6 +204,7 @@ def _curated_entry( CURATED_DATABRICKS_MODELS: tuple[ModelPickerEntry, ...] = ( # ------------------------------------------------------------------ # # Anthropic — Claude (native Anthropic Messages API) + # All live-tested PASS except opus-4-7 (temporarily rate-limited). # ------------------------------------------------------------------ # _curated_entry( "databricks-claude-sonnet-4-6", ProviderFamily.ANTHROPIC, recommended=True @@ -211,10 +212,13 @@ def _curated_entry( _curated_entry("databricks-claude-sonnet-4-5", ProviderFamily.ANTHROPIC), _curated_entry("databricks-claude-haiku-4-5", ProviderFamily.ANTHROPIC), _curated_entry("databricks-claude-opus-4-7", ProviderFamily.ANTHROPIC), + _curated_entry("databricks-claude-opus-4-6", ProviderFamily.ANTHROPIC), _curated_entry("databricks-claude-opus-4-5", ProviderFamily.ANTHROPIC), _curated_entry("databricks-claude-opus-4-1", ProviderFamily.ANTHROPIC), # ------------------------------------------------------------------ # # OpenAI — GPT-5 series (Responses API) and gpt-oss (OpenAI Chat) + # All live-tested PASS. gpt-5-5 / gpt-5-5-pro may be temporarily + # rate-limited (403) on some workspaces. # ------------------------------------------------------------------ # _curated_entry( "databricks-gpt-5-mini", ProviderFamily.OPENAI_RESPONSES, recommended=True @@ -223,16 +227,28 @@ def _curated_entry( _curated_entry("databricks-gpt-5-5", ProviderFamily.OPENAI_RESPONSES), _curated_entry("databricks-gpt-5-4", ProviderFamily.OPENAI_RESPONSES), _curated_entry("databricks-gpt-5-4-mini", ProviderFamily.OPENAI_RESPONSES), + _curated_entry("databricks-gpt-5-4-nano", ProviderFamily.OPENAI_RESPONSES), + _curated_entry("databricks-gpt-5-3-codex", ProviderFamily.OPENAI_RESPONSES), + _curated_entry("databricks-gpt-5-2-codex", ProviderFamily.OPENAI_RESPONSES), + _curated_entry("databricks-gpt-5-2", ProviderFamily.OPENAI_RESPONSES), + _curated_entry("databricks-gpt-5-1", ProviderFamily.OPENAI_RESPONSES), + _curated_entry("databricks-gpt-5-nano", ProviderFamily.OPENAI_RESPONSES), _curated_entry("databricks-gpt-5", ProviderFamily.OPENAI_RESPONSES), _curated_entry("databricks-gpt-oss-120b", ProviderFamily.OPENAI), # ------------------------------------------------------------------ # # Google — Gemini (native generateContent) # ------------------------------------------------------------------ # + # Live-tested PASS: gemini-3-5-flash, gemini-3-1-flash-lite, + # gemini-2-5-flash, gemini-2-5-pro + # gemini-3-flash / gemini-3-pro: NOT available in typical workspaces — + # they require cross-geo routing on global endpoints. They surface via + # live workspace discovery when the endpoint is actually available. + # gemma-3-12b: excluded — 8,192-token context window is below the 16k + # minimum required by OpenHands. _curated_entry( "databricks-gemini-3-5-flash", ProviderFamily.GEMINI, recommended=True ), - _curated_entry("databricks-gemini-3-flash", ProviderFamily.GEMINI), - _curated_entry("databricks-gemini-3-pro", ProviderFamily.GEMINI), + _curated_entry("databricks-gemini-3-1-flash-lite", ProviderFamily.GEMINI), _curated_entry("databricks-gemini-2-5-flash", ProviderFamily.GEMINI), _curated_entry("databricks-gemini-2-5-pro", ProviderFamily.GEMINI), ) From fc0f7355143a7fd6fab260f7ab8c38d068824b86 Mon Sep 17 00:00:00 2001 From: Prasad Kona Date: Sat, 6 Jun 2026 19:25:15 -0700 Subject: [PATCH 15/21] fix(sdk): fix ModelFeatures init signature and graceful permission errors - Add missing required_positional_arg to ModelFeatures.__init__ to match updated signature (fixes TypeError on CLI startup) - Handle PermissionError gracefully when iterating workspace directory in find_third_party_files (fixes crash on macOS TCC-restricted paths) - Update Databricks provider utils for correct base_url resolution --- openhands-sdk/openhands/sdk/context/skills/__init__.py | 8 +++++++- openhands-sdk/openhands/sdk/context/skills/skill.py | 6 +++++- .../openhands/sdk/llm/providers/databricks/utils.py | 4 ++++ openhands-sdk/openhands/sdk/llm/utils/model_features.py | 1 + openhands-sdk/openhands/sdk/skills/utils.py | 9 ++++++++- 5 files changed, 25 insertions(+), 3 deletions(-) diff --git a/openhands-sdk/openhands/sdk/context/skills/__init__.py b/openhands-sdk/openhands/sdk/context/skills/__init__.py index 2ba5aaa06f..b17aed195e 100644 --- a/openhands-sdk/openhands/sdk/context/skills/__init__.py +++ b/openhands-sdk/openhands/sdk/context/skills/__init__.py @@ -6,16 +6,22 @@ """ from openhands.sdk.skills.skill import ( # noqa: F401 DEFAULT_MARKETPLACE_PATH, - PUBLIC_SKILLS_BRANCH, + PUBLIC_SKILLS_REF, PUBLIC_SKILLS_REPO, Skill, load_available_skills, load_skills_from_dir, ) +# Backward-compatible alias: PUBLIC_SKILLS_BRANCH was renamed to +# PUBLIC_SKILLS_REF in the main merge; keep the old name so that +# older installed openhands.agent_server builds can still import it. +PUBLIC_SKILLS_BRANCH = PUBLIC_SKILLS_REF + __all__ = [ "DEFAULT_MARKETPLACE_PATH", "PUBLIC_SKILLS_BRANCH", + "PUBLIC_SKILLS_REF", "PUBLIC_SKILLS_REPO", "Skill", "load_available_skills", diff --git a/openhands-sdk/openhands/sdk/context/skills/skill.py b/openhands-sdk/openhands/sdk/context/skills/skill.py index a405938363..f44da4b059 100644 --- a/openhands-sdk/openhands/sdk/context/skills/skill.py +++ b/openhands-sdk/openhands/sdk/context/skills/skill.py @@ -1,16 +1,20 @@ """Backward-compatible re-exports. Canonical location: openhands.sdk.skills.skill""" from openhands.sdk.skills.skill import ( # noqa: F401 DEFAULT_MARKETPLACE_PATH, - PUBLIC_SKILLS_BRANCH, + PUBLIC_SKILLS_REF, PUBLIC_SKILLS_REPO, Skill, load_available_skills, load_skills_from_dir, ) +# Backward-compatible alias: PUBLIC_SKILLS_BRANCH was renamed to PUBLIC_SKILLS_REF. +PUBLIC_SKILLS_BRANCH = PUBLIC_SKILLS_REF + __all__ = [ "DEFAULT_MARKETPLACE_PATH", "PUBLIC_SKILLS_BRANCH", + "PUBLIC_SKILLS_REF", "PUBLIC_SKILLS_REPO", "Skill", "load_available_skills", diff --git a/openhands-sdk/openhands/sdk/llm/providers/databricks/utils.py b/openhands-sdk/openhands/sdk/llm/providers/databricks/utils.py index 2dbd24c961..64445931f2 100644 --- a/openhands-sdk/openhands/sdk/llm/providers/databricks/utils.py +++ b/openhands-sdk/openhands/sdk/llm/providers/databricks/utils.py @@ -127,11 +127,15 @@ def _raise_non_retryable(response: httpx.Response) -> None: Called immediately (no sleep) for 400/401/403/404/422. """ + raw_text = response.text[:500] if response.text else "" try: body = response.json() except Exception: body = {} msg = map_databricks_error(response.status_code, body) + # Include raw response body in the message when no structured error field was found + if msg.endswith("Unknown error") and raw_text: + msg = f"{msg} | url={response.url} | body={raw_text}" exc_class = STATUS_TO_LITELLM.get(response.status_code, BadRequestError) raise exc_class(msg, model="", llm_provider="databricks") diff --git a/openhands-sdk/openhands/sdk/llm/utils/model_features.py b/openhands-sdk/openhands/sdk/llm/utils/model_features.py index e4c5b3c4de..8d3d631c53 100644 --- a/openhands-sdk/openhands/sdk/llm/utils/model_features.py +++ b/openhands-sdk/openhands/sdk/llm/utils/model_features.py @@ -232,6 +232,7 @@ def get_features(model: str) -> ModelFeatures: force_string_serializer=False, send_reasoning_content=False, supports_prompt_cache_retention=False, + requires_inline_image_data=False, ) return ModelFeatures( diff --git a/openhands-sdk/openhands/sdk/skills/utils.py b/openhands-sdk/openhands/sdk/skills/utils.py index 58b1a87751..5138d1404e 100644 --- a/openhands-sdk/openhands/sdk/skills/utils.py +++ b/openhands-sdk/openhands/sdk/skills/utils.py @@ -278,7 +278,14 @@ def find_third_party_files( files: list[Path] = [] seen_names: set[str] = set() seen_real_paths: set[Path] = set() - for item in repo_root.iterdir(): + try: + dir_items = list(repo_root.iterdir()) + except PermissionError: + logger.debug( + f"Skipping third-party skill discovery in {repo_root}: permission denied" + ) + return files + for item in dir_items: if item.is_file() and item.name.lower() in target_names: # Avoid duplicates (e.g., AGENTS.md and agents.md in same dir) name_lower = item.name.lower() From a578ce9b8e4ce43b3ec31511e7668337f27f2bc6 Mon Sep 17 00:00:00 2001 From: Prasad Kona Date: Sat, 6 Jun 2026 20:24:32 -0700 Subject: [PATCH 16/21] docs(databricks): minor doc updates --- .../openhands/sdk/llm/providers/databricks/discovery.py | 6 +++--- .../openhands/sdk/llm/providers/databricks/models.py | 3 +-- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/openhands-sdk/openhands/sdk/llm/providers/databricks/discovery.py b/openhands-sdk/openhands/sdk/llm/providers/databricks/discovery.py index 945f335375..eee7e53138 100644 --- a/openhands-sdk/openhands/sdk/llm/providers/databricks/discovery.py +++ b/openhands-sdk/openhands/sdk/llm/providers/databricks/discovery.py @@ -3,7 +3,7 @@ Queries GET /api/2.0/serving-endpoints and returns the chat-capable endpoints exposed through the AI Gateway. Three endpoint classes are surfaced: -* ``FOUNDATION_MODEL_API`` — workspace-hosted DBRX / Llama / Claude / Gemini / +* ``FOUNDATION_MODEL_API`` — workspace-hosted Llama / Claude / Gemini / GPT-5 pay-per-token models (native AI Gateway). * ``EXTERNAL_MODEL`` — customer-configured external model endpoints proxied through the gateway (still routed to provider-native APIs). @@ -195,8 +195,8 @@ def _curated_entry( # Curated tier-1 set — Claude / GPT / Gemini only. One "recommended" per family -# (fast + capable), plus a couple of siblings. Intentionally excludes Llama, -# DBRX, and legacy endpoints — those surface automatically via discovery if the +# (fast + capable), plus a couple of siblings. Intentionally excludes Llama +# and legacy endpoints — those surface automatically via discovery if the # workspace has them enabled. # # Last sync with Databricks FMAPI docs: May 2026. diff --git a/openhands-sdk/openhands/sdk/llm/providers/databricks/models.py b/openhands-sdk/openhands/sdk/llm/providers/databricks/models.py index 539fca3b75..3139c10985 100644 --- a/openhands-sdk/openhands/sdk/llm/providers/databricks/models.py +++ b/openhands-sdk/openhands/sdk/llm/providers/databricks/models.py @@ -213,8 +213,7 @@ def detect_family(model: str) -> ProviderFamily: this rule — Databricks routes all numbered GPT generations through the OpenAI Responses API (``/openai/v1/responses``). 4. Everything else → :attr:`ProviderFamily.OPENAI` (universal - MLflow Chat Completions — safe default for ``gpt-oss``, Llama, - Mistral, DBRX, Qwen, …) + MLflow Chat Completions — safe default for ``gpt-oss``, Llama, …) """ name = _bare_name(model) if "claude" in name: From a1c288df089bdc06adf30a8ed426eabf5b2ef4dd Mon Sep 17 00:00:00 2001 From: Prasad Kona Date: Sat, 6 Jun 2026 21:02:24 -0700 Subject: [PATCH 17/21] refactor(databricks): remove out-of-scope changes from connector PR MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Restore skills modules and the safety_settings deprecation validator to upstream main — these were fork-drift artifacts unrelated to the Databricks provider. The connector does not depend on them. --- .../openhands/sdk/context/skills/__init__.py | 33 ++++--------------- .../openhands/sdk/context/skills/skill.py | 22 ------------- .../openhands/sdk/context/skills/utils.py | 10 ------ openhands-sdk/openhands/sdk/llm/llm.py | 14 -------- openhands-sdk/openhands/sdk/skills/utils.py | 9 +---- 5 files changed, 7 insertions(+), 81 deletions(-) delete mode 100644 openhands-sdk/openhands/sdk/context/skills/skill.py delete mode 100644 openhands-sdk/openhands/sdk/context/skills/utils.py diff --git a/openhands-sdk/openhands/sdk/context/skills/__init__.py b/openhands-sdk/openhands/sdk/context/skills/__init__.py index b17aed195e..80a22b9597 100644 --- a/openhands-sdk/openhands/sdk/context/skills/__init__.py +++ b/openhands-sdk/openhands/sdk/context/skills/__init__.py @@ -1,29 +1,8 @@ -"""Backward-compatible re-exports from openhands.sdk.skills. +"""Removed: Use openhands.sdk.skills instead. -The canonical location is ``openhands.sdk.skills``. These aliases are -kept so that the installed ``openhands.agent_server`` package (which -may be pinned to an older release) can still import from the old path. -""" -from openhands.sdk.skills.skill import ( # noqa: F401 - DEFAULT_MARKETPLACE_PATH, - PUBLIC_SKILLS_REF, - PUBLIC_SKILLS_REPO, - Skill, - load_available_skills, - load_skills_from_dir, -) - -# Backward-compatible alias: PUBLIC_SKILLS_BRANCH was renamed to -# PUBLIC_SKILLS_REF in the main merge; keep the old name so that -# older installed openhands.agent_server builds can still import it. -PUBLIC_SKILLS_BRANCH = PUBLIC_SKILLS_REF +This module previously provided backward-compatible re-exports of skill +classes. Those shims were deprecated in 1.16.0 and removed in 1.21.0. -__all__ = [ - "DEFAULT_MARKETPLACE_PATH", - "PUBLIC_SKILLS_BRANCH", - "PUBLIC_SKILLS_REF", - "PUBLIC_SKILLS_REPO", - "Skill", - "load_available_skills", - "load_skills_from_dir", -] +Migration: + from openhands.sdk.skills import Skill, load_skills_from_dir +""" diff --git a/openhands-sdk/openhands/sdk/context/skills/skill.py b/openhands-sdk/openhands/sdk/context/skills/skill.py deleted file mode 100644 index f44da4b059..0000000000 --- a/openhands-sdk/openhands/sdk/context/skills/skill.py +++ /dev/null @@ -1,22 +0,0 @@ -"""Backward-compatible re-exports. Canonical location: openhands.sdk.skills.skill""" -from openhands.sdk.skills.skill import ( # noqa: F401 - DEFAULT_MARKETPLACE_PATH, - PUBLIC_SKILLS_REF, - PUBLIC_SKILLS_REPO, - Skill, - load_available_skills, - load_skills_from_dir, -) - -# Backward-compatible alias: PUBLIC_SKILLS_BRANCH was renamed to PUBLIC_SKILLS_REF. -PUBLIC_SKILLS_BRANCH = PUBLIC_SKILLS_REF - -__all__ = [ - "DEFAULT_MARKETPLACE_PATH", - "PUBLIC_SKILLS_BRANCH", - "PUBLIC_SKILLS_REF", - "PUBLIC_SKILLS_REPO", - "Skill", - "load_available_skills", - "load_skills_from_dir", -] diff --git a/openhands-sdk/openhands/sdk/context/skills/utils.py b/openhands-sdk/openhands/sdk/context/skills/utils.py deleted file mode 100644 index 88bf458304..0000000000 --- a/openhands-sdk/openhands/sdk/context/skills/utils.py +++ /dev/null @@ -1,10 +0,0 @@ -"""Backward-compatible re-exports. Canonical location: openhands.sdk.skills.utils""" -from openhands.sdk.skills.utils import ( # noqa: F401 - get_skills_cache_dir, - update_skills_repository, -) - -__all__ = [ - "get_skills_cache_dir", - "update_skills_repository", -] diff --git a/openhands-sdk/openhands/sdk/llm/llm.py b/openhands-sdk/openhands/sdk/llm/llm.py index 1c53abe717..0048c1f414 100644 --- a/openhands-sdk/openhands/sdk/llm/llm.py +++ b/openhands-sdk/openhands/sdk/llm/llm.py @@ -543,20 +543,6 @@ def _dispatch_to_provider_subclass( return sub.model_validate(data, context=info.context) return handler(data) - @field_validator("safety_settings", mode="before", check_fields=False) - @classmethod - def _warn_safety_settings_deprecated( - cls, v: list[dict[str, str]] | None - ) -> list[dict[str, str]] | None: - if v is not None: - warn_deprecated( - "LLM.safety_settings", - deprecated_in="1.15.0", - removed_in="1.20.0", - details="Safety settings are no longer applied.", - ) - return v - @field_validator( "api_key", "aws_access_key_id", "aws_secret_access_key", "aws_session_token" ) diff --git a/openhands-sdk/openhands/sdk/skills/utils.py b/openhands-sdk/openhands/sdk/skills/utils.py index 5138d1404e..58b1a87751 100644 --- a/openhands-sdk/openhands/sdk/skills/utils.py +++ b/openhands-sdk/openhands/sdk/skills/utils.py @@ -278,14 +278,7 @@ def find_third_party_files( files: list[Path] = [] seen_names: set[str] = set() seen_real_paths: set[Path] = set() - try: - dir_items = list(repo_root.iterdir()) - except PermissionError: - logger.debug( - f"Skipping third-party skill discovery in {repo_root}: permission denied" - ) - return files - for item in dir_items: + for item in repo_root.iterdir(): if item.is_file() and item.name.lower() in target_names: # Avoid duplicates (e.g., AGENTS.md and agents.md in same dir) name_lower = item.name.lower() From 592f8445b333cda93e5151b9e19eed6c297178e7 Mon Sep 17 00:00:00 2001 From: Prasad Kona Date: Sun, 7 Jun 2026 08:25:10 -0700 Subject: [PATCH 18/21] feat(databricks): add shared U2M PKCE helpers consumed by web + CLI Consolidate the Authorization Code + PKCE browser-login primitives into a single SDK module so the OpenHands web app and CLI no longer maintain separate copies. Provides generate_pkce, build_authorize_url, and both sync and async code-for-token exchange helpers, exported from the databricks provider package. Bumps SDK to 1.27.0. --- .../sdk/llm/providers/databricks/__init__.py | 12 ++ .../sdk/llm/providers/databricks/pkce.py | 166 +++++++++++++++ openhands-sdk/pyproject.toml | 2 +- .../sdk/llm/providers/databricks/test_pkce.py | 201 ++++++++++++++++++ 4 files changed, 380 insertions(+), 1 deletion(-) create mode 100644 openhands-sdk/openhands/sdk/llm/providers/databricks/pkce.py create mode 100644 tests/sdk/llm/providers/databricks/test_pkce.py diff --git a/openhands-sdk/openhands/sdk/llm/providers/databricks/__init__.py b/openhands-sdk/openhands/sdk/llm/providers/databricks/__init__.py index d07f770d94..1883589abd 100644 --- a/openhands-sdk/openhands/sdk/llm/providers/databricks/__init__.py +++ b/openhands-sdk/openhands/sdk/llm/providers/databricks/__init__.py @@ -65,8 +65,15 @@ detect_family, pick_family_from_api_types, ) +from openhands.sdk.llm.providers.databricks.pkce import ( + async_exchange_code_for_tokens, + build_authorize_url, + exchange_code_for_tokens, + generate_pkce, +) from openhands.sdk.llm.providers.databricks.settings_bridge import kwargs_from_settings + __all__ = [ # LLM "DatabricksLLM", @@ -79,6 +86,11 @@ "AuthStrategy", "DatabricksCredentials", "StoredU2MTokens", + # U2M browser-login PKCE primitives (shared by web + CLI) + "generate_pkce", + "build_authorize_url", + "exchange_code_for_tokens", + "async_exchange_code_for_tokens", # Discovery "DiscoveredEndpoint", "list_chat_endpoints", diff --git a/openhands-sdk/openhands/sdk/llm/providers/databricks/pkce.py b/openhands-sdk/openhands/sdk/llm/providers/databricks/pkce.py new file mode 100644 index 0000000000..68cd7dc8c6 --- /dev/null +++ b/openhands-sdk/openhands/sdk/llm/providers/databricks/pkce.py @@ -0,0 +1,166 @@ +"""Databricks U2M OAuth PKCE primitives (shared by web + CLI front-ends). + +These are the dependency-light helpers for the interactive *browser login* +(Authorization Code + PKCE) flow: + +* :func:`generate_pkce` — verifier / S256 challenge pair. +* :func:`build_authorize_url` — Databricks OIDC ``/authorize`` URL. +* :func:`exchange_code_for_tokens` — sync code → tokens exchange. +* :func:`async_exchange_code_for_tokens` — async variant for event-loop callers. + +The provider's :mod:`.auth` module owns token *refresh*; this module owns the +one-time *login*. Both the OpenHands web app and the OpenHands CLI consume these +helpers so the PKCE logic lives in exactly one place. + +The returned token dict round-trips through +:class:`~openhands.sdk.llm.providers.databricks.models.StoredU2MTokens`: +``access_token``, ``refresh_token``, ``expires_at``, ``client_id``, ``host``. + +No ``litellm`` / FastAPI imports here — kept minimal so both front-ends (which +may pin different framework versions) can import it cheaply. +""" + +from __future__ import annotations + +import base64 +import hashlib +import secrets +import time +from typing import Any +from urllib.parse import urlencode + +import httpx + +from openhands.sdk.llm.providers.databricks.utils import USER_AGENT + + +_TOKEN_TIMEOUT_S = 15.0 +_DEFAULT_EXPIRES_IN = 3600 + + +def generate_pkce() -> tuple[str, str]: + """Return ``(verifier, challenge)`` where challenge is S256 of verifier.""" + verifier = base64.urlsafe_b64encode(secrets.token_bytes(32)).rstrip(b"=").decode() + digest = hashlib.sha256(verifier.encode()).digest() + challenge = base64.urlsafe_b64encode(digest).rstrip(b"=").decode() + return verifier, challenge + + +def build_authorize_url( + host: str, + client_id: str, + redirect_uri: str, + state: str, + challenge: str, +) -> str: + """Build the Databricks OIDC authorize URL with PKCE (S256).""" + host = host.rstrip("/") + params = { + "response_type": "code", + "client_id": client_id, + "redirect_uri": redirect_uri, + "scope": "all-apis offline_access", + "state": state, + "code_challenge": challenge, + "code_challenge_method": "S256", + } + return f"{host}/oidc/v1/authorize?{urlencode(params)}" + + +def _build_token_request( + host: str, + client_id: str, + redirect_uri: str, + code: str, + verifier: str, + client_secret: str | None, +) -> tuple[str, dict[str, str]]: + """Return ``(token_url, form_data)`` for the code → token exchange. + + ``client_secret`` is required for **confidential** OAuth apps (apps with a + secret registered in Databricks App connections). Public PKCE apps omit it; + omitting it for a confidential app returns ``{"error": "invalid_client"}``. + """ + host = host.rstrip("/") + token_data: dict[str, str] = { + "grant_type": "authorization_code", + "code": code, + "redirect_uri": redirect_uri, + "client_id": client_id, + "code_verifier": verifier, + } + if client_secret: + token_data["client_secret"] = client_secret + return f"{host}/oidc/v1/token", token_data + + +def _to_stored_payload( + data: dict[str, Any], client_id: str, host: str +) -> dict[str, Any]: + """Shape a Databricks token response into a ``StoredU2MTokens``-compatible dict.""" + return { + "access_token": data["access_token"], + "refresh_token": data.get("refresh_token", ""), + "expires_at": time.time() + data.get("expires_in", _DEFAULT_EXPIRES_IN), + "client_id": client_id, + "host": host.rstrip("/"), + } + + +def exchange_code_for_tokens( + host: str, + client_id: str, + redirect_uri: str, + code: str, + verifier: str, + client_secret: str | None = None, +) -> dict[str, Any]: + """Exchange an authorization code for tokens (synchronous). + + Sends the PWAF ``User-Agent`` on the token request. Returns a dict + compatible with ``StoredU2MTokens.model_validate``. + + Raises: + httpx.HTTPStatusError: if the token endpoint returns a non-2xx status. + """ + token_url, token_data = _build_token_request( + host, client_id, redirect_uri, code, verifier, client_secret + ) + resp = httpx.post( + token_url, + data=token_data, + headers={"User-Agent": USER_AGENT}, + timeout=_TOKEN_TIMEOUT_S, + ) + resp.raise_for_status() + return _to_stored_payload(resp.json(), client_id, host) + + +async def async_exchange_code_for_tokens( + host: str, + client_id: str, + redirect_uri: str, + code: str, + verifier: str, + client_secret: str | None = None, +) -> dict[str, Any]: + """Exchange an authorization code for tokens (asynchronous). + + Identical to :func:`exchange_code_for_tokens` but uses + ``httpx.AsyncClient`` so it does not block the event loop when called from + an ``async`` request handler (e.g. the web app's OAuth callback route). + + Raises: + httpx.HTTPStatusError: if the token endpoint returns a non-2xx status. + """ + token_url, token_data = _build_token_request( + host, client_id, redirect_uri, code, verifier, client_secret + ) + async with httpx.AsyncClient(timeout=_TOKEN_TIMEOUT_S) as client: + resp = await client.post( + token_url, + data=token_data, + headers={"User-Agent": USER_AGENT}, + ) + resp.raise_for_status() + return _to_stored_payload(resp.json(), client_id, host) diff --git a/openhands-sdk/pyproject.toml b/openhands-sdk/pyproject.toml index 0c4a9c9fc8..22851fbb2b 100644 --- a/openhands-sdk/pyproject.toml +++ b/openhands-sdk/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "openhands-sdk" -version = "1.26.0" +version = "1.27.0" description = "OpenHands SDK - Core functionality for building AI agents" requires-python = ">=3.12" diff --git a/tests/sdk/llm/providers/databricks/test_pkce.py b/tests/sdk/llm/providers/databricks/test_pkce.py new file mode 100644 index 0000000000..940031a52d --- /dev/null +++ b/tests/sdk/llm/providers/databricks/test_pkce.py @@ -0,0 +1,201 @@ +"""Tests for the shared U2M OAuth PKCE primitives. + +Covers: PKCE verifier/challenge S256 correctness, authorize-URL parameters, +and the sync + async code → token exchange (happy path, confidential-app +secret, refresh-token default, PWAF User-Agent header, and error propagation). +""" + +from __future__ import annotations + +import base64 +import hashlib +from unittest.mock import AsyncMock, MagicMock, patch +from urllib.parse import parse_qs, urlparse + +import httpx +import pytest + +from openhands.sdk.llm.providers.databricks.models import StoredU2MTokens +from openhands.sdk.llm.providers.databricks.pkce import ( + async_exchange_code_for_tokens, + build_authorize_url, + exchange_code_for_tokens, + generate_pkce, +) +from openhands.sdk.llm.providers.databricks.utils import USER_AGENT + + +_HOST = "https://adb-123.azuredatabricks.net" +_CLIENT_ID = "oauth-app-client-id" +_REDIRECT = "http://localhost:3000/auth/databricks/callback" + + +# --------------------------------------------------------------------------- +# generate_pkce +# --------------------------------------------------------------------------- + + +def test_generate_pkce_challenge_is_s256_of_verifier() -> None: + verifier, challenge = generate_pkce() + expected = ( + base64.urlsafe_b64encode(hashlib.sha256(verifier.encode()).digest()) + .rstrip(b"=") + .decode() + ) + assert challenge == expected + + +def test_generate_pkce_no_base64_padding() -> None: + verifier, challenge = generate_pkce() + assert "=" not in verifier + assert "=" not in challenge + + +def test_generate_pkce_is_random() -> None: + assert generate_pkce()[0] != generate_pkce()[0] + + +# --------------------------------------------------------------------------- +# build_authorize_url +# --------------------------------------------------------------------------- + + +def test_build_authorize_url_params() -> None: + url = build_authorize_url(_HOST, _CLIENT_ID, _REDIRECT, "state123", "chal456") + parsed = urlparse(url) + assert parsed.scheme == "https" + assert parsed.path == "/oidc/v1/authorize" + qs = parse_qs(parsed.query) + assert qs["response_type"] == ["code"] + assert qs["client_id"] == [_CLIENT_ID] + assert qs["redirect_uri"] == [_REDIRECT] + assert qs["scope"] == ["all-apis offline_access"] + assert qs["state"] == ["state123"] + assert qs["code_challenge"] == ["chal456"] + assert qs["code_challenge_method"] == ["S256"] + + +def test_build_authorize_url_strips_trailing_slash() -> None: + url = build_authorize_url(_HOST + "/", _CLIENT_ID, _REDIRECT, "s", "c") + assert url.startswith(f"{_HOST}/oidc/v1/authorize?") + + +# --------------------------------------------------------------------------- +# exchange_code_for_tokens (sync) +# --------------------------------------------------------------------------- + + +def _mock_token_response() -> MagicMock: + resp = MagicMock() + resp.raise_for_status.return_value = None + resp.json.return_value = { + "access_token": "access-abc", + "refresh_token": "refresh-xyz", + "expires_in": 3600, + } + return resp + + +def test_exchange_code_returns_stored_token_shape() -> None: + with patch("httpx.post", return_value=_mock_token_response()) as mock_post: + payload = exchange_code_for_tokens( + _HOST, _CLIENT_ID, _REDIRECT, "auth-code", "verifier-1" + ) + + # Round-trips through the StoredU2MTokens model. + stored = StoredU2MTokens.model_validate(payload) + assert stored.access_token == "access-abc" + assert stored.refresh_token == "refresh-xyz" + assert stored.client_id == _CLIENT_ID + assert stored.host == _HOST + assert stored.expires_at > 0 + + # Token endpoint + PWAF User-Agent + correct form fields. + args, kwargs = mock_post.call_args + assert args[0] == f"{_HOST}/oidc/v1/token" + assert kwargs["headers"]["User-Agent"] == USER_AGENT + assert kwargs["data"]["grant_type"] == "authorization_code" + assert kwargs["data"]["code"] == "auth-code" + assert kwargs["data"]["code_verifier"] == "verifier-1" + assert "client_secret" not in kwargs["data"] + + +def test_exchange_code_includes_client_secret_for_confidential_app() -> None: + with patch("httpx.post", return_value=_mock_token_response()) as mock_post: + exchange_code_for_tokens( + _HOST, _CLIENT_ID, _REDIRECT, "code", "verifier", + client_secret="super-secret", + ) + assert mock_post.call_args.kwargs["data"]["client_secret"] == "super-secret" + + +def test_exchange_code_defaults_missing_refresh_token() -> None: + resp = MagicMock() + resp.raise_for_status.return_value = None + resp.json.return_value = {"access_token": "a", "expires_in": 1200} + with patch("httpx.post", return_value=resp): + payload = exchange_code_for_tokens( + _HOST, _CLIENT_ID, _REDIRECT, "code", "verifier" + ) + assert payload["refresh_token"] == "" + + +def test_exchange_code_propagates_http_error() -> None: + resp = MagicMock() + resp.raise_for_status.side_effect = httpx.HTTPStatusError( + "invalid_client", request=MagicMock(), response=MagicMock() + ) + with patch("httpx.post", return_value=resp): + with pytest.raises(httpx.HTTPStatusError): + exchange_code_for_tokens(_HOST, _CLIENT_ID, _REDIRECT, "code", "verifier") + + +# --------------------------------------------------------------------------- +# async_exchange_code_for_tokens +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_async_exchange_code_returns_stored_token_shape() -> None: + resp = MagicMock() + resp.raise_for_status.return_value = None + resp.json.return_value = { + "access_token": "access-async", + "refresh_token": "refresh-async", + "expires_in": 3600, + } + mock_client = AsyncMock() + mock_client.__aenter__.return_value = mock_client + mock_client.post.return_value = resp + + with patch("httpx.AsyncClient", return_value=mock_client): + payload = await async_exchange_code_for_tokens( + _HOST, _CLIENT_ID, _REDIRECT, "code", "verifier", + client_secret="conf-secret", + ) + + stored = StoredU2MTokens.model_validate(payload) + assert stored.access_token == "access-async" + assert stored.host == _HOST + + # PWAF User-Agent + confidential secret forwarded on the async path too. + kwargs = mock_client.post.call_args.kwargs + assert kwargs["headers"]["User-Agent"] == USER_AGENT + assert kwargs["data"]["client_secret"] == "conf-secret" + + +@pytest.mark.asyncio +async def test_async_exchange_code_propagates_http_error() -> None: + resp = MagicMock() + resp.raise_for_status.side_effect = httpx.HTTPStatusError( + "boom", request=MagicMock(), response=MagicMock() + ) + mock_client = AsyncMock() + mock_client.__aenter__.return_value = mock_client + mock_client.post.return_value = resp + + with patch("httpx.AsyncClient", return_value=mock_client): + with pytest.raises(httpx.HTTPStatusError): + await async_exchange_code_for_tokens( + _HOST, _CLIENT_ID, _REDIRECT, "code", "verifier" + ) From 848bda301561c81915b3df31165593b43d3c40c4 Mon Sep 17 00:00:00 2001 From: Prasad Kona Date: Sun, 7 Jun 2026 08:25:11 -0700 Subject: [PATCH 19/21] test(databricks): drop stale model cases and refresh curated assertions Remove obsolete model entries from family-detection parametrizations and update curated-model assertions to current Foundation Model API names. --- .../llm/providers/databricks/test_discovery.py | 6 +++--- tests/sdk/llm/providers/databricks/test_llm.py | 16 +--------------- .../sdk/llm/providers/databricks/test_models.py | 2 -- 3 files changed, 4 insertions(+), 20 deletions(-) diff --git a/tests/sdk/llm/providers/databricks/test_discovery.py b/tests/sdk/llm/providers/databricks/test_discovery.py index 71b3e330cd..c6068d4c87 100644 --- a/tests/sdk/llm/providers/databricks/test_discovery.py +++ b/tests/sdk/llm/providers/databricks/test_discovery.py @@ -483,9 +483,9 @@ def _family_order(section: list[ModelPickerEntry]) -> list[str]: def test_get_picker_entries_merges_discovered_on_top_of_curated() -> None: """Discovery adds non-curated endpoints; curated entries get live signals.""" - # One endpoint overlaps curated (claude-sonnet-4-5), two are new. + # One endpoint overlaps curated (claude-sonnet-4-6), two are new. payload = _make_endpoints_payload([ - _fmapi_ep("databricks-claude-sonnet-4-5"), # overlaps curated + _fmapi_ep("databricks-claude-sonnet-4-6"), # overlaps curated _fmapi_ep("databricks-meta-llama-4-maverick"), # discovered only _external_ep("customer-private-gpt"), # external-model discovered only ]) @@ -497,7 +497,7 @@ def test_get_picker_entries_merges_discovered_on_top_of_curated() -> None: # Overlap: curated entry kept recommended + its opinionated family, but # gained "curated+discovered" source and the live endpoint_type/ready. - overlap = by_qn["databricks/databricks-claude-sonnet-4-5"] + overlap = by_qn["databricks/databricks-claude-sonnet-4-6"] assert overlap.source == "curated+discovered" assert overlap.family is ProviderFamily.ANTHROPIC assert overlap.recommended is True diff --git a/tests/sdk/llm/providers/databricks/test_llm.py b/tests/sdk/llm/providers/databricks/test_llm.py index c1490df9dc..e22f431526 100644 --- a/tests/sdk/llm/providers/databricks/test_llm.py +++ b/tests/sdk/llm/providers/databricks/test_llm.py @@ -29,8 +29,7 @@ _HOST = "https://adb-123.azuredatabricks.net" _MODEL_LLAMA = "databricks/databricks-meta-llama-3-3-70b-instruct" -_MODEL_DBRX = "databricks/databricks-dbrx-instruct" -_MODEL_CLAUDE = "databricks/databricks-claude-3-7-sonnet" +_MODEL_CLAUDE = "databricks/databricks-claude-sonnet-4" _MODEL_UNKNOWN = "databricks/my-custom-finetuned-model" @@ -168,12 +167,6 @@ def test_context_window_llama_70b() -> None: assert llm.max_input_tokens == DATABRICKS_CONTEXT_WINDOWS[_MODEL_LLAMA] -def test_context_window_dbrx() -> None: - llm = _make_llm(model=_MODEL_DBRX) - assert llm.max_input_tokens == DATABRICKS_CONTEXT_WINDOWS[_MODEL_DBRX] - assert llm.max_input_tokens == 32_768 - - def test_context_window_claude() -> None: """Claude-based Databricks models have 200K context window.""" llm = _make_llm(model=_MODEL_CLAUDE) @@ -186,12 +179,6 @@ def test_context_window_unknown_model_fallback() -> None: assert llm.max_input_tokens == 128_000 -def test_max_output_tokens_dbrx() -> None: - llm = _make_llm(model=_MODEL_DBRX) - assert llm.max_output_tokens == DATABRICKS_MAX_OUTPUT[_MODEL_DBRX] - assert llm.max_output_tokens == 4_096 - - def test_max_output_tokens_claude() -> None: llm = _make_llm(model=_MODEL_CLAUDE) assert llm.max_output_tokens == 8_192 @@ -386,7 +373,6 @@ def test_auth_method_property_reflects_pat_construction() -> None: "model,expected", [ ("databricks/databricks-meta-llama-3-3-70b-instruct", ProviderFamily.OPENAI), - ("databricks/databricks-dbrx-instruct", ProviderFamily.OPENAI), ("databricks/databricks-gpt-oss-120b", ProviderFamily.OPENAI), ("databricks/databricks-claude-sonnet-4-5", ProviderFamily.ANTHROPIC), ("databricks/databricks-claude-opus-4-6", ProviderFamily.ANTHROPIC), diff --git a/tests/sdk/llm/providers/databricks/test_models.py b/tests/sdk/llm/providers/databricks/test_models.py index d3a76905d0..948abe8166 100644 --- a/tests/sdk/llm/providers/databricks/test_models.py +++ b/tests/sdk/llm/providers/databricks/test_models.py @@ -77,8 +77,6 @@ def test_provider_family_openai_is_default_fallback() -> None: ("gpt-oss-120b", ProviderFamily.OPENAI), ("databricks-gpt-oss-120b", ProviderFamily.OPENAI), ("databricks-meta-llama-3-3-70b-instruct", ProviderFamily.OPENAI), - ("databricks-dbrx-instruct", ProviderFamily.OPENAI), - ("mistral-7b-instruct", ProviderFamily.OPENAI), # Case-insensitive ("DATABRICKS-CLAUDE-SONNET-4-5", ProviderFamily.ANTHROPIC), From 90a1a69aea61853d7702dd1aa8fcf39b9ac38a39 Mon Sep 17 00:00:00 2001 From: Prasad Kona Date: Sun, 7 Jun 2026 10:02:13 -0700 Subject: [PATCH 20/21] docs(databricks): document pkce + settings_bridge modules and shared PKCE helpers Bring the provider README in sync with the code: add pkce.py and settings_bridge.py to the module-layout table, note the __init__ now exports the PKCE helpers, and add an Authentication paragraph describing the shared U2M browser-login helpers (generate_pkce / build_authorize_url / exchange_code_for_tokens) consumed by both the web backend and the CLI. --- .../sdk/llm/providers/databricks/README.md | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/openhands-sdk/openhands/sdk/llm/providers/databricks/README.md b/openhands-sdk/openhands/sdk/llm/providers/databricks/README.md index 2f983b4ed8..54da734b80 100644 --- a/openhands-sdk/openhands/sdk/llm/providers/databricks/README.md +++ b/openhands-sdk/openhands/sdk/llm/providers/databricks/README.md @@ -69,6 +69,15 @@ auth details (token caching, refresh policies, CLI profile selection). Use `llm.auth_method` to see which strategy resolved. +The interactive **U2M browser login** (Authorization Code + PKCE) is built from +the dependency-light helpers in `pkce.py` — `generate_pkce()`, +`build_authorize_url()`, and `exchange_code_for_tokens()` / +`async_exchange_code_for_tokens()` — all re-exported from this package's +`__init__`. These are the single source of truth consumed by both the web +backend and the OpenHands-CLI, so the three front-ends can't drift apart. Each +caller supplies its own local redirect/callback handling and passes the +resulting tokens back in via `stored_u2m_tokens`. + ## Discovery (picker UIs) Listing AI-Gateway-shaped chat endpoints: @@ -103,12 +112,14 @@ with a 5-minute TTL cache. | Module | Role | |---|---| -| `__init__.py` | Public API (`DatabricksLLM`, `ProviderFamily`, discovery, auth types) | +| `__init__.py` | Public API (`DatabricksLLM`, `ProviderFamily`, discovery, auth types, PKCE helpers) | | `llm.py` | `DatabricksLLM` — Pydantic subclass of `LLM`; transport override | | `client.py` | `DatabricksFMAPIClient` — `httpx` transport, family dispatch, retry | | `native.py` | Per-family `to_native` / `from_native` adapters (request/response shaping) | | `models.py` | `ProviderFamily`, `AIGatewayPaths`, routing functions, token containers | | `auth.py` | Five credential strategies, `resolve_credentials()`, token providers | +| `pkce.py` | Shared U2M browser-login helpers (`generate_pkce`, `build_authorize_url`, sync/async `exchange_code_for_tokens`) — single source of truth for web + CLI | +| `settings_bridge.py` | `kwargs_from_settings()` — the one path that turns user settings (env / DB / TUI) into `create_llm(...)` kwargs, shared by backend and CLI | | `discovery.py` | `list_chat_endpoints` / `list_foundation_models` + TTL cache | | `utils.py` | `USER_AGENT`, `DatabricksTimeouts`, retry/backoff, error mapping | From 9d058e5627bbff4937734c53695bb4d7a8b7e5f9 Mon Sep 17 00:00:00 2001 From: Prasad Kona Date: Mon, 8 Jun 2026 13:18:44 -0700 Subject: [PATCH 21/21] docs(databricks): note alignment with Databricks ucode credential model Add an Alignment with ucode section to the provider README: the connector's PROFILE/UNIFIED/U2M strategies let an OpenHands agent reach AI Gateway the same key-free, governed, workspace-credential way Databricks ucode does. --- .../sdk/llm/providers/databricks/README.md | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/openhands-sdk/openhands/sdk/llm/providers/databricks/README.md b/openhands-sdk/openhands/sdk/llm/providers/databricks/README.md index 54da734b80..8a3fa7a416 100644 --- a/openhands-sdk/openhands/sdk/llm/providers/databricks/README.md +++ b/openhands-sdk/openhands/sdk/llm/providers/databricks/README.md @@ -78,6 +78,19 @@ backend and the OpenHands-CLI, so the three front-ends can't drift apart. Each caller supplies its own local redirect/callback handling and passes the resulting tokens back in via `stored_u2m_tokens`. +## Alignment with Databricks `ucode` + +This connector follows the same credential model as +[Databricks `ucode`](https://github.com/databricks/ucode) — the *Unity AI +Gateway Coding CLI* — which routes coding agents through the Databricks AI +Gateway using workspace credentials, **no API keys required**. The `PROFILE` +and `UNIFIED` strategies read the workspace login a developer has already +established (`databricks auth login` / `~/.databrickscfg`), and `U2M` provides +interactive browser OAuth. An OpenHands agent can therefore reach AI Gateway the +same key-free, governed way `ucode` does — reusing the existing workspace +session rather than minting a separate token — over one consistent path to the +gateway (and the Unity Catalog–governed resources behind it). + ## Discovery (picker UIs) Listing AI-Gateway-shaped chat endpoints: