Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions integrations/langchain/src/databricks_langchain/chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,14 @@ class ChatDatabricks(BaseChatModel):
use_responses_api=True,
)

For AI Gateway V2 endpoints, set ``use_ai_gateway=True`` (MLflow API) or
``use_ai_gateway_native_api=True`` (native OpenAI API):
.. code-block:: python
llm = ChatDatabricks(
model="my-gateway-endpoint",
use_ai_gateway=True,
)

**Invoke**:

.. code-block:: python
Expand Down Expand Up @@ -294,6 +302,12 @@ class GetPopulation(BaseModel):
"""Any extra parameters to pass to the endpoint."""
use_responses_api: bool = False
"""Whether to use the Responses API to format inputs and outputs."""
use_ai_gateway: bool = False
"""If True, route requests through AI Gateway V2 using the MLflow API
(``{host}/ai-gateway/mlflow/v1``). Cannot be combined with use_ai_gateway_native_api."""
use_ai_gateway_native_api: bool = False
"""If True, route requests through AI Gateway V2 using the native OpenAI API
(``{host}/ai-gateway/openai/v1``). Cannot be combined with use_ai_gateway."""
timeout: Optional[float] = None
"""Timeout in seconds for the HTTP request. If None, uses the default timeout."""
max_retries: Optional[int] = None
Expand Down Expand Up @@ -326,6 +340,10 @@ def _get_client_kwargs(self) -> Dict[str, Any]:
openai_kwargs["timeout"] = self.timeout
if self.max_retries is not None:
openai_kwargs["max_retries"] = self.max_retries
if self.use_ai_gateway:
openai_kwargs["use_ai_gateway"] = True
if self.use_ai_gateway_native_api:
openai_kwargs["use_ai_gateway_native_api"] = True
return openai_kwargs

@cached_property
Expand Down
61 changes: 46 additions & 15 deletions integrations/langchain/src/databricks_langchain/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,25 +22,43 @@ def get_deployment_client(target_uri: str) -> Any:
) from e


def get_openai_client(workspace_client: Any = None, **kwargs) -> OpenAI:
def get_openai_client(
workspace_client: Any = None,
use_ai_gateway: bool = False,
use_ai_gateway_native_api: bool = False,
**kwargs,
) -> OpenAI:
"""Get an OpenAI client configured for Databricks.

Args:
workspace_client: Optional WorkspaceClient instance to use for authentication.
If not provided, creates a default WorkspaceClient.
use_ai_gateway: If True, route requests through AI Gateway V2 using the MLflow
API (``{host}/ai-gateway/mlflow/v1``). Cannot be combined with
``use_ai_gateway_native_api``.
use_ai_gateway_native_api: If True, route requests through AI Gateway V2 using
the native OpenAI-compatible API (``{host}/ai-gateway/openai/v1``). Cannot
be combined with ``use_ai_gateway``.
**kwargs: Additional keyword arguments to pass to get_open_ai_client(),
such as timeout and max_retries.
such as timeout and max_retries. Ignored when ``use_ai_gateway`` or
``use_ai_gateway_native_api`` is True.
"""
try:
from databricks.sdk import WorkspaceClient

# If workspace_client is provided, use it directly
if workspace_client is not None:
return workspace_client.serving_endpoints.get_open_ai_client(**kwargs)
else:
# Otherwise, create default workspace client
if workspace_client is None:
workspace_client = WorkspaceClient()
return workspace_client.serving_endpoints.get_open_ai_client(**kwargs)

if use_ai_gateway or use_ai_gateway_native_api:
from databricks_openai import DatabricksOpenAI

return DatabricksOpenAI(
workspace_client=workspace_client,
use_ai_gateway=use_ai_gateway,
use_ai_gateway_native_api=use_ai_gateway_native_api,
)

return workspace_client.serving_endpoints.get_open_ai_client(**kwargs)

except ImportError as e:
raise ImportError(
Expand All @@ -50,24 +68,37 @@ def get_openai_client(workspace_client: Any = None, **kwargs) -> OpenAI:
) from e


def get_async_openai_client(workspace_client: Any = None, **kwargs) -> AsyncOpenAI:
def get_async_openai_client(
workspace_client: Any = None,
use_ai_gateway: bool = False,
use_ai_gateway_native_api: bool = False,
**kwargs,
) -> AsyncOpenAI:
"""Get an async OpenAI client configured for Databricks using databricks-openai.

Args:
workspace_client: Optional WorkspaceClient instance to use for authentication.
If not provided, creates a default WorkspaceClient.
use_ai_gateway: If True, route requests through AI Gateway V2 using the MLflow
API (``{host}/ai-gateway/mlflow/v1``). Cannot be combined with
``use_ai_gateway_native_api``.
use_ai_gateway_native_api: If True, route requests through AI Gateway V2 using
the native OpenAI-compatible API (``{host}/ai-gateway/openai/v1``). Cannot
be combined with ``use_ai_gateway``.
**kwargs: Additional keyword arguments to pass to AsyncDatabricksOpenAI(),
such as timeout and max_retries.
"""
from databricks.sdk import WorkspaceClient

# If workspace_client is provided, use it directly
if workspace_client is not None:
return AsyncDatabricksOpenAI(workspace_client=workspace_client, **kwargs)
else:
# Otherwise, create default workspace client and use it
if workspace_client is None:
workspace_client = WorkspaceClient()
return AsyncDatabricksOpenAI(workspace_client=workspace_client, **kwargs)

return AsyncDatabricksOpenAI(
workspace_client=workspace_client,
use_ai_gateway=use_ai_gateway,
use_ai_gateway_native_api=use_ai_gateway_native_api,
**kwargs,
)


# Utility function for Maximal Marginal Relevance (MMR) reranking.
Expand Down
34 changes: 34 additions & 0 deletions integrations/langchain/tests/unit_tests/test_chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,40 @@ def test_default_workspace_client() -> None:
mock_get_client.assert_called_once_with(workspace_client=None)


def test_use_ai_gateway_parameter() -> None:
"""Test that use_ai_gateway flag is forwarded to get_openai_client."""
from unittest.mock import Mock, patch

mock_openai_client = Mock()

with patch(
"databricks_langchain.chat_models.get_openai_client", return_value=mock_openai_client
) as mock_get_client:
llm = ChatDatabricks(model="test-model", use_ai_gateway=True)
_ = llm.client

mock_get_client.assert_called_once_with(workspace_client=None, use_ai_gateway=True)
assert llm.use_ai_gateway is True
assert llm.use_ai_gateway_native_api is False


def test_use_ai_gateway_native_api_parameter() -> None:
"""Test that use_ai_gateway_native_api flag is forwarded to get_openai_client."""
from unittest.mock import Mock, patch

mock_openai_client = Mock()

with patch(
"databricks_langchain.chat_models.get_openai_client", return_value=mock_openai_client
) as mock_get_client:
llm = ChatDatabricks(model="test-model", use_ai_gateway_native_api=True)
_ = llm.client

mock_get_client.assert_called_once_with(workspace_client=None, use_ai_gateway_native_api=True)
assert llm.use_ai_gateway is False
assert llm.use_ai_gateway_native_api is True


def test_target_uri_deprecation_warning() -> None:
"""Test that using target_uri shows deprecation warning."""
from unittest.mock import Mock, patch
Expand Down
57 changes: 57 additions & 0 deletions integrations/langchain/tests/unit_tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,3 +60,60 @@ def test_get_openai_client_without_timeout_and_retries() -> None:

# Verify the client is returned
assert client == mock_openai_client


def test_get_openai_client_with_use_ai_gateway() -> None:
"""Test use_ai_gateway=True constructs DatabricksOpenAI instead of the SDK helper."""

mock_workspace_client = Mock()
mock_databricks_openai_client = Mock()

with patch(
"databricks_openai.DatabricksOpenAI", return_value=mock_databricks_openai_client
) as mock_databricks_openai:
client = get_openai_client(workspace_client=mock_workspace_client, use_ai_gateway=True)

mock_databricks_openai.assert_called_once_with(
workspace_client=mock_workspace_client,
use_ai_gateway=True,
use_ai_gateway_native_api=False,
)
mock_workspace_client.serving_endpoints.get_open_ai_client.assert_not_called()
assert client == mock_databricks_openai_client


def test_get_openai_client_with_use_ai_gateway_native_api() -> None:
"""Test use_ai_gateway_native_api=True constructs DatabricksOpenAI with that flag."""

mock_workspace_client = Mock()
mock_databricks_openai_client = Mock()

with patch(
"databricks_openai.DatabricksOpenAI", return_value=mock_databricks_openai_client
) as mock_databricks_openai:
client = get_openai_client(
workspace_client=mock_workspace_client, use_ai_gateway_native_api=True
)

mock_databricks_openai.assert_called_once_with(
workspace_client=mock_workspace_client,
use_ai_gateway=False,
use_ai_gateway_native_api=True,
)
mock_workspace_client.serving_endpoints.get_open_ai_client.assert_not_called()
assert client == mock_databricks_openai_client


def test_get_openai_client_without_gateway_uses_serving_endpoints() -> None:
"""Test that DatabricksOpenAI is NOT constructed when no gateway flags are set."""

mock_workspace_client = Mock()
mock_openai_client = Mock()
mock_workspace_client.serving_endpoints.get_open_ai_client.return_value = mock_openai_client

with patch("databricks_openai.DatabricksOpenAI") as mock_databricks_openai:
client = get_openai_client(workspace_client=mock_workspace_client)

mock_databricks_openai.assert_not_called()
mock_workspace_client.serving_endpoints.get_open_ai_client.assert_called_once_with()
assert client == mock_openai_client