diff --git a/integrations/langchain/src/databricks_langchain/chat_models.py b/integrations/langchain/src/databricks_langchain/chat_models.py index 51f76da3..3eff87b7 100644 --- a/integrations/langchain/src/databricks_langchain/chat_models.py +++ b/integrations/langchain/src/databricks_langchain/chat_models.py @@ -108,6 +108,14 @@ class ChatDatabricks(BaseChatModel): use_responses_api=True, ) + For AI Gateway V2 endpoints, set ``use_ai_gateway=True`` (MLflow API) or + ``use_ai_gateway_native_api=True`` (native OpenAI API): + .. code-block:: python + llm = ChatDatabricks( + model="my-gateway-endpoint", + use_ai_gateway=True, + ) + **Invoke**: .. code-block:: python @@ -294,6 +302,12 @@ class GetPopulation(BaseModel): """Any extra parameters to pass to the endpoint.""" use_responses_api: bool = False """Whether to use the Responses API to format inputs and outputs.""" + use_ai_gateway: bool = False + """If True, route requests through AI Gateway V2 using the MLflow API + (``{host}/ai-gateway/mlflow/v1``). Cannot be combined with use_ai_gateway_native_api.""" + use_ai_gateway_native_api: bool = False + """If True, route requests through AI Gateway V2 using the native OpenAI API + (``{host}/ai-gateway/openai/v1``). Cannot be combined with use_ai_gateway.""" timeout: Optional[float] = None """Timeout in seconds for the HTTP request. If None, uses the default timeout.""" max_retries: Optional[int] = None @@ -326,6 +340,10 @@ def _get_client_kwargs(self) -> Dict[str, Any]: openai_kwargs["timeout"] = self.timeout if self.max_retries is not None: openai_kwargs["max_retries"] = self.max_retries + if self.use_ai_gateway: + openai_kwargs["use_ai_gateway"] = True + if self.use_ai_gateway_native_api: + openai_kwargs["use_ai_gateway_native_api"] = True return openai_kwargs @cached_property diff --git a/integrations/langchain/src/databricks_langchain/utils.py b/integrations/langchain/src/databricks_langchain/utils.py index 87e94cff..13e6075c 100644 --- a/integrations/langchain/src/databricks_langchain/utils.py +++ b/integrations/langchain/src/databricks_langchain/utils.py @@ -22,25 +22,43 @@ def get_deployment_client(target_uri: str) -> Any: ) from e -def get_openai_client(workspace_client: Any = None, **kwargs) -> OpenAI: +def get_openai_client( + workspace_client: Any = None, + use_ai_gateway: bool = False, + use_ai_gateway_native_api: bool = False, + **kwargs, +) -> OpenAI: """Get an OpenAI client configured for Databricks. Args: workspace_client: Optional WorkspaceClient instance to use for authentication. If not provided, creates a default WorkspaceClient. + use_ai_gateway: If True, route requests through AI Gateway V2 using the MLflow + API (``{host}/ai-gateway/mlflow/v1``). Cannot be combined with + ``use_ai_gateway_native_api``. + use_ai_gateway_native_api: If True, route requests through AI Gateway V2 using + the native OpenAI-compatible API (``{host}/ai-gateway/openai/v1``). Cannot + be combined with ``use_ai_gateway``. **kwargs: Additional keyword arguments to pass to get_open_ai_client(), - such as timeout and max_retries. + such as timeout and max_retries. Ignored when ``use_ai_gateway`` or + ``use_ai_gateway_native_api`` is True. """ try: from databricks.sdk import WorkspaceClient - # If workspace_client is provided, use it directly - if workspace_client is not None: - return workspace_client.serving_endpoints.get_open_ai_client(**kwargs) - else: - # Otherwise, create default workspace client + if workspace_client is None: workspace_client = WorkspaceClient() - return workspace_client.serving_endpoints.get_open_ai_client(**kwargs) + + if use_ai_gateway or use_ai_gateway_native_api: + from databricks_openai import DatabricksOpenAI + + return DatabricksOpenAI( + workspace_client=workspace_client, + use_ai_gateway=use_ai_gateway, + use_ai_gateway_native_api=use_ai_gateway_native_api, + ) + + return workspace_client.serving_endpoints.get_open_ai_client(**kwargs) except ImportError as e: raise ImportError( @@ -50,24 +68,37 @@ def get_openai_client(workspace_client: Any = None, **kwargs) -> OpenAI: ) from e -def get_async_openai_client(workspace_client: Any = None, **kwargs) -> AsyncOpenAI: +def get_async_openai_client( + workspace_client: Any = None, + use_ai_gateway: bool = False, + use_ai_gateway_native_api: bool = False, + **kwargs, +) -> AsyncOpenAI: """Get an async OpenAI client configured for Databricks using databricks-openai. Args: workspace_client: Optional WorkspaceClient instance to use for authentication. If not provided, creates a default WorkspaceClient. + use_ai_gateway: If True, route requests through AI Gateway V2 using the MLflow + API (``{host}/ai-gateway/mlflow/v1``). Cannot be combined with + ``use_ai_gateway_native_api``. + use_ai_gateway_native_api: If True, route requests through AI Gateway V2 using + the native OpenAI-compatible API (``{host}/ai-gateway/openai/v1``). Cannot + be combined with ``use_ai_gateway``. **kwargs: Additional keyword arguments to pass to AsyncDatabricksOpenAI(), such as timeout and max_retries. """ from databricks.sdk import WorkspaceClient - # If workspace_client is provided, use it directly - if workspace_client is not None: - return AsyncDatabricksOpenAI(workspace_client=workspace_client, **kwargs) - else: - # Otherwise, create default workspace client and use it + if workspace_client is None: workspace_client = WorkspaceClient() - return AsyncDatabricksOpenAI(workspace_client=workspace_client, **kwargs) + + return AsyncDatabricksOpenAI( + workspace_client=workspace_client, + use_ai_gateway=use_ai_gateway, + use_ai_gateway_native_api=use_ai_gateway_native_api, + **kwargs, + ) # Utility function for Maximal Marginal Relevance (MMR) reranking. diff --git a/integrations/langchain/tests/unit_tests/test_chat_models.py b/integrations/langchain/tests/unit_tests/test_chat_models.py index 5e85ff7e..4c6dd8ec 100644 --- a/integrations/langchain/tests/unit_tests/test_chat_models.py +++ b/integrations/langchain/tests/unit_tests/test_chat_models.py @@ -170,6 +170,40 @@ def test_default_workspace_client() -> None: mock_get_client.assert_called_once_with(workspace_client=None) +def test_use_ai_gateway_parameter() -> None: + """Test that use_ai_gateway flag is forwarded to get_openai_client.""" + from unittest.mock import Mock, patch + + mock_openai_client = Mock() + + with patch( + "databricks_langchain.chat_models.get_openai_client", return_value=mock_openai_client + ) as mock_get_client: + llm = ChatDatabricks(model="test-model", use_ai_gateway=True) + _ = llm.client + + mock_get_client.assert_called_once_with(workspace_client=None, use_ai_gateway=True) + assert llm.use_ai_gateway is True + assert llm.use_ai_gateway_native_api is False + + +def test_use_ai_gateway_native_api_parameter() -> None: + """Test that use_ai_gateway_native_api flag is forwarded to get_openai_client.""" + from unittest.mock import Mock, patch + + mock_openai_client = Mock() + + with patch( + "databricks_langchain.chat_models.get_openai_client", return_value=mock_openai_client + ) as mock_get_client: + llm = ChatDatabricks(model="test-model", use_ai_gateway_native_api=True) + _ = llm.client + + mock_get_client.assert_called_once_with(workspace_client=None, use_ai_gateway_native_api=True) + assert llm.use_ai_gateway is False + assert llm.use_ai_gateway_native_api is True + + def test_target_uri_deprecation_warning() -> None: """Test that using target_uri shows deprecation warning.""" from unittest.mock import Mock, patch diff --git a/integrations/langchain/tests/unit_tests/test_utils.py b/integrations/langchain/tests/unit_tests/test_utils.py index 8ee0f58e..cbe3e9b8 100644 --- a/integrations/langchain/tests/unit_tests/test_utils.py +++ b/integrations/langchain/tests/unit_tests/test_utils.py @@ -60,3 +60,60 @@ def test_get_openai_client_without_timeout_and_retries() -> None: # Verify the client is returned assert client == mock_openai_client + + +def test_get_openai_client_with_use_ai_gateway() -> None: + """Test use_ai_gateway=True constructs DatabricksOpenAI instead of the SDK helper.""" + + mock_workspace_client = Mock() + mock_databricks_openai_client = Mock() + + with patch( + "databricks_openai.DatabricksOpenAI", return_value=mock_databricks_openai_client + ) as mock_databricks_openai: + client = get_openai_client(workspace_client=mock_workspace_client, use_ai_gateway=True) + + mock_databricks_openai.assert_called_once_with( + workspace_client=mock_workspace_client, + use_ai_gateway=True, + use_ai_gateway_native_api=False, + ) + mock_workspace_client.serving_endpoints.get_open_ai_client.assert_not_called() + assert client == mock_databricks_openai_client + + +def test_get_openai_client_with_use_ai_gateway_native_api() -> None: + """Test use_ai_gateway_native_api=True constructs DatabricksOpenAI with that flag.""" + + mock_workspace_client = Mock() + mock_databricks_openai_client = Mock() + + with patch( + "databricks_openai.DatabricksOpenAI", return_value=mock_databricks_openai_client + ) as mock_databricks_openai: + client = get_openai_client( + workspace_client=mock_workspace_client, use_ai_gateway_native_api=True + ) + + mock_databricks_openai.assert_called_once_with( + workspace_client=mock_workspace_client, + use_ai_gateway=False, + use_ai_gateway_native_api=True, + ) + mock_workspace_client.serving_endpoints.get_open_ai_client.assert_not_called() + assert client == mock_databricks_openai_client + + +def test_get_openai_client_without_gateway_uses_serving_endpoints() -> None: + """Test that DatabricksOpenAI is NOT constructed when no gateway flags are set.""" + + mock_workspace_client = Mock() + mock_openai_client = Mock() + mock_workspace_client.serving_endpoints.get_open_ai_client.return_value = mock_openai_client + + with patch("databricks_openai.DatabricksOpenAI") as mock_databricks_openai: + client = get_openai_client(workspace_client=mock_workspace_client) + + mock_databricks_openai.assert_not_called() + mock_workspace_client.serving_endpoints.get_open_ai_client.assert_called_once_with() + assert client == mock_openai_client