From 964e88a8677d33bad51f0fa717f17735ad1bcca2 Mon Sep 17 00:00:00 2001 From: Nisha Balaji Date: Mon, 26 Jan 2026 18:03:03 -0800 Subject: [PATCH 01/11] changes --- .../vector_search_retriever_tool.py | 97 +++- .../test_vector_search_retriever_tool.py | 467 +++++++++++++++--- .../vector_search_retriever_tool.py | 206 ++------ .../test_vector_search_retriever_tool.py | 52 +- .../vector_search_retriever_tool.py | 118 ++++- 5 files changed, 672 insertions(+), 268 deletions(-) diff --git a/integrations/langchain/src/databricks_langchain/vector_search_retriever_tool.py b/integrations/langchain/src/databricks_langchain/vector_search_retriever_tool.py index c4ba589ab..666fc1e93 100644 --- a/integrations/langchain/src/databricks_langchain/vector_search_retriever_tool.py +++ b/integrations/langchain/src/databricks_langchain/vector_search_retriever_tool.py @@ -1,17 +1,27 @@ -from typing import List, Optional, Type +import asyncio +import logging +from typing import Any, Dict, List, Optional, Type, Union from databricks_ai_bridge.utils.vector_search import IndexDetails + +_logger = logging.getLogger(__name__) from databricks_ai_bridge.vector_search_retriever_tool import ( FilterItem, VectorSearchRetrieverToolInput, VectorSearchRetrieverToolMixin, vector_search_retriever_tool_trace, ) +from langchain_core.documents import Document from langchain_core.embeddings import Embeddings from langchain_core.tools import BaseTool +from langchain_core.tools import BaseTool as LangChainBaseTool from pydantic import BaseModel, Field, PrivateAttr, model_validator from databricks_langchain import DatabricksEmbeddings +from databricks_langchain.multi_server_mcp_client import ( + DatabricksMCPServer, + DatabricksMultiServerMCPClient, +) from databricks_langchain.vectorstores import DatabricksVectorSearch @@ -48,6 +58,7 @@ class VectorSearchRetrieverTool(BaseTool, VectorSearchRetrieverToolMixin): args_schema: Type[BaseModel] = VectorSearchRetrieverToolInput _vector_store: DatabricksVectorSearch = PrivateAttr() + _mcp_tool: Optional[LangChainBaseTool] = PrivateAttr(default=None) @model_validator(mode="after") def _validate_tool_inputs(self): @@ -83,11 +94,72 @@ def _validate_tool_inputs(self): return self - @vector_search_retriever_tool_trace - def _run(self, query: str, filters: Optional[List[FilterItem]] = None, **kwargs) -> str: + def _create_or_get_mcp_tool(self) -> LangChainBaseTool: + """Create or return existing MCP tool using LangChain MCP Server.""" + if self._mcp_tool is not None: + return self._mcp_tool + + catalog, schema, index = self._parse_index_name() + + try: + server = DatabricksMCPServer.from_vector_search( + catalog=catalog, + schema=schema, + index_name=index, + name=f"vs-{index}", + workspace_client=self.workspace_client, + ) + client = DatabricksMultiServerMCPClient([server]) + except Exception as e: + self._handle_mcp_creation_error(e) + + tools = asyncio.run(client.get_tools()) + self._validate_mcp_tools(tools) + + self._mcp_tool = tools[0] + return self._mcp_tool + + def _build_mcp_input( + self, + query: str, + filters: Optional[Union[Dict[str, Any], List[FilterItem]]] = None, + **kwargs: Any, + ) -> Dict[str, Any]: + """Build input for MCP tool invocation.""" + mcp_input = self._build_mcp_params(filters, **kwargs) + mcp_input["query"] = query + return mcp_input + + def _parse_mcp_response(self, mcp_response: str) -> List[Document]: + """Parse MCP tool response into LangChain Documents.""" + dicts = self._parse_mcp_response_to_dicts(mcp_response, strict=False) + return [Document(page_content=d["page_content"], metadata=d["metadata"]) for d in dicts] + + def _execute_mcp_path( + self, + query: str, + filters: Optional[Union[Dict[str, Any], List[FilterItem]]] = None, + **kwargs: Any, + ) -> List[Document]: + """Execute vector search via LangChain MCP infrastructure.""" + try: + mcp_tool = self._create_or_get_mcp_tool() + mcp_input = self._build_mcp_input(query, filters, **kwargs) + result = mcp_tool.invoke(mcp_input) + return self._parse_mcp_response(result) + except Exception as e: + self._handle_mcp_execution_error(e) + + def _execute_direct_api_path( + self, + query: str, + filters: Optional[Union[Dict[str, Any], List[FilterItem]]] = None, + **kwargs: Any, + ) -> List[Document]: + """Execute vector search via direct DatabricksVectorSearch API.""" kwargs = {**kwargs, **(self.model_extra or {})} - # Since LLM can generate either a dict or FilterItem, convert to dict always - filters_dict = {dict(item)["key"]: dict(item)["value"] for item in (filters or [])} + # Normalize filters to dict format + filters_dict = self._normalize_filters(filters) combined_filters = {**filters_dict, **(self.filters or {})} # Allow kwargs to override the default values upon invocation @@ -104,3 +176,18 @@ def _run(self, query: str, filters: Optional[List[FilterItem]] = None, **kwargs) } ) return self._vector_store.similarity_search(**kwargs) + + @vector_search_retriever_tool_trace + def _run( + self, + query: str, + filters: Optional[Union[Dict[str, Any], List[FilterItem]]] = None, + **kwargs, + ) -> List[Document]: + """Execute vector search with automatic routing.""" + index_details = IndexDetails(self._vector_store.index) + + if index_details.is_databricks_managed_embeddings(): + return self._execute_mcp_path(query, filters, **kwargs) + else: + return self._execute_direct_api_path(query, filters, **kwargs) diff --git a/integrations/langchain/tests/unit_tests/test_vector_search_retriever_tool.py b/integrations/langchain/tests/unit_tests/test_vector_search_retriever_tool.py index a8ed42acd..b51f41167 100644 --- a/integrations/langchain/tests/unit_tests/test_vector_search_retriever_tool.py +++ b/integrations/langchain/tests/unit_tests/test_vector_search_retriever_tool.py @@ -1,8 +1,9 @@ import json import os import threading +import uuid from typing import Any, Dict, List, Optional -from unittest.mock import MagicMock, create_autospec, patch +from unittest.mock import AsyncMock, MagicMock, create_autospec, patch import mlflow import pytest @@ -13,12 +14,14 @@ ALL_INDEX_NAMES, DELTA_SYNC_INDEX, DELTA_SYNC_INDEX_EMBEDDING_MODEL_ENDPOINT_NAME, + DIRECT_ACCESS_INDEX, INPUT_TEXTS, _get_index, mock_vs_client, mock_workspace_client, ) from databricks_ai_bridge.vector_search_retriever_tool import FilterItem +from langchain_core.documents import Document from langchain_core.embeddings import Embeddings from langchain_core.tools import BaseTool from mlflow.entities import SpanType @@ -41,6 +44,73 @@ ) +def _create_mcp_response_json(texts: List[str] = None) -> str: + """Create a mock MCP response in JSON format.""" + texts = texts or INPUT_TEXTS + return json.dumps( + [ + {"id": str(uuid.uuid4()), "text": text, "score": 0.85 - (i * 0.1)} + for i, text in enumerate(texts) + ] + ) + + +@pytest.fixture +def mock_mcp_infrastructure(): + """Mock MCP infrastructure for tests that need it.""" + # Create mock MCP tool that returns JSON response + mock_tool = MagicMock() + mock_tool.invoke = MagicMock(return_value=_create_mcp_response_json()) + + # Create mock MCP client + mock_client_instance = MagicMock() + mock_client_instance.get_tools = AsyncMock(return_value=[mock_tool]) + + # Create mock MCP server + mock_server_instance = MagicMock() + + with patch( + "databricks_langchain.vector_search_retriever_tool.DatabricksMultiServerMCPClient" + ) as mock_client_class, patch( + "databricks_langchain.vector_search_retriever_tool.DatabricksMCPServer" + ) as mock_server_class: + mock_client_class.return_value = mock_client_instance + mock_server_class.from_vector_search.return_value = mock_server_instance + yield { + "client_class": mock_client_class, + "client_instance": mock_client_instance, + "server_class": mock_server_class, + "server_instance": mock_server_instance, + "tool": mock_tool, + } + + +@pytest.fixture(params=["mcp", "direct_api"]) +def execution_path(request, mock_mcp_infrastructure): + """Parametrized fixture that sets up mocks for MCP or Direct API path.""" + if request.param == "mcp": + yield { + "path": "mcp", + "index_name": DELTA_SYNC_INDEX, + "mock_tool": mock_mcp_infrastructure["tool"], + "mock_mcp": mock_mcp_infrastructure, + } + else: + # For direct API, use an index that requires self-managed embeddings + yield { + "path": "direct_api", + "index_name": DIRECT_ACCESS_INDEX, + "mock_tool": None, + "mock_mcp": mock_mcp_infrastructure, + } + + +def setup_tool_for_path(execution_path, tool): + """Set up mock for the tool based on execution path.""" + if execution_path["path"] == "direct_api": + tool._vector_store.similarity_search = MagicMock(return_value=[]) + + def init_vector_search_tool( index_name: str, columns: Optional[List[str]] = None, @@ -93,40 +163,55 @@ def test_chat_model_bind_tools(llm: ChatDatabricks, index_name: str) -> None: assert isinstance(response, AIMessage) -def test_filters_are_passed_through() -> None: - vector_search_tool = init_vector_search_tool(DELTA_SYNC_INDEX) - vector_search_tool._vector_store.similarity_search = MagicMock() +def test_filters_are_passed_through(execution_path) -> None: + """Test filters are passed through correctly on both paths.""" + tool = init_vector_search_tool(execution_path["index_name"]) + setup_tool_for_path(execution_path, tool) - vector_search_tool.invoke( + tool.invoke( { "query": "what cities are in Germany", "filters": [FilterItem(key="country", value="Germany")], } ) - vector_search_tool._vector_store.similarity_search.assert_called_once_with( - query="what cities are in Germany", - k=vector_search_tool.num_results, - filter={"country": "Germany"}, - query_type=vector_search_tool.query_type, - ) - -def test_filters_are_combined() -> None: - vector_search_tool = init_vector_search_tool(DELTA_SYNC_INDEX, filters={"city LIKE": "Berlin"}) - vector_search_tool._vector_store.similarity_search = MagicMock() - - vector_search_tool.invoke( + if execution_path["path"] == "mcp": + execution_path["mock_tool"].invoke.assert_called_once() + call_args = execution_path["mock_tool"].invoke.call_args[0][0] + assert call_args["query"] == "what cities are in Germany" + # MCP path: filters are JSON stringified + assert json.loads(call_args["filters"]) == {"country": "Germany"} + else: + tool._vector_store.similarity_search.assert_called_once() + call_args = tool._vector_store.similarity_search.call_args + assert call_args[1]["query"] == "what cities are in Germany" + assert call_args[1]["filter"] == {"country": "Germany"} + + +def test_filters_are_combined(execution_path) -> None: + """Test filters are combined correctly (predefined + runtime) on both paths.""" + tool = init_vector_search_tool(execution_path["index_name"], filters={"city LIKE": "Berlin"}) + setup_tool_for_path(execution_path, tool) + + tool.invoke( { "query": "what cities are in Germany", "filters": [FilterItem(key="country", value="Germany")], } ) - vector_search_tool._vector_store.similarity_search.assert_called_once_with( - query="what cities are in Germany", - k=vector_search_tool.num_results, - filter={"city LIKE": "Berlin", "country": "Germany"}, - query_type=vector_search_tool.query_type, - ) + + expected_filters = {"city LIKE": "Berlin", "country": "Germany"} + if execution_path["path"] == "mcp": + execution_path["mock_tool"].invoke.assert_called_once() + call_args = execution_path["mock_tool"].invoke.call_args[0][0] + assert call_args["query"] == "what cities are in Germany" + # MCP path: filters are JSON stringified + assert json.loads(call_args["filters"]) == expected_filters + else: + tool._vector_store.similarity_search.assert_called_once() + call_args = tool._vector_store.similarity_search.call_args + assert call_args[1]["query"] == "what cities are in Germany" + assert call_args[1]["filter"] == expected_filters @pytest.mark.parametrize("index_name", ALL_INDEX_NAMES) @@ -136,6 +221,7 @@ def test_filters_are_combined() -> None: @pytest.mark.parametrize("embedding", [None, EMBEDDING_MODEL]) @pytest.mark.parametrize("text_column", [None, "text"]) def test_vector_search_retriever_tool_combinations( + mock_mcp_infrastructure, index_name: str, columns: Optional[List[str]], tool_name: Optional[str], @@ -160,7 +246,8 @@ def test_vector_search_retriever_tool_combinations( assert result is not None -def test_vector_search_retriever_tool_combinations() -> None: +def test_vector_search_retriever_tool_doc_uri_primary_key(mock_mcp_infrastructure) -> None: + """Test that doc_uri and primary_key work correctly with MCP path.""" vector_search_tool = init_vector_search_tool( index_name=DELTA_SYNC_INDEX, doc_uri="uri", @@ -168,8 +255,13 @@ def test_vector_search_retriever_tool_combinations() -> None: ) assert isinstance(vector_search_tool, BaseTool) result = vector_search_tool.invoke("Databricks Agent Framework") - assert all(item.metadata.keys() == {"doc_uri", "chunk_id"} for item in result) - assert all(item.page_content for item in result) + # With MCP path, results are parsed from mock JSON response + assert result is not None + assert len(result) > 0 + assert all(isinstance(doc, Document) for doc in result) + # Verify Documents have expected structure from mock response + assert all(doc.page_content for doc in result) + assert all("id" in doc.metadata for doc in result) @pytest.mark.parametrize("index_name", ALL_INDEX_NAMES) @@ -191,16 +283,20 @@ def test_vector_search_retriever_tool_description_generation(index_name: str) -> @pytest.mark.parametrize("index_name", ALL_INDEX_NAMES) @pytest.mark.parametrize("tool_name", [None, "test_tool"]) -def test_vs_tool_tracing(index_name: str, tool_name: Optional[str]) -> None: +def test_vs_tool_tracing(mock_mcp_infrastructure, index_name: str, tool_name: Optional[str]) -> None: vector_search_tool = init_vector_search_tool(index_name, tool_name=tool_name) vector_search_tool._run("Databricks Agent Framework") + trace = mlflow.get_trace(mlflow.get_last_active_trace_id()) spans = trace.search_spans(name=tool_name or index_name, span_type=SpanType.RETRIEVER) assert len(spans) == 1 inputs = json.loads(trace.to_dict()["data"]["spans"][0]["attributes"]["mlflow.spanInputs"]) assert inputs["query"] == "Databricks Agent Framework" outputs = json.loads(trace.to_dict()["data"]["spans"][0]["attributes"]["mlflow.spanOutputs"]) - assert [d["page_content"] in INPUT_TEXTS for d in outputs] + # Verify outputs are Documents with page_content + assert len(outputs) > 0 + assert all("page_content" in d for d in outputs) + assert all(d["page_content"] for d in outputs) # page_content is not empty @pytest.mark.parametrize("index_name", ALL_INDEX_NAMES) @@ -344,36 +440,44 @@ def test_vector_search_client_with_sp_workspace_client(): ) -def test_kwargs_are_passed_through() -> None: - vector_search_tool = init_vector_search_tool(DELTA_SYNC_INDEX, score_threshold=0.5) - vector_search_tool._vector_store.similarity_search = MagicMock() +def test_kwargs_are_passed_through(execution_path) -> None: + """Test kwargs are passed through correctly on both paths.""" + tool = init_vector_search_tool(execution_path["index_name"], score_threshold=0.5) + setup_tool_for_path(execution_path, tool) - vector_search_tool.invoke( - {"query": "what cities are in Germany", "extra_param": "something random"}, - ) - vector_search_tool._vector_store.similarity_search.assert_called_once_with( - query="what cities are in Germany", - k=vector_search_tool.num_results, - query_type=vector_search_tool.query_type, - filter={}, - score_threshold=0.5, - extra_param="something random", - ) + tool.invoke({"query": "what cities are in Germany"}) + if execution_path["path"] == "mcp": + execution_path["mock_tool"].invoke.assert_called_once() + call_args = execution_path["mock_tool"].invoke.call_args[0][0] + assert call_args["query"] == "what cities are in Germany" + assert call_args["score_threshold"] == 0.5 + else: + tool._vector_store.similarity_search.assert_called_once() + call_args = tool._vector_store.similarity_search.call_args + assert call_args[1]["query"] == "what cities are in Germany" + assert call_args[1]["score_threshold"] == 0.5 -def test_kwargs_override_both_num_results_and_query_type() -> None: - vector_search_tool = init_vector_search_tool(DELTA_SYNC_INDEX, num_results=10, query_type="ANN") - vector_search_tool._vector_store.similarity_search = MagicMock() - vector_search_tool.invoke( - {"query": "what cities are in Germany", "k": 3, "query_type": "HYBRID"}, - ) - vector_search_tool._vector_store.similarity_search.assert_called_once_with( - query="what cities are in Germany", - k=3, # Should use overridden value - query_type="HYBRID", # Should use overridden value - filter={}, - ) +def test_kwargs_override_both_num_results_and_query_type(execution_path) -> None: + """Test kwargs can override num_results and query_type on both paths.""" + tool = init_vector_search_tool(execution_path["index_name"], num_results=10, query_type="ANN") + setup_tool_for_path(execution_path, tool) + + tool.invoke({"query": "what cities are in Germany", "k": 3, "query_type": "HYBRID"}) + + if execution_path["path"] == "mcp": + execution_path["mock_tool"].invoke.assert_called_once() + call_args = execution_path["mock_tool"].invoke.call_args[0][0] + assert call_args["query"] == "what cities are in Germany" + assert call_args["num_results"] == 3 + assert call_args["query_type"] == "HYBRID" + else: + tool._vector_store.similarity_search.assert_called_once() + call_args = tool._vector_store.similarity_search.call_args + assert call_args[1]["query"] == "what cities are in Germany" + assert call_args[1]["k"] == 3 + assert call_args[1]["query_type"] == "HYBRID" def test_enhanced_filter_description_with_column_metadata() -> None: @@ -458,34 +562,37 @@ def test_cannot_use_both_dynamic_filter_and_predefined_filters() -> None: ) -def test_predefined_filters_work_without_dynamic_filter() -> None: - """Test that predefined filters work correctly when dynamic_filter is False.""" - # Initialize tool with only predefined filters (dynamic_filter=False by default) - vector_search_tool = init_vector_search_tool( - DELTA_SYNC_INDEX, filters={"status": "active", "category": "electronics"} +def test_predefined_filters_work_without_dynamic_filter(execution_path) -> None: + """Test that predefined filters work correctly when dynamic_filter is False on both paths.""" + tool = init_vector_search_tool( + execution_path["index_name"], filters={"status": "active", "category": "electronics"} ) + setup_tool_for_path(execution_path, tool) # The filters parameter should NOT be exposed since dynamic_filter=False - args_schema = vector_search_tool.args_schema + args_schema = tool.args_schema assert "filters" not in args_schema.model_fields - # Test that predefined filters are used - vector_search_tool._vector_store.similarity_search = MagicMock() + tool.invoke({"query": "what electronics are available"}) - vector_search_tool.invoke({"query": "what electronics are available"}) + expected_filters = {"status": "active", "category": "electronics"} + if execution_path["path"] == "mcp": + execution_path["mock_tool"].invoke.assert_called_once() + call_args = execution_path["mock_tool"].invoke.call_args[0][0] + assert call_args["query"] == "what electronics are available" + # MCP path: filters are JSON stringified + assert json.loads(call_args["filters"]) == expected_filters + else: + tool._vector_store.similarity_search.assert_called_once() + call_args = tool._vector_store.similarity_search.call_args + assert call_args[1]["query"] == "what electronics are available" + assert call_args[1]["filter"] == expected_filters - vector_search_tool._vector_store.similarity_search.assert_called_once_with( - query="what electronics are available", - k=vector_search_tool.num_results, - query_type=vector_search_tool.query_type, - filter={"status": "active", "category": "electronics"}, # Only predefined filters - ) - -def test_filter_item_serialization() -> None: - """Test that FilterItem objects are properly converted to dictionaries.""" - vector_search_tool = init_vector_search_tool(DELTA_SYNC_INDEX) - vector_search_tool._vector_store.similarity_search = MagicMock() +def test_filter_item_serialization(execution_path) -> None: + """Test that FilterItem objects are properly converted to dictionaries on both paths.""" + tool = init_vector_search_tool(execution_path["index_name"]) + setup_tool_for_path(execution_path, tool) # Test various filter types filters = [ @@ -495,7 +602,7 @@ def test_filter_item_serialization() -> None: FilterItem(key="tags", value=["wireless", "bluetooth"]), ] - vector_search_tool.invoke({"query": "find products", "filters": filters}) + tool.invoke({"query": "find products", "filters": filters}) expected_filters = { "category": "electronics", @@ -504,9 +611,209 @@ def test_filter_item_serialization() -> None: "tags": ["wireless", "bluetooth"], } - vector_search_tool._vector_store.similarity_search.assert_called_once_with( - query="find products", - k=vector_search_tool.num_results, - query_type=vector_search_tool.query_type, - filter=expected_filters, + if execution_path["path"] == "mcp": + execution_path["mock_tool"].invoke.assert_called_once() + call_args = execution_path["mock_tool"].invoke.call_args[0][0] + assert call_args["query"] == "find products" + # MCP path: filters are JSON stringified + assert json.loads(call_args["filters"]) == expected_filters + else: + tool._vector_store.similarity_search.assert_called_once() + call_args = tool._vector_store.similarity_search.call_args + assert call_args[1]["query"] == "find products" + assert call_args[1]["filter"] == expected_filters + + +# ============================================================================= +# MCP Path Specific Tests +# ============================================================================= + + +def test_mcp_path_is_used_for_databricks_managed_embeddings(mock_mcp_infrastructure) -> None: + """Test that MCP path is used for Databricks-managed embeddings indexes.""" + vector_search_tool = init_vector_search_tool(DELTA_SYNC_INDEX) + + # Invoke the tool (should use MCP path for DELTA_SYNC_INDEX which has managed embeddings) + result = vector_search_tool._run("test query") + + # Verify MCP server was created with correct parameters + mock_mcp_infrastructure["server_class"].from_vector_search.assert_called_once() + call_kwargs = mock_mcp_infrastructure["server_class"].from_vector_search.call_args[1] + assert call_kwargs["catalog"] == "test" + assert call_kwargs["schema"] == "delta_sync" + assert call_kwargs["index_name"] == "index" + + # Verify MCP client was used + mock_mcp_infrastructure["client_class"].assert_called_once() + + # Verify MCP tool was invoked + mock_mcp_infrastructure["tool"].invoke.assert_called_once() + + +def test_direct_api_path_is_used_for_self_managed_embeddings(mock_mcp_infrastructure) -> None: + """Test that direct API path is used for self-managed embeddings indexes.""" + # Use an index that requires self-managed embeddings + index_name = "test.direct_access.index" + vector_search_tool = init_vector_search_tool(index_name) + vector_search_tool._vector_store.similarity_search = MagicMock(return_value=[]) + + # Invoke the tool (should use direct API path) + result = vector_search_tool._run("test query") + + # Verify similarity_search was called directly + vector_search_tool._vector_store.similarity_search.assert_called_once() + + # Verify MCP was NOT used for self-managed embeddings + mock_mcp_infrastructure["tool"].invoke.assert_not_called() + + +def test_mcp_tool_is_cached(mock_mcp_infrastructure) -> None: + """Test that MCP tool is cached and not recreated on subsequent calls.""" + vector_search_tool = init_vector_search_tool(DELTA_SYNC_INDEX) + + # Call _run multiple times + vector_search_tool._run("query 1") + vector_search_tool._run("query 2") + vector_search_tool._run("query 3") + + # MCP server should only be created once + assert mock_mcp_infrastructure["server_class"].from_vector_search.call_count == 1 + + # MCP client should only be created once + assert mock_mcp_infrastructure["client_class"].call_count == 1 + + # But MCP tool should be invoked 3 times + assert mock_mcp_infrastructure["tool"].invoke.call_count == 3 + + +def test_mcp_response_parsing_json_array() -> None: + """Test that MCP JSON array response is parsed correctly into Documents.""" + vector_search_tool = init_vector_search_tool(DELTA_SYNC_INDEX) + + json_response = json.dumps( + [ + {"id": "doc1", "text": "content1", "score": 0.9}, + {"id": "doc2", "text": "content2", "score": 0.8}, + ] ) + + docs = vector_search_tool._parse_mcp_response(json_response) + + assert len(docs) == 2 + assert all(isinstance(doc, Document) for doc in docs) + assert docs[0].page_content == "content1" + assert docs[0].metadata == {"id": "doc1", "score": 0.9} + assert docs[1].page_content == "content2" + assert docs[1].metadata == {"id": "doc2", "score": 0.8} + + +def test_mcp_response_parsing_non_json() -> None: + """Test that non-JSON MCP response is treated as a single document.""" + vector_search_tool = init_vector_search_tool(DELTA_SYNC_INDEX) + + plain_text_response = "This is a plain text response" + + docs = vector_search_tool._parse_mcp_response(plain_text_response) + + assert len(docs) == 1 + assert docs[0].page_content == plain_text_response + + +def test_mcp_response_parsing_non_list_json() -> None: + """Test that non-list JSON is converted to a single document.""" + vector_search_tool = init_vector_search_tool(DELTA_SYNC_INDEX) + + json_response = json.dumps({"message": "single object response"}) + + docs = vector_search_tool._parse_mcp_response(json_response) + + assert len(docs) == 1 + assert docs[0].page_content == "{'message': 'single object response'}" + + +def test_normalize_filters_with_filter_items() -> None: + """Test that FilterItem list is normalized to dict.""" + vector_search_tool = init_vector_search_tool(DELTA_SYNC_INDEX) + + filters = [ + FilterItem(key="category", value="electronics"), + FilterItem(key="price >=", value=100), + ] + + result = vector_search_tool._normalize_filters(filters) + + assert result == {"category": "electronics", "price >=": 100} + + +def test_normalize_filters_with_dict() -> None: + """Test that dict filters are passed through unchanged.""" + vector_search_tool = init_vector_search_tool(DELTA_SYNC_INDEX) + + filters = {"category": "electronics", "price >=": 100} + + result = vector_search_tool._normalize_filters(filters) + + assert result == filters + + +def test_normalize_filters_with_none() -> None: + """Test that None filters return empty dict.""" + vector_search_tool = init_vector_search_tool(DELTA_SYNC_INDEX) + + result = vector_search_tool._normalize_filters(None) + + assert result == {} + + +def test_build_mcp_input() -> None: + """Test MCP input building with various parameters.""" + from databricks.vector_search.reranker import DatabricksReranker + + # Basic parameters + tool = init_vector_search_tool(DELTA_SYNC_INDEX) + mcp_input = tool._build_mcp_input("test query") + assert mcp_input["query"] == "test query" + assert mcp_input["num_results"] == tool.num_results + assert mcp_input["query_type"] == tool.query_type + assert mcp_input["include_score"] == "false" # Default + + # With filters (JSON stringified for MCP - parse back to compare) + filters = [FilterItem(key="category", value="electronics")] + mcp_input = tool._build_mcp_input("test query", filters=filters) + assert json.loads(mcp_input["filters"]) == {"category": "electronics"} + + # Combines predefined and runtime filters + tool_with_filters = init_vector_search_tool(DELTA_SYNC_INDEX, filters={"status": "active"}) + runtime_filters = [FilterItem(key="category", value="electronics")] + mcp_input = tool_with_filters._build_mcp_input("test query", filters=runtime_filters) + expected_filters = {"status": "active", "category": "electronics"} + assert json.loads(mcp_input["filters"]) == expected_filters + + # kwargs override defaults + tool_with_defaults = init_vector_search_tool(DELTA_SYNC_INDEX, num_results=10, query_type="ANN") + mcp_input = tool_with_defaults._build_mcp_input("test query", num_results=5, query_type="HYBRID") + assert mcp_input["num_results"] == 5 + assert mcp_input["query_type"] == "HYBRID" + + # With columns (comma-separated for MCP) + tool_with_columns = init_vector_search_tool(DELTA_SYNC_INDEX, columns=["id", "text", "score"]) + mcp_input = tool_with_columns._build_mcp_input("test query") + assert mcp_input["columns"] == "id,text,score" + + # With score_threshold (converted to float) + mcp_input = tool._build_mcp_input("test query", score_threshold=0.7) + assert mcp_input["score_threshold"] == 0.7 + assert isinstance(mcp_input["score_threshold"], float) + + # With include_score=True + tool_with_score = init_vector_search_tool(DELTA_SYNC_INDEX, include_score=True) + mcp_input = tool_with_score._build_mcp_input("test query") + assert mcp_input["include_score"] == "true" + + # With reranker + reranker = DatabricksReranker(columns_to_rerank=["text", "title"]) + tool_with_reranker = init_vector_search_tool(DELTA_SYNC_INDEX, reranker=reranker) + mcp_input = tool_with_reranker._build_mcp_input("test query") + assert mcp_input["columns_to_rerank"] == "text,title" + + diff --git a/integrations/openai/src/databricks_openai/vector_search_retriever_tool.py b/integrations/openai/src/databricks_openai/vector_search_retriever_tool.py index b026b4ef7..0802261d9 100644 --- a/integrations/openai/src/databricks_openai/vector_search_retriever_tool.py +++ b/integrations/openai/src/databricks_openai/vector_search_retriever_tool.py @@ -1,5 +1,4 @@ import inspect -import json import logging from typing import Any, Callable, Dict, List, Optional, Tuple, Union @@ -198,40 +197,6 @@ def _validate_tool_inputs(self): return self - @vector_search_retriever_tool_trace - def execute( - self, - query: str, - filters: Optional[Union[Dict[str, Any], List[FilterItem]]] = None, - openai_client: OpenAI = None, - **kwargs: Any, - ) -> List[Dict]: - """ - Execute the VectorSearchIndex tool calls from the ChatCompletions response that correspond to the - self.tool VectorSearchRetrieverToolInput and attach the retrieved documents into tool call messages. - - Execute vector search with automatic routing: - - MCP path: Used for Databricks-managed embeddings (no embedding model configuration needed) - - Direct API path: Used for self-managed embeddings (requires openai_client) - - Args: - query: The query text to use for the retrieval. - filters: Optional filters to refine vector search results. - openai_client: The OpenAI client object used to generate embeddings for retrieval queries. - Only used for self-managed embeddings. If not provided, the default OpenAI - client in the current environment will be used. - **kwargs: Additional search parameters (e.g., num_results, query_type, score_threshold, reranker). - For Databricks-managed embeddings, these are passed as MCP metadata. - For self-managed embeddings, these are passed to similarity_search(). - - Returns: - A list of document dictionaries. Format may vary between MCP and Direct API paths. - """ - if self._index_details.is_databricks_managed_embeddings(): - return self._execute_mcp_path(query, filters, **kwargs) - else: - return self._execute_direct_api_path(query, filters, openai_client, **kwargs) - def _create_or_get_mcp_toolkit(self) -> Callable: """ If it does not exist, create the MCP tool execution function for this index. @@ -243,12 +208,7 @@ def _create_or_get_mcp_toolkit(self) -> Callable: if self._mcp_tool_execute is not None: return self._mcp_tool_execute - parts = self.index_name.split(".") - if len(parts) != 3: - raise ValueError( - f"Invalid index name format: {self.index_name}. Expected 'catalog.schema.index'" - ) - catalog, schema, index = parts + catalog, schema, index = self._parse_index_name() try: self._mcp_toolkit = McpServerToolkit.from_vector_search( @@ -258,125 +218,37 @@ def _create_or_get_mcp_toolkit(self) -> Callable: workspace_client=self.workspace_client, ) except Exception as e: - raise RuntimeError( - f"Failed to initialize MCP toolkit for index {self.index_name}. " - f"Ensure the index exists and is configured for Databricks-managed embeddings. " - f"Error: {e}" - ) from e + self._handle_mcp_creation_error(e) tools = self._mcp_toolkit.get_tools() - if len(tools) < 1: - raise ValueError( - f"Expected exactly 1 MCP tool for index {self.index_name}, but got {len(tools)}" - ) + self._validate_mcp_tools(tools) self._mcp_tool_execute = tools[0].execute return self._mcp_tool_execute - def _normalize_filters( - self, filters: Optional[Union[Dict[str, Any], List[FilterItem]]] - ) -> Dict[str, Any]: - """ - Normalize filters to a dict format. - - Args: - filters: Either a dict or List[FilterItem] - - Returns: - Dict of filter key-value pairs - """ - if filters is None: - return {} - if isinstance(filters, dict): - return filters - return {item.model_dump()["key"]: item.model_dump()["value"] for item in filters} - def _build_mcp_meta( self, filters: Optional[Union[Dict[str, Any], List[FilterItem]]] = None, **kwargs: Any ) -> Dict[str, Any]: - kwargs = {**(self.model_extra or {}), **kwargs} - - meta = {} - - num_results = kwargs.pop("num_results", self.num_results) - meta["num_results"] = num_results - - if self.query_type or "query_type" in kwargs: - query_type = kwargs.pop("query_type", self.query_type) - if query_type: - meta["query_type"] = query_type - - if self.columns: - meta["columns"] = ",".join(self.columns) - - combined_filters = {**self._normalize_filters(filters), **(self.filters or {})} - if combined_filters: - try: - meta["filters"] = json.dumps(combined_filters) - except (TypeError, ValueError) as e: - raise ValueError(f"Filters must be JSON serializable: {e}") from e - - if "score_threshold" in kwargs: - meta["score_threshold"] = float(kwargs.pop("score_threshold")) - - # Always send include_score explicitly to override backend defaults - meta["include_score"] = "true" if self.include_score else "false" - - reranker = kwargs.pop("reranker", self.reranker) - if reranker and hasattr(reranker, "columns_to_rerank"): - meta["columns_to_rerank"] = ",".join(reranker.columns_to_rerank) - - # Warn about any unknown kwargs - if kwargs: - _logger.warning( - f"Ignoring unsupported kwargs for MCP vector search: {list(kwargs.keys())}" - ) - - return meta - - def _normalize_mcp_result(self, result: Dict) -> Dict: - """ - Normalize MCP result to page_content/metadata format for backward compatibility. - - MCP returns: {"id": "doc1", "text": "content", "score": 0.95} - We convert to: {"page_content": "content", "metadata": {"id": "doc1", "score": 0.95}} - - This ensures callers get consistent output regardless of MCP vs Direct API path. - """ - text_column = self.text_column - page_content = result.get(text_column, "") - - metadata = {k: v for k, v in result.items() if k != text_column} - - return {"page_content": page_content, "metadata": metadata} + """Build metadata dict for MCP tool invocation.""" + return self._build_mcp_params(filters, **kwargs) def _parse_mcp_response(self, mcp_response: str) -> List[Dict]: - """ - Parse MCP JSON response and normalize to page_content/metadata format. + """Parse MCP JSON response and normalize to page_content/metadata format.""" + return self._parse_mcp_response_to_dicts(mcp_response, strict=True) - The Vector Search MCP server returns a JSON array of flat result dicts. - We parse and normalize each result for consistent output format. - """ + def _execute_mcp_path( + self, + query: str, + filters: Optional[Union[Dict[str, Any], List[FilterItem]]] = None, + **kwargs: Any, + ) -> List[Dict]: try: - parsed = json.loads(mcp_response) - except json.JSONDecodeError as e: - _logger.error(f"Failed to parse MCP response as JSON: {mcp_response[:200]}...") - raise ValueError( - f"Unable to parse MCP response. Expected JSON format. Error: {e}" - ) from e - - if not isinstance(parsed, list): - # Show preview of what we got (limit to 500 chars for readability) - response_preview = str(parsed)[:500] - _logger.error( - f"MCP response is not a list: {type(parsed).__name__}. Content: {response_preview}" - ) - raise ValueError( - f"Expected MCP vector search to return a JSON array of results, " - f"but got {type(parsed).__name__}: {response_preview}" - ) - - return [self._normalize_mcp_result(result) for result in parsed] + mcp_execute = self._create_or_get_mcp_toolkit() + meta = self._build_mcp_meta(filters, **kwargs) + mcp_response = mcp_execute(query=query, _meta=meta) + return self._parse_mcp_response(mcp_response) + except Exception as e: + self._handle_mcp_execution_error(e) def _execute_direct_api_path( self, @@ -439,20 +311,36 @@ def _execute_direct_api_path( ) return [doc for doc, _ in docs_with_score] - def _execute_mcp_path( + @vector_search_retriever_tool_trace + def execute( self, query: str, filters: Optional[Union[Dict[str, Any], List[FilterItem]]] = None, + openai_client: OpenAI = None, **kwargs: Any, ) -> List[Dict]: - try: - mcp_execute = self._create_or_get_mcp_toolkit() - meta = self._build_mcp_meta(filters, **kwargs) - mcp_response = mcp_execute(query=query, _meta=meta) - documents = self._parse_mcp_response(mcp_response) - return documents - except Exception as e: - _logger.error(f"MCP vector search failed: {e}", exc_info=True) - raise RuntimeError( - f"Vector search via MCP failed for index {self.index_name}. Error: {e}" - ) from e + """ + Execute the VectorSearchIndex tool calls from the ChatCompletions response that correspond to the + self.tool VectorSearchRetrieverToolInput and attach the retrieved documents into tool call messages. + + Execute vector search with automatic routing: + - MCP path: Used for Databricks-managed embeddings (no embedding model configuration needed) + - Direct API path: Used for self-managed embeddings (requires openai_client) + + Args: + query: The query text to use for the retrieval. + filters: Optional filters to refine vector search results. + openai_client: The OpenAI client object used to generate embeddings for retrieval queries. + Only used for self-managed embeddings. If not provided, the default OpenAI + client in the current environment will be used. + **kwargs: Additional search parameters (e.g., num_results, query_type, score_threshold, reranker). + For Databricks-managed embeddings, these are passed as MCP metadata. + For self-managed embeddings, these are passed to similarity_search(). + + Returns: + A list of document dictionaries. Format may vary between MCP and Direct API paths. + """ + if self._index_details.is_databricks_managed_embeddings(): + return self._execute_mcp_path(query, filters, **kwargs) + else: + return self._execute_direct_api_path(query, filters, openai_client, **kwargs) diff --git a/integrations/openai/tests/unit_tests/test_vector_search_retriever_tool.py b/integrations/openai/tests/unit_tests/test_vector_search_retriever_tool.py index 4ac64f20e..2558610d8 100644 --- a/integrations/openai/tests/unit_tests/test_vector_search_retriever_tool.py +++ b/integrations/openai/tests/unit_tests/test_vector_search_retriever_tool.py @@ -854,38 +854,44 @@ def test_reranker_is_overriden(execution_path) -> None: class TestMCPResponseNormalization: """Test that MCP responses are normalized to match Direct API format.""" - def test_normalize_mcp_result_basic(self) -> None: - """Test basic normalization of a single MCP result.""" + def test_parse_mcp_response_basic_normalization(self) -> None: + """Test basic normalization of MCP results via _parse_mcp_response.""" vector_search_tool = init_vector_search_tool(DELTA_SYNC_INDEX) - mcp_result = { - "id": "doc-123", - "text": "This is the document content", - "score": 0.95, - } + mcp_response = json.dumps([ + { + "id": "doc-123", + "text": "This is the document content", + "score": 0.95, + } + ]) - normalized = vector_search_tool._normalize_mcp_result(mcp_result) + results = vector_search_tool._parse_mcp_response(mcp_response) - assert normalized["page_content"] == "This is the document content" - assert normalized["metadata"]["id"] == "doc-123" - assert normalized["metadata"]["score"] == 0.95 - assert "text" not in normalized["metadata"] # text column moved to page_content + assert len(results) == 1 + assert results[0]["page_content"] == "This is the document content" + assert results[0]["metadata"]["id"] == "doc-123" + assert results[0]["metadata"]["score"] == 0.95 + assert "text" not in results[0]["metadata"] # text column moved to page_content - def test_normalize_mcp_result_missing_text_column(self) -> None: + def test_parse_mcp_response_missing_text_column(self) -> None: """Test normalization handles missing text column gracefully.""" vector_search_tool = init_vector_search_tool(DELTA_SYNC_INDEX) - mcp_result = { - "id": "doc-789", - "score": 0.75, - # "text" column is missing - } + mcp_response = json.dumps([ + { + "id": "doc-789", + "score": 0.75, + # "text" column is missing + } + ]) - normalized = vector_search_tool._normalize_mcp_result(mcp_result) + results = vector_search_tool._parse_mcp_response(mcp_response) - assert normalized["page_content"] == "" # Empty string when text column missing - assert normalized["metadata"]["id"] == "doc-789" - assert normalized["metadata"]["score"] == 0.75 + assert len(results) == 1 + # When text column is missing, the dict is converted to string + assert results[0]["metadata"]["id"] == "doc-789" + assert results[0]["metadata"]["score"] == 0.75 def test_parse_mcp_response_empty_list(self) -> None: """Test parsing empty MCP response.""" @@ -911,5 +917,5 @@ def test_parse_mcp_response_not_a_list(self) -> None: # MCP should return a list, not a dict mcp_response = json.dumps({"error": "something went wrong"}) - with pytest.raises(ValueError, match="Expected MCP vector search to return a JSON array"): + with pytest.raises(ValueError, match="Expected JSON array, got"): vector_search_tool._parse_mcp_response(mcp_response) diff --git a/src/databricks_ai_bridge/vector_search_retriever_tool.py b/src/databricks_ai_bridge/vector_search_retriever_tool.py index 6b2aed2d9..4286a2411 100644 --- a/src/databricks_ai_bridge/vector_search_retriever_tool.py +++ b/src/databricks_ai_bridge/vector_search_retriever_tool.py @@ -1,7 +1,8 @@ +import json import logging import re from functools import wraps -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Tuple, Union import mlflow from databricks.sdk import WorkspaceClient @@ -298,3 +299,118 @@ def _get_tool_name(self) -> str: ) return tool_name[-64:] return tool_name + + def _normalize_filters( + self, filters: Optional[Union[Dict[str, Any], List["FilterItem"]]] + ) -> Dict[str, Any]: + """Normalize filters to dict format.""" + if filters is None: + return {} + if isinstance(filters, dict): + return filters + return {item.model_dump()["key"]: item.model_dump()["value"] for item in filters} + + def _parse_index_name(self) -> Tuple[str, str, str]: + """Parse index_name into (catalog, schema, index) tuple.""" + parts = self.index_name.split(".") + if len(parts) != 3: + raise ValueError( + f"Invalid index name format: {self.index_name}. Expected 'catalog.schema.index'" + ) + return parts[0], parts[1], parts[2] + + def _handle_mcp_creation_error(self, error: Exception) -> None: + """Raise standardized error for MCP initialization failures.""" + raise RuntimeError( + f"Failed to initialize MCP tool for index {self.index_name}. " + f"Ensure the index exists and is configured for Databricks-managed embeddings. " + f"Error: {error}" + ) from error + + def _validate_mcp_tools(self, tools: list) -> None: + """Validate that MCP tools were returned.""" + if not tools: + raise ValueError(f"No MCP tools found for index {self.index_name}") + + def _handle_mcp_execution_error(self, error: Exception) -> None: + """Log and raise standardized error for MCP execution failures.""" + _logger.error(f"MCP vector search failed: {error}", exc_info=True) + raise RuntimeError( + f"Vector search via MCP failed for index {self.index_name}. Error: {error}" + ) from error + + def _build_mcp_params( + self, + filters: Optional[Union[Dict[str, Any], List["FilterItem"]]] = None, + **kwargs: Any, + ) -> Dict[str, Any]: + """Build common MCP parameters dict (excludes query).""" + kwargs = {**(self.model_extra or {}), **kwargs} + params: Dict[str, Any] = {} + + num_results = kwargs.pop("num_results", kwargs.pop("k", self.num_results)) + if num_results: + params["num_results"] = num_results + + query_type = kwargs.pop("query_type", self.query_type) + if query_type: + params["query_type"] = query_type + + combined_filters = {**self._normalize_filters(filters), **(self.filters or {})} + if combined_filters: + try: + params["filters"] = json.dumps(combined_filters) + except (TypeError, ValueError) as e: + raise ValueError(f"Filters must be JSON serializable: {e}") from e + + if self.columns: + params["columns"] = ",".join(self.columns) + + if "score_threshold" in kwargs: + params["score_threshold"] = float(kwargs.pop("score_threshold")) + + params["include_score"] = "true" if self.include_score else "false" + + reranker = kwargs.pop("reranker", self.reranker) + if reranker and hasattr(reranker, "columns_to_rerank"): + params["columns_to_rerank"] = ",".join(reranker.columns_to_rerank) + + if kwargs: + _logger.warning(f"Ignoring unsupported kwargs for MCP vector search: {list(kwargs.keys())}") + + return params + + def _parse_mcp_response_to_dicts( + self, + mcp_response: str, + text_column: Optional[str] = None, + strict: bool = True + ) -> List[Dict[str, Any]]: + """Parse MCP JSON response to list of dicts with page_content/metadata structure.""" + text_col = text_column or getattr(self, 'text_column', None) or "text" + + try: + parsed = json.loads(mcp_response) + except json.JSONDecodeError as e: + if strict: + _logger.error(f"Failed to parse MCP response as JSON: {mcp_response[:200]}...") + raise ValueError(f"Unable to parse MCP response. Expected JSON format. Error: {e}") from e + return [{"page_content": mcp_response, "metadata": {}}] + + if not isinstance(parsed, list): + if strict: + response_preview = str(parsed)[:500] + _logger.error(f"MCP response is not a list: {type(parsed).__name__}. Content: {response_preview}") + raise ValueError(f"Expected JSON array, got {type(parsed).__name__}: {response_preview}") + return [{"page_content": str(parsed), "metadata": {}}] + + results = [] + for item in parsed: + if isinstance(item, dict): + page_content = item.get(text_col, str(item)) + metadata = {k: v for k, v in item.items() if k != text_col} + results.append({"page_content": page_content, "metadata": metadata}) + else: + results.append({"page_content": str(item), "metadata": {}}) + + return results From d5bebb893e227e41cf23964d0a4894a1088d4ec9 Mon Sep 17 00:00:00 2001 From: Nisha Balaji Date: Mon, 26 Jan 2026 18:04:42 -0800 Subject: [PATCH 02/11] ruff --- .../test_vector_search_retriever_tool.py | 23 +++++++------ .../test_vector_search_retriever_tool.py | 32 +++++++++++-------- .../vector_search_retriever_tool.py | 23 +++++++------ 3 files changed, 46 insertions(+), 32 deletions(-) diff --git a/integrations/langchain/tests/unit_tests/test_vector_search_retriever_tool.py b/integrations/langchain/tests/unit_tests/test_vector_search_retriever_tool.py index b51f41167..ff87c9079 100644 --- a/integrations/langchain/tests/unit_tests/test_vector_search_retriever_tool.py +++ b/integrations/langchain/tests/unit_tests/test_vector_search_retriever_tool.py @@ -69,11 +69,14 @@ def mock_mcp_infrastructure(): # Create mock MCP server mock_server_instance = MagicMock() - with patch( - "databricks_langchain.vector_search_retriever_tool.DatabricksMultiServerMCPClient" - ) as mock_client_class, patch( - "databricks_langchain.vector_search_retriever_tool.DatabricksMCPServer" - ) as mock_server_class: + with ( + patch( + "databricks_langchain.vector_search_retriever_tool.DatabricksMultiServerMCPClient" + ) as mock_client_class, + patch( + "databricks_langchain.vector_search_retriever_tool.DatabricksMCPServer" + ) as mock_server_class, + ): mock_client_class.return_value = mock_client_instance mock_server_class.from_vector_search.return_value = mock_server_instance yield { @@ -283,7 +286,9 @@ def test_vector_search_retriever_tool_description_generation(index_name: str) -> @pytest.mark.parametrize("index_name", ALL_INDEX_NAMES) @pytest.mark.parametrize("tool_name", [None, "test_tool"]) -def test_vs_tool_tracing(mock_mcp_infrastructure, index_name: str, tool_name: Optional[str]) -> None: +def test_vs_tool_tracing( + mock_mcp_infrastructure, index_name: str, tool_name: Optional[str] +) -> None: vector_search_tool = init_vector_search_tool(index_name, tool_name=tool_name) vector_search_tool._run("Databricks Agent Framework") @@ -791,7 +796,9 @@ def test_build_mcp_input() -> None: # kwargs override defaults tool_with_defaults = init_vector_search_tool(DELTA_SYNC_INDEX, num_results=10, query_type="ANN") - mcp_input = tool_with_defaults._build_mcp_input("test query", num_results=5, query_type="HYBRID") + mcp_input = tool_with_defaults._build_mcp_input( + "test query", num_results=5, query_type="HYBRID" + ) assert mcp_input["num_results"] == 5 assert mcp_input["query_type"] == "HYBRID" @@ -815,5 +822,3 @@ def test_build_mcp_input() -> None: tool_with_reranker = init_vector_search_tool(DELTA_SYNC_INDEX, reranker=reranker) mcp_input = tool_with_reranker._build_mcp_input("test query") assert mcp_input["columns_to_rerank"] == "text,title" - - diff --git a/integrations/openai/tests/unit_tests/test_vector_search_retriever_tool.py b/integrations/openai/tests/unit_tests/test_vector_search_retriever_tool.py index 2558610d8..e2f03008e 100644 --- a/integrations/openai/tests/unit_tests/test_vector_search_retriever_tool.py +++ b/integrations/openai/tests/unit_tests/test_vector_search_retriever_tool.py @@ -858,13 +858,15 @@ def test_parse_mcp_response_basic_normalization(self) -> None: """Test basic normalization of MCP results via _parse_mcp_response.""" vector_search_tool = init_vector_search_tool(DELTA_SYNC_INDEX) - mcp_response = json.dumps([ - { - "id": "doc-123", - "text": "This is the document content", - "score": 0.95, - } - ]) + mcp_response = json.dumps( + [ + { + "id": "doc-123", + "text": "This is the document content", + "score": 0.95, + } + ] + ) results = vector_search_tool._parse_mcp_response(mcp_response) @@ -878,13 +880,15 @@ def test_parse_mcp_response_missing_text_column(self) -> None: """Test normalization handles missing text column gracefully.""" vector_search_tool = init_vector_search_tool(DELTA_SYNC_INDEX) - mcp_response = json.dumps([ - { - "id": "doc-789", - "score": 0.75, - # "text" column is missing - } - ]) + mcp_response = json.dumps( + [ + { + "id": "doc-789", + "score": 0.75, + # "text" column is missing + } + ] + ) results = vector_search_tool._parse_mcp_response(mcp_response) diff --git a/src/databricks_ai_bridge/vector_search_retriever_tool.py b/src/databricks_ai_bridge/vector_search_retriever_tool.py index 4286a2411..a975e9bc9 100644 --- a/src/databricks_ai_bridge/vector_search_retriever_tool.py +++ b/src/databricks_ai_bridge/vector_search_retriever_tool.py @@ -376,32 +376,37 @@ def _build_mcp_params( params["columns_to_rerank"] = ",".join(reranker.columns_to_rerank) if kwargs: - _logger.warning(f"Ignoring unsupported kwargs for MCP vector search: {list(kwargs.keys())}") + _logger.warning( + f"Ignoring unsupported kwargs for MCP vector search: {list(kwargs.keys())}" + ) return params def _parse_mcp_response_to_dicts( - self, - mcp_response: str, - text_column: Optional[str] = None, - strict: bool = True + self, mcp_response: str, text_column: Optional[str] = None, strict: bool = True ) -> List[Dict[str, Any]]: """Parse MCP JSON response to list of dicts with page_content/metadata structure.""" - text_col = text_column or getattr(self, 'text_column', None) or "text" + text_col = text_column or getattr(self, "text_column", None) or "text" try: parsed = json.loads(mcp_response) except json.JSONDecodeError as e: if strict: _logger.error(f"Failed to parse MCP response as JSON: {mcp_response[:200]}...") - raise ValueError(f"Unable to parse MCP response. Expected JSON format. Error: {e}") from e + raise ValueError( + f"Unable to parse MCP response. Expected JSON format. Error: {e}" + ) from e return [{"page_content": mcp_response, "metadata": {}}] if not isinstance(parsed, list): if strict: response_preview = str(parsed)[:500] - _logger.error(f"MCP response is not a list: {type(parsed).__name__}. Content: {response_preview}") - raise ValueError(f"Expected JSON array, got {type(parsed).__name__}: {response_preview}") + _logger.error( + f"MCP response is not a list: {type(parsed).__name__}. Content: {response_preview}" + ) + raise ValueError( + f"Expected JSON array, got {type(parsed).__name__}: {response_preview}" + ) return [{"page_content": str(parsed), "metadata": {}}] results = [] From 2a8d33fa17633ea5a9e1db0ae844637c3823b467 Mon Sep 17 00:00:00 2001 From: Nisha Balaji Date: Tue, 27 Jan 2026 10:08:32 -0800 Subject: [PATCH 03/11] clean up langchain tests --- .../test_vector_search_retriever_tool.py | 1096 +++++------------ 1 file changed, 309 insertions(+), 787 deletions(-) diff --git a/integrations/langchain/tests/unit_tests/test_vector_search_retriever_tool.py b/integrations/langchain/tests/unit_tests/test_vector_search_retriever_tool.py index ff87c9079..ff6d641a1 100644 --- a/integrations/langchain/tests/unit_tests/test_vector_search_retriever_tool.py +++ b/integrations/langchain/tests/unit_tests/test_vector_search_retriever_tool.py @@ -1,824 +1,346 @@ -import json -import os -import threading -import uuid -from typing import Any, Dict, List, Optional -from unittest.mock import AsyncMock, MagicMock, create_autospec, patch - -import mlflow -import pytest -from databricks.sdk import WorkspaceClient -from databricks.sdk.credentials_provider import ModelServingUserCredentials -from databricks.vector_search.utils import CredentialStrategy -from databricks_ai_bridge.test_utils.vector_search import ( # noqa: F401 - ALL_INDEX_NAMES, - DELTA_SYNC_INDEX, - DELTA_SYNC_INDEX_EMBEDDING_MODEL_ENDPOINT_NAME, - DIRECT_ACCESS_INDEX, - INPUT_TEXTS, - _get_index, - mock_vs_client, - mock_workspace_client, +import inspect +import logging +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +from databricks.vector_search.client import VectorSearchIndex +from databricks_ai_bridge.utils.vector_search import ( + IndexDetails, + RetrieverSchema, + parse_vector_search_response, + validate_and_get_return_columns, + validate_and_get_text_column, ) -from databricks_ai_bridge.vector_search_retriever_tool import FilterItem -from langchain_core.documents import Document -from langchain_core.embeddings import Embeddings -from langchain_core.tools import BaseTool -from mlflow.entities import SpanType -from mlflow.models.resources import ( - DatabricksServingEndpoint, - DatabricksVectorSearchIndex, +from databricks_ai_bridge.vector_search_retriever_tool import ( + FilterItem, + VectorSearchRetrieverToolMixin, + vector_search_retriever_tool_trace, ) +from openai import OpenAI, pydantic_function_tool +from openai.types.chat import ChatCompletionToolParam +from pydantic import Field, PrivateAttr, model_validator -from databricks_langchain import ( - ChatDatabricks, - VectorSearchRetrieverTool, -) -from tests.utils.chat_models import llm, mock_client # noqa: F401 -from tests.utils.vector_search import ( - EMBEDDING_MODEL, - embeddings, # noqa: F401 -) -from tests.utils.vector_search import ( - mock_client as mock_embeddings_client, # noqa: F401 -) +from databricks_openai.mcp_server_toolkit import McpServerToolkit +_logger = logging.getLogger(__name__) -def _create_mcp_response_json(texts: List[str] = None) -> str: - """Create a mock MCP response in JSON format.""" - texts = texts or INPUT_TEXTS - return json.dumps( - [ - {"id": str(uuid.uuid4()), "text": text, "score": 0.85 - (i * 0.1)} - for i, text in enumerate(texts) - ] - ) - - -@pytest.fixture -def mock_mcp_infrastructure(): - """Mock MCP infrastructure for tests that need it.""" - # Create mock MCP tool that returns JSON response - mock_tool = MagicMock() - mock_tool.invoke = MagicMock(return_value=_create_mcp_response_json()) - - # Create mock MCP client - mock_client_instance = MagicMock() - mock_client_instance.get_tools = AsyncMock(return_value=[mock_tool]) - - # Create mock MCP server - mock_server_instance = MagicMock() - - with ( - patch( - "databricks_langchain.vector_search_retriever_tool.DatabricksMultiServerMCPClient" - ) as mock_client_class, - patch( - "databricks_langchain.vector_search_retriever_tool.DatabricksMCPServer" - ) as mock_server_class, - ): - mock_client_class.return_value = mock_client_instance - mock_server_class.from_vector_search.return_value = mock_server_instance - yield { - "client_class": mock_client_class, - "client_instance": mock_client_instance, - "server_class": mock_server_class, - "server_instance": mock_server_instance, - "tool": mock_tool, - } +class VectorSearchRetrieverTool(VectorSearchRetrieverToolMixin): + """ + A utility class to create a vector search-based retrieval tool for querying indexed embeddings. + This class integrates with Databricks Vector Search and provides a convenient interface + for tool calling using the OpenAI SDK. -@pytest.fixture(params=["mcp", "direct_api"]) -def execution_path(request, mock_mcp_infrastructure): - """Parametrized fixture that sets up mocks for MCP or Direct API path.""" - if request.param == "mcp": - yield { - "path": "mcp", - "index_name": DELTA_SYNC_INDEX, - "mock_tool": mock_mcp_infrastructure["tool"], - "mock_mcp": mock_mcp_infrastructure, - } - else: - # For direct API, use an index that requires self-managed embeddings - yield { - "path": "direct_api", - "index_name": DIRECT_ACCESS_INDEX, - "mock_tool": None, - "mock_mcp": mock_mcp_infrastructure, - } + Example: + Step 1: Call model with VectorSearchRetrieverTool defined + .. code-block:: python -def setup_tool_for_path(execution_path, tool): - """Set up mock for the tool based on execution path.""" - if execution_path["path"] == "direct_api": - tool._vector_store.similarity_search = MagicMock(return_value=[]) - - -def init_vector_search_tool( - index_name: str, - columns: Optional[List[str]] = None, - tool_name: Optional[str] = None, - tool_description: Optional[str] = None, - embedding: Optional[Embeddings] = None, - text_column: Optional[str] = None, - doc_uri: Optional[str] = None, - primary_key: Optional[str] = None, - filters: Optional[Dict[str, Any]] = None, - **kwargs: Any, -) -> VectorSearchRetrieverTool: - kwargs.update( - { - "index_name": index_name, - "columns": columns, - "tool_name": tool_name, - "tool_description": tool_description, - "embedding": embedding, - "text_column": text_column, - "doc_uri": doc_uri, - "primary_key": primary_key, - "filters": filters, - } - ) - if index_name != DELTA_SYNC_INDEX: - kwargs.update( - { - "embedding": EMBEDDING_MODEL, - "text_column": "text", - } - ) - return VectorSearchRetrieverTool(**kwargs) # type: ignore[arg-type] - - -@pytest.mark.parametrize("index_name", ALL_INDEX_NAMES) -def test_init(index_name: str) -> None: - vector_search_tool = init_vector_search_tool(index_name) - assert isinstance(vector_search_tool, BaseTool) - assert "'additionalProperties': true" not in str(vector_search_tool.args) + dbvs_tool = VectorSearchRetrieverTool(index_name="catalog.schema.my_index_name") + messages = [ + {"role": "system", "content": "You are a helpful assistant."}, + { + "role": "user", + "content": "Using the Databricks documentation, answer what are AI Gateway inference tables?", + }, + ] + first_response = client.chat.completions.create( + model="gpt-4o", messages=messages, tools=[dbvs_tool.tool] + ) + Step 2: Execute function code – parse the model's response and handle function calls. -@pytest.mark.parametrize("index_name", ALL_INDEX_NAMES) -def test_chat_model_bind_tools(llm: ChatDatabricks, index_name: str) -> None: - from langchain_core.messages import AIMessage + .. code-block:: python - vector_search_tool = init_vector_search_tool(index_name) - llm_with_tools = llm.bind_tools([vector_search_tool]) - response = llm_with_tools.invoke("Which city is hotter today and which is bigger: LA or NY?") - assert isinstance(response, AIMessage) + tool_call = first_response.choices[0].message.tool_calls[0] + args = json.loads(tool_call.function.arguments) + result = dbvs_tool.execute( + query=args["query"], + filters={"category": "governance", "status": "general"}, + num_results=5, + score_threshold=0.7, + ) + Step 3: Supply model with results – so it can incorporate them into its final response. -def test_filters_are_passed_through(execution_path) -> None: - """Test filters are passed through correctly on both paths.""" - tool = init_vector_search_tool(execution_path["index_name"]) - setup_tool_for_path(execution_path, tool) + .. code-block:: python - tool.invoke( - { - "query": "what cities are in Germany", - "filters": [FilterItem(key="country", value="Germany")], - } - ) + messages.append(first_response.choices[0].message) + messages.append( + {"role": "tool", "tool_call_id": tool_call.id, "content": json.dumps(result)} + ) + second_response = client.chat.completions.create( + model="gpt-4o", messages=messages, tools=tools + ) - if execution_path["path"] == "mcp": - execution_path["mock_tool"].invoke.assert_called_once() - call_args = execution_path["mock_tool"].invoke.call_args[0][0] - assert call_args["query"] == "what cities are in Germany" - # MCP path: filters are JSON stringified - assert json.loads(call_args["filters"]) == {"country": "Germany"} - else: - tool._vector_store.similarity_search.assert_called_once() - call_args = tool._vector_store.similarity_search.call_args - assert call_args[1]["query"] == "what cities are in Germany" - assert call_args[1]["filter"] == {"country": "Germany"} - - -def test_filters_are_combined(execution_path) -> None: - """Test filters are combined correctly (predefined + runtime) on both paths.""" - tool = init_vector_search_tool(execution_path["index_name"], filters={"city LIKE": "Berlin"}) - setup_tool_for_path(execution_path, tool) - - tool.invoke( - { - "query": "what cities are in Germany", - "filters": [FilterItem(key="country", value="Germany")], - } + **Note**: Any additional keyword arguments passed to the constructor will be passed along when executing the tool. + The ``execute()`` method supports meta parameters such as ``num_results``, ``score_threshold``, ``query_type``, + ``filters``, ``columns``, and ``columns_to_rerank``. See + :class:`~databricks_ai_bridge.vector_search_retriever_tool.VectorSearchRetrieverToolMixin` for additional supported + constructor arguments. + + WorkspaceClient instances with auth types PAT, OAuth-M2M (client ID and client secret), or model serving credential + strategy will be used to instantiate the underlying VectorSearchClient. + """ + + text_column: Optional[str] = Field( + None, + description="The name of the text column to use for the embeddings. " + "Required for direct-access index or delta-sync index with " + "self-managed embeddings.", ) - - expected_filters = {"city LIKE": "Berlin", "country": "Germany"} - if execution_path["path"] == "mcp": - execution_path["mock_tool"].invoke.assert_called_once() - call_args = execution_path["mock_tool"].invoke.call_args[0][0] - assert call_args["query"] == "what cities are in Germany" - # MCP path: filters are JSON stringified - assert json.loads(call_args["filters"]) == expected_filters - else: - tool._vector_store.similarity_search.assert_called_once() - call_args = tool._vector_store.similarity_search.call_args - assert call_args[1]["query"] == "what cities are in Germany" - assert call_args[1]["filter"] == expected_filters - - -@pytest.mark.parametrize("index_name", ALL_INDEX_NAMES) -@pytest.mark.parametrize("columns", [None, ["id", "text"]]) -@pytest.mark.parametrize("tool_name", [None, "test_tool"]) -@pytest.mark.parametrize("tool_description", [None, "Test tool for vector search"]) -@pytest.mark.parametrize("embedding", [None, EMBEDDING_MODEL]) -@pytest.mark.parametrize("text_column", [None, "text"]) -def test_vector_search_retriever_tool_combinations( - mock_mcp_infrastructure, - index_name: str, - columns: Optional[List[str]], - tool_name: Optional[str], - tool_description: Optional[str], - embedding: Optional[Any], - text_column: Optional[str], -) -> None: - if index_name == DELTA_SYNC_INDEX: - embedding = None - text_column = None - - vector_search_tool = init_vector_search_tool( - index_name=index_name, - columns=columns, - tool_name=tool_name, - tool_description=tool_description, - embedding=embedding, - text_column=text_column, + embedding_model_name: Optional[str] = Field( + None, + description="The name of the embedding model to use for embedding the query text." + "Required for direct-access index or delta-sync index with " + "self-managed embeddings.", ) - assert isinstance(vector_search_tool, BaseTool) - result = vector_search_tool.invoke("Databricks Agent Framework") - assert result is not None - -def test_vector_search_retriever_tool_doc_uri_primary_key(mock_mcp_infrastructure) -> None: - """Test that doc_uri and primary_key work correctly with MCP path.""" - vector_search_tool = init_vector_search_tool( - index_name=DELTA_SYNC_INDEX, - doc_uri="uri", - primary_key="id", + tool: ChatCompletionToolParam = Field( + None, description="The tool input used in the OpenAI chat completion SDK" ) - assert isinstance(vector_search_tool, BaseTool) - result = vector_search_tool.invoke("Databricks Agent Framework") - # With MCP path, results are parsed from mock JSON response - assert result is not None - assert len(result) > 0 - assert all(isinstance(doc, Document) for doc in result) - # Verify Documents have expected structure from mock response - assert all(doc.page_content for doc in result) - assert all("id" in doc.metadata for doc in result) - - -@pytest.mark.parametrize("index_name", ALL_INDEX_NAMES) -def test_vector_search_retriever_tool_description_generation(index_name: str) -> None: - vector_search_tool = init_vector_search_tool(index_name) - assert vector_search_tool.name != "" - assert vector_search_tool.description != "" - assert vector_search_tool.name == index_name.replace(".", "__") - assert ( - "A vector search-based retrieval tool for querying indexed embeddings." - in vector_search_tool.description - ) - assert vector_search_tool.args_schema.model_fields["query"] is not None - assert vector_search_tool.args_schema.model_fields["query"].description == ( - "The string used to query the index with and identify the most similar " - "vectors and return the associated documents." - ) - + _index: VectorSearchIndex = PrivateAttr() + _index_details: IndexDetails = PrivateAttr() + _mcp_toolkit: Optional[McpServerToolkit] = PrivateAttr(default=None) + _mcp_tool_execute: Optional[Callable] = PrivateAttr(default=None) + + @model_validator(mode="after") + def _validate_tool_inputs(self): + from databricks.vector_search.client import ( + VectorSearchClient, # import here so we can mock in tests + ) + from databricks.vector_search.utils import CredentialStrategy -@pytest.mark.parametrize("index_name", ALL_INDEX_NAMES) -@pytest.mark.parametrize("tool_name", [None, "test_tool"]) -def test_vs_tool_tracing( - mock_mcp_infrastructure, index_name: str, tool_name: Optional[str] -) -> None: - vector_search_tool = init_vector_search_tool(index_name, tool_name=tool_name) - vector_search_tool._run("Databricks Agent Framework") - - trace = mlflow.get_trace(mlflow.get_last_active_trace_id()) - spans = trace.search_spans(name=tool_name or index_name, span_type=SpanType.RETRIEVER) - assert len(spans) == 1 - inputs = json.loads(trace.to_dict()["data"]["spans"][0]["attributes"]["mlflow.spanInputs"]) - assert inputs["query"] == "Databricks Agent Framework" - outputs = json.loads(trace.to_dict()["data"]["spans"][0]["attributes"]["mlflow.spanOutputs"]) - # Verify outputs are Documents with page_content - assert len(outputs) > 0 - assert all("page_content" in d for d in outputs) - assert all(d["page_content"] for d in outputs) # page_content is not empty - - -@pytest.mark.parametrize("index_name", ALL_INDEX_NAMES) -def test_vector_search_retriever_tool_resources( - mock_embeddings_client, - embeddings, - index_name: str, -) -> None: - text_column = "text" - if index_name == DELTA_SYNC_INDEX: - embeddings = None - text_column = None - - vector_search_tool = VectorSearchRetrieverTool( - index_name=index_name, embedding=embeddings, text_column=text_column - ) - expected_resources = ( - [DatabricksVectorSearchIndex(index_name=index_name)] - + ([DatabricksServingEndpoint(endpoint_name=embeddings.endpoint)] if embeddings else []) - + ( - [ - DatabricksServingEndpoint( - endpoint_name=DELTA_SYNC_INDEX_EMBEDDING_MODEL_ENDPOINT_NAME + splits = self.index_name.split(".") + if len(splits) != 3: + raise ValueError( + f"Index name {self.index_name} is not in the expected format 'catalog.schema.index'." + ) + client_args = { + "disable_notice": True, + } + if self.workspace_client is not None: + config = self.workspace_client.config + if config.auth_type == "model_serving_user_credentials": + client_args.setdefault( + "credential_strategy", CredentialStrategy.MODEL_SERVING_USER_CREDENTIALS ) - ] - if index_name == DELTA_SYNC_INDEX - else [] + elif config.auth_type == "pat": + client_args.setdefault("personal_access_token", config.token) + elif config.auth_type == "oauth-m2m": + client_args.setdefault("workspace_url", config.host) + client_args.setdefault("service_principal_client_id", config.client_id) + client_args.setdefault("service_principal_client_secret", config.client_secret) + self._index = VectorSearchClient(**client_args).get_index(index_name=self.index_name) + self._index_details = IndexDetails(self._index) + self.text_column = validate_and_get_text_column(self.text_column, self._index_details) + self.columns = validate_and_get_return_columns( + self.columns or [], + self.text_column, + self._index_details, + self.doc_uri, + self.primary_key, ) - ) - assert [res.to_dict() for res in vector_search_tool.resources] == [ - res.to_dict() for res in expected_resources - ] - - -@pytest.mark.parametrize("tool_name", [None, "valid_tool_name", "test_tool"]) -def test_tool_name_validation_valid(tool_name: Optional[str]) -> None: - index_name = "catalog.schema.index" - tool = init_vector_search_tool(index_name, tool_name=tool_name) - assert tool.tool_name == tool_name - if tool_name: - assert tool.name == tool_name - - -@pytest.mark.parametrize("tool_name", ["test.tool.name", "tool&name"]) -def test_tool_name_validation_invalid(tool_name: str) -> None: - index_name = "catalog.schema.index" - with pytest.raises(ValueError): - init_vector_search_tool(index_name, tool_name=tool_name) - - -@pytest.mark.parametrize( - "index_name,name", - [ - ("catalog.schema.index", "catalog__schema__index"), - ("cata_log.schema_.index", "cata_log__schema___index"), - ], -) -def test_index_name_to_tool_name(index_name: str, name: str) -> None: - vector_search_tool = init_vector_search_tool(index_name) - assert vector_search_tool.name == name - - -def test_vector_search_client_model_serving_environment(): - with patch("os.path.isfile", return_value=True): - # Simulate Model Serving Environment - os.environ["IS_IN_DB_MODEL_SERVING_ENV"] = "true" - - # Fake credential token - current_thread = threading.current_thread() - thread_data = current_thread.__dict__ - thread_data["invokers_token"] = "abc" - - w = WorkspaceClient( - host="testDogfod.com", credentials_strategy=ModelServingUserCredentials() + self._retriever_schema = RetrieverSchema( + text_column=self.text_column, + doc_uri=self.doc_uri, + primary_key=self.primary_key, + other_columns=self.columns, ) - with patch("databricks.vector_search.client.VectorSearchClient") as mockVSClient: - mock_instance = mockVSClient.return_value - mock_instance.get_index.side_effect = _get_index - with patch("databricks.sdk.service.serving.ServingEndpointsAPI.get", return_value=None): - vsTool = VectorSearchRetrieverTool( - index_name="test.delta_sync.index", - tool_description="desc", - workspace_client=w, - ) - mockVSClient.assert_called_once_with( - disable_notice=True, - credential_strategy=CredentialStrategy.MODEL_SERVING_USER_CREDENTIALS, - ) + if ( + not self._index_details.is_databricks_managed_embeddings() + and not self.embedding_model_name + ): + raise ValueError( + "The embedding model name is required for non-Databricks-managed " + "embeddings Vector Search indexes in order to generate embeddings for retrieval queries." + ) + tool_name = self._get_tool_name() -def test_vector_search_client_non_model_serving_environment(): - with patch("databricks.vector_search.client.VectorSearchClient") as mockVSClient: - mock_instance = mockVSClient.return_value - mock_instance.get_index.side_effect = _get_index - vsTool = VectorSearchRetrieverTool( - index_name="test.delta_sync.index", - tool_description="desc", + # Create tool input model based on dynamic_filter setting + if self.dynamic_filter: + tool_input_class = self._create_enhanced_input_model() + else: + tool_input_class = self._create_basic_input_model() + + self.tool = pydantic_function_tool( + tool_input_class, + name=tool_name, + description=self.tool_description + or self._get_default_tool_description(self._index_details), ) - mockVSClient.assert_called_once_with(disable_notice=True) - - -def test_vector_search_client_with_pat_workspace_client(): - w = WorkspaceClient(host="testDogfod.com", token="fakeToken") - with patch("databricks.vector_search.client.VectorSearchClient") as mockVSClient: - with patch("databricks.sdk.service.serving.ServingEndpointsAPI.get", return_value=None): - mock_instance = mockVSClient.return_value - mock_instance.get_index.side_effect = _get_index - VectorSearchRetrieverTool( - index_name="test.delta_sync.index", - tool_description="desc", - workspace_client=w, - ) - mockVSClient.assert_called_once_with( - disable_notice=True, personal_access_token="fakeToken" + # We need to remove strict: True from the tool in order to support arbitrary filters + if "function" in self.tool and "strict" in self.tool["function"]: + del self.tool["function"]["strict"] + # We need to remove additionalProperties from the tool in order to support arbitrary kwargs + if ( + "function" in self.tool + and "parameters" in self.tool["function"] + and "additionalProperties" in self.tool["function"]["parameters"] + ): + del self.tool["function"]["parameters"]["additionalProperties"] + + try: + from databricks.sdk import WorkspaceClient + from databricks.sdk.errors.platform import ResourceDoesNotExist + + if self.workspace_client is not None: + self.workspace_client.serving_endpoints.get(self.embedding_model_name) + else: + WorkspaceClient().serving_endpoints.get(self.embedding_model_name) + self.resources = self._get_resources( + self.index_name, self.embedding_model_name, self._index_details ) - - -def test_vector_search_client_with_sp_workspace_client(): - # Create a proper mock workspace client that passes isinstance check - w = create_autospec(WorkspaceClient, instance=True) - w.config.auth_type = "oauth-m2m" - w.config.host = "testDogfod.com" - w.config.client_id = "fakeClientId" - w.config.client_secret = "fakeClientSecret" - - with patch("databricks.vector_search.client.VectorSearchClient") as mockVSClient: - with patch("databricks.sdk.service.serving.ServingEndpointsAPI.get", return_value=None): - mock_instance = mockVSClient.return_value - mock_instance.get_index.side_effect = _get_index - VectorSearchRetrieverTool( - index_name="test.delta_sync.index", - tool_description="desc", - workspace_client=w, + except ResourceDoesNotExist: + self.resources = self._get_resources(self.index_name, None, self._index_details) + + return self + + def _create_or_get_mcp_toolkit(self) -> Callable: + """ + If it does not exist, create the MCP tool execution function for this index. + Otherwise, return the execution function. + + Uses McpServerToolkit.from_vector_search(catalog, schema, index_name, workspace_client) + to access tools for the specified vector search index. + """ + if self._mcp_tool_execute is not None: + return self._mcp_tool_execute + + catalog, schema, index = self._parse_index_name() + + try: + self._mcp_toolkit = McpServerToolkit.from_vector_search( + catalog=catalog, + schema=schema, + index_name=index, + workspace_client=self.workspace_client, ) - mockVSClient.assert_called_once_with( - disable_notice=True, - workspace_url="testDogfod.com", - service_principal_client_id="fakeClientId", - service_principal_client_secret="fakeClientSecret", + except Exception as e: + self._handle_mcp_creation_error(e) + + tools = self._mcp_toolkit.get_tools() + self._validate_mcp_tools(tools) + + self._mcp_tool_execute = tools[0].execute + return self._mcp_tool_execute + + def _build_mcp_meta( + self, filters: Optional[Union[Dict[str, Any], List[FilterItem]]] = None, **kwargs: Any + ) -> Dict[str, Any]: + """Build metadata dict for MCP tool invocation.""" + return self._build_mcp_params(filters, **kwargs) + + def _parse_mcp_response(self, mcp_response: str) -> List[Dict]: + """Parse MCP JSON response and normalize to page_content/metadata format.""" + return self._parse_mcp_response_to_dicts(mcp_response, strict=True) + + def _execute_mcp_path( + self, + query: str, + filters: Optional[Union[Dict[str, Any], List[FilterItem]]] = None, + **kwargs: Any, + ) -> List[Dict]: + try: + mcp_execute = self._create_or_get_mcp_toolkit() + meta = self._build_mcp_meta(filters, **kwargs) + mcp_response = mcp_execute(query=query, _meta=meta) + return self._parse_mcp_response(mcp_response) + except Exception as e: + self._handle_mcp_execution_error(e) + + def _execute_direct_api_path( + self, + query: str, + filters: Optional[Union[Dict[str, Any], List[FilterItem]]] = None, + openai_client: OpenAI = None, + **kwargs: Any, + ) -> List[Dict]: + from openai import OpenAI + + oai_client = openai_client or OpenAI() + if not oai_client.api_key: + raise ValueError( + "OpenAI API key is required to generate embeddings for retrieval queries." ) + signature = inspect.signature(self._index.similarity_search) + kwargs = {**kwargs, **(self.model_extra or {})} + kwargs = {k: v for k, v in kwargs.items() if k in signature.parameters} -def test_kwargs_are_passed_through(execution_path) -> None: - """Test kwargs are passed through correctly on both paths.""" - tool = init_vector_search_tool(execution_path["index_name"], score_threshold=0.5) - setup_tool_for_path(execution_path, tool) - - tool.invoke({"query": "what cities are in Germany"}) - - if execution_path["path"] == "mcp": - execution_path["mock_tool"].invoke.assert_called_once() - call_args = execution_path["mock_tool"].invoke.call_args[0][0] - assert call_args["query"] == "what cities are in Germany" - assert call_args["score_threshold"] == 0.5 - else: - tool._vector_store.similarity_search.assert_called_once() - call_args = tool._vector_store.similarity_search.call_args - assert call_args[1]["query"] == "what cities are in Germany" - assert call_args[1]["score_threshold"] == 0.5 - - -def test_kwargs_override_both_num_results_and_query_type(execution_path) -> None: - """Test kwargs can override num_results and query_type on both paths.""" - tool = init_vector_search_tool(execution_path["index_name"], num_results=10, query_type="ANN") - setup_tool_for_path(execution_path, tool) - - tool.invoke({"query": "what cities are in Germany", "k": 3, "query_type": "HYBRID"}) - - if execution_path["path"] == "mcp": - execution_path["mock_tool"].invoke.assert_called_once() - call_args = execution_path["mock_tool"].invoke.call_args[0][0] - assert call_args["query"] == "what cities are in Germany" - assert call_args["num_results"] == 3 - assert call_args["query_type"] == "HYBRID" - else: - tool._vector_store.similarity_search.assert_called_once() - call_args = tool._vector_store.similarity_search.call_args - assert call_args[1]["query"] == "what cities are in Germany" - assert call_args[1]["k"] == 3 - assert call_args[1]["query_type"] == "HYBRID" - - -def test_enhanced_filter_description_with_column_metadata() -> None: - """Test that the tool args_schema includes enhanced filter descriptions with column metadata.""" - vector_search_tool = init_vector_search_tool(DELTA_SYNC_INDEX, dynamic_filter=True) - - # The LangChain implementation calls index.describe() to get column information - # and includes them in the filter description - args_schema = vector_search_tool.args_schema - filter_field = args_schema.model_fields["filters"] - - # Check that the filter description is enhanced with available columns - # Note: The actual columns will depend on the mocked index.describe() response - assert ( - "Available columns for filtering:" in filter_field.description - or "Optional filters" in filter_field.description - ) + # Allow kwargs to override the default values upon invocation + num_results = kwargs.pop("num_results", self.num_results) + query_type = kwargs.pop("query_type", self.query_type) + reranker = kwargs.pop("reranker", self.reranker) - # Should include comprehensive filter syntax - assert "Inclusion:" in filter_field.description - assert "Exclusion:" in filter_field.description - assert "Comparisons:" in filter_field.description - assert "Pattern match:" in filter_field.description - assert "OR logic:" in filter_field.description - - # Should include examples - assert "Examples:" in filter_field.description - assert "Filter by category:" in filter_field.description - assert "Filter by price range:" in filter_field.description - - -def test_enhanced_filter_description_fails_on_table_metadata_error() -> None: - """Test that tool initialization fails with clear error when table metadata cannot be retrieved.""" - # Mock WorkspaceClient to raise an exception when accessing table metadata - with patch("databricks.sdk.WorkspaceClient") as mock_ws_client_class: - mock_ws_client = MagicMock() - mock_ws_client.tables.get.side_effect = Exception("Permission denied") - mock_ws_client_class.return_value = mock_ws_client - - # Try to initialize tool with dynamic_filter=True - # This should fail because we can't get table metadata - with pytest.raises( - ValueError, - match="Failed to retrieve table metadata for index.*Permission denied", - ): - init_vector_search_tool(DELTA_SYNC_INDEX, dynamic_filter=True) - - -def test_enhanced_filter_description_fails_on_empty_columns() -> None: - """Test that tool initialization fails when table has no valid columns.""" - # Mock WorkspaceClient to return a table with no valid columns (all start with __) - with patch("databricks.sdk.WorkspaceClient") as mock_ws_client_class: - mock_ws_client = MagicMock() - mock_table = MagicMock() - mock_column = MagicMock() - mock_column.name = "__internal_column" - mock_column.type_name = MagicMock() - mock_column.type_name.name = "STRING" - mock_table.columns = [mock_column] - mock_ws_client.tables.get.return_value = mock_table - mock_ws_client_class.return_value = mock_ws_client - - # Try to initialize tool with dynamic_filter=True - # This should fail because there are no valid columns - with pytest.raises( - ValueError, - match="No valid columns found in table metadata for index", - ): - init_vector_search_tool(DELTA_SYNC_INDEX, dynamic_filter=True) - - -def test_cannot_use_both_dynamic_filter_and_predefined_filters() -> None: - """Test that using both dynamic_filter and predefined filters raises an error.""" - # Try to initialize tool with both dynamic_filter=True and predefined filters - with pytest.raises( - ValueError, match="Cannot use both dynamic_filter=True and predefined filters" - ): - init_vector_search_tool( - DELTA_SYNC_INDEX, - filters={"status": "active", "category": "electronics"}, - dynamic_filter=True, + query_text = query if query_type and query_type.upper() == "HYBRID" else None + query_vector = ( + oai_client.embeddings.create(input=query, model=self.embedding_model_name) + .data[0] + .embedding ) + if ( + index_embedding_dimension := self._index_details.embedding_vector_column.get( + "embedding_dimension" + ) + ) and len(query_vector) != index_embedding_dimension: + raise ValueError( + f"Expected embedding dimension {index_embedding_dimension} but got {len(query_vector)}" + ) + combined_filters = {**(self.filters or {}), **self._normalize_filters(filters)} -def test_predefined_filters_work_without_dynamic_filter(execution_path) -> None: - """Test that predefined filters work correctly when dynamic_filter is False on both paths.""" - tool = init_vector_search_tool( - execution_path["index_name"], filters={"status": "active", "category": "electronics"} - ) - setup_tool_for_path(execution_path, tool) - - # The filters parameter should NOT be exposed since dynamic_filter=False - args_schema = tool.args_schema - assert "filters" not in args_schema.model_fields - - tool.invoke({"query": "what electronics are available"}) - - expected_filters = {"status": "active", "category": "electronics"} - if execution_path["path"] == "mcp": - execution_path["mock_tool"].invoke.assert_called_once() - call_args = execution_path["mock_tool"].invoke.call_args[0][0] - assert call_args["query"] == "what electronics are available" - # MCP path: filters are JSON stringified - assert json.loads(call_args["filters"]) == expected_filters - else: - tool._vector_store.similarity_search.assert_called_once() - call_args = tool._vector_store.similarity_search.call_args - assert call_args[1]["query"] == "what electronics are available" - assert call_args[1]["filter"] == expected_filters - - -def test_filter_item_serialization(execution_path) -> None: - """Test that FilterItem objects are properly converted to dictionaries on both paths.""" - tool = init_vector_search_tool(execution_path["index_name"]) - setup_tool_for_path(execution_path, tool) - - # Test various filter types - filters = [ - FilterItem(key="category", value="electronics"), - FilterItem(key="price >=", value=100), - FilterItem(key="status NOT", value="discontinued"), - FilterItem(key="tags", value=["wireless", "bluetooth"]), - ] - - tool.invoke({"query": "find products", "filters": filters}) - - expected_filters = { - "category": "electronics", - "price >=": 100, - "status NOT": "discontinued", - "tags": ["wireless", "bluetooth"], - } - - if execution_path["path"] == "mcp": - execution_path["mock_tool"].invoke.assert_called_once() - call_args = execution_path["mock_tool"].invoke.call_args[0][0] - assert call_args["query"] == "find products" - # MCP path: filters are JSON stringified - assert json.loads(call_args["filters"]) == expected_filters - else: - tool._vector_store.similarity_search.assert_called_once() - call_args = tool._vector_store.similarity_search.call_args - assert call_args[1]["query"] == "find products" - assert call_args[1]["filter"] == expected_filters - - -# ============================================================================= -# MCP Path Specific Tests -# ============================================================================= - - -def test_mcp_path_is_used_for_databricks_managed_embeddings(mock_mcp_infrastructure) -> None: - """Test that MCP path is used for Databricks-managed embeddings indexes.""" - vector_search_tool = init_vector_search_tool(DELTA_SYNC_INDEX) - - # Invoke the tool (should use MCP path for DELTA_SYNC_INDEX which has managed embeddings) - result = vector_search_tool._run("test query") - - # Verify MCP server was created with correct parameters - mock_mcp_infrastructure["server_class"].from_vector_search.assert_called_once() - call_kwargs = mock_mcp_infrastructure["server_class"].from_vector_search.call_args[1] - assert call_kwargs["catalog"] == "test" - assert call_kwargs["schema"] == "delta_sync" - assert call_kwargs["index_name"] == "index" - - # Verify MCP client was used - mock_mcp_infrastructure["client_class"].assert_called_once() - - # Verify MCP tool was invoked - mock_mcp_infrastructure["tool"].invoke.assert_called_once() - - -def test_direct_api_path_is_used_for_self_managed_embeddings(mock_mcp_infrastructure) -> None: - """Test that direct API path is used for self-managed embeddings indexes.""" - # Use an index that requires self-managed embeddings - index_name = "test.direct_access.index" - vector_search_tool = init_vector_search_tool(index_name) - vector_search_tool._vector_store.similarity_search = MagicMock(return_value=[]) - - # Invoke the tool (should use direct API path) - result = vector_search_tool._run("test query") - - # Verify similarity_search was called directly - vector_search_tool._vector_store.similarity_search.assert_called_once() - - # Verify MCP was NOT used for self-managed embeddings - mock_mcp_infrastructure["tool"].invoke.assert_not_called() - - -def test_mcp_tool_is_cached(mock_mcp_infrastructure) -> None: - """Test that MCP tool is cached and not recreated on subsequent calls.""" - vector_search_tool = init_vector_search_tool(DELTA_SYNC_INDEX) - - # Call _run multiple times - vector_search_tool._run("query 1") - vector_search_tool._run("query 2") - vector_search_tool._run("query 3") - - # MCP server should only be created once - assert mock_mcp_infrastructure["server_class"].from_vector_search.call_count == 1 - - # MCP client should only be created once - assert mock_mcp_infrastructure["client_class"].call_count == 1 - - # But MCP tool should be invoked 3 times - assert mock_mcp_infrastructure["tool"].invoke.call_count == 3 - - -def test_mcp_response_parsing_json_array() -> None: - """Test that MCP JSON array response is parsed correctly into Documents.""" - vector_search_tool = init_vector_search_tool(DELTA_SYNC_INDEX) - - json_response = json.dumps( - [ - {"id": "doc1", "text": "content1", "score": 0.9}, - {"id": "doc2", "text": "content2", "score": 0.8}, - ] - ) - - docs = vector_search_tool._parse_mcp_response(json_response) - - assert len(docs) == 2 - assert all(isinstance(doc, Document) for doc in docs) - assert docs[0].page_content == "content1" - assert docs[0].metadata == {"id": "doc1", "score": 0.9} - assert docs[1].page_content == "content2" - assert docs[1].metadata == {"id": "doc2", "score": 0.8} - - -def test_mcp_response_parsing_non_json() -> None: - """Test that non-JSON MCP response is treated as a single document.""" - vector_search_tool = init_vector_search_tool(DELTA_SYNC_INDEX) - - plain_text_response = "This is a plain text response" - - docs = vector_search_tool._parse_mcp_response(plain_text_response) - - assert len(docs) == 1 - assert docs[0].page_content == plain_text_response - - -def test_mcp_response_parsing_non_list_json() -> None: - """Test that non-list JSON is converted to a single document.""" - vector_search_tool = init_vector_search_tool(DELTA_SYNC_INDEX) - - json_response = json.dumps({"message": "single object response"}) - - docs = vector_search_tool._parse_mcp_response(json_response) - - assert len(docs) == 1 - assert docs[0].page_content == "{'message': 'single object response'}" - - -def test_normalize_filters_with_filter_items() -> None: - """Test that FilterItem list is normalized to dict.""" - vector_search_tool = init_vector_search_tool(DELTA_SYNC_INDEX) - - filters = [ - FilterItem(key="category", value="electronics"), - FilterItem(key="price >=", value=100), - ] - - result = vector_search_tool._normalize_filters(filters) - - assert result == {"category": "electronics", "price >=": 100} - - -def test_normalize_filters_with_dict() -> None: - """Test that dict filters are passed through unchanged.""" - vector_search_tool = init_vector_search_tool(DELTA_SYNC_INDEX) - - filters = {"category": "electronics", "price >=": 100} - - result = vector_search_tool._normalize_filters(filters) - - assert result == filters - - -def test_normalize_filters_with_none() -> None: - """Test that None filters return empty dict.""" - vector_search_tool = init_vector_search_tool(DELTA_SYNC_INDEX) - - result = vector_search_tool._normalize_filters(None) - - assert result == {} - - -def test_build_mcp_input() -> None: - """Test MCP input building with various parameters.""" - from databricks.vector_search.reranker import DatabricksReranker - - # Basic parameters - tool = init_vector_search_tool(DELTA_SYNC_INDEX) - mcp_input = tool._build_mcp_input("test query") - assert mcp_input["query"] == "test query" - assert mcp_input["num_results"] == tool.num_results - assert mcp_input["query_type"] == tool.query_type - assert mcp_input["include_score"] == "false" # Default - - # With filters (JSON stringified for MCP - parse back to compare) - filters = [FilterItem(key="category", value="electronics")] - mcp_input = tool._build_mcp_input("test query", filters=filters) - assert json.loads(mcp_input["filters"]) == {"category": "electronics"} - - # Combines predefined and runtime filters - tool_with_filters = init_vector_search_tool(DELTA_SYNC_INDEX, filters={"status": "active"}) - runtime_filters = [FilterItem(key="category", value="electronics")] - mcp_input = tool_with_filters._build_mcp_input("test query", filters=runtime_filters) - expected_filters = {"status": "active", "category": "electronics"} - assert json.loads(mcp_input["filters"]) == expected_filters - - # kwargs override defaults - tool_with_defaults = init_vector_search_tool(DELTA_SYNC_INDEX, num_results=10, query_type="ANN") - mcp_input = tool_with_defaults._build_mcp_input( - "test query", num_results=5, query_type="HYBRID" - ) - assert mcp_input["num_results"] == 5 - assert mcp_input["query_type"] == "HYBRID" - - # With columns (comma-separated for MCP) - tool_with_columns = init_vector_search_tool(DELTA_SYNC_INDEX, columns=["id", "text", "score"]) - mcp_input = tool_with_columns._build_mcp_input("test query") - assert mcp_input["columns"] == "id,text,score" - - # With score_threshold (converted to float) - mcp_input = tool._build_mcp_input("test query", score_threshold=0.7) - assert mcp_input["score_threshold"] == 0.7 - assert isinstance(mcp_input["score_threshold"], float) - - # With include_score=True - tool_with_score = init_vector_search_tool(DELTA_SYNC_INDEX, include_score=True) - mcp_input = tool_with_score._build_mcp_input("test query") - assert mcp_input["include_score"] == "true" - - # With reranker - reranker = DatabricksReranker(columns_to_rerank=["text", "title"]) - tool_with_reranker = init_vector_search_tool(DELTA_SYNC_INDEX, reranker=reranker) - mcp_input = tool_with_reranker._build_mcp_input("test query") - assert mcp_input["columns_to_rerank"] == "text,title" + kwargs.update( + { + "query_text": query_text, + "query_vector": query_vector, + "columns": self.columns, + "filters": combined_filters, + "num_results": num_results, + "query_type": query_type, + "reranker": reranker, + } + ) + search_resp = self._index.similarity_search(**kwargs) + docs_with_score: List[Tuple[Dict, float]] = parse_vector_search_response( + search_resp=search_resp, + retriever_schema=self._retriever_schema, + document_class=dict, + include_score=self.include_score, + ) + return [doc for doc, _ in docs_with_score] + + @vector_search_retriever_tool_trace + def execute( + self, + query: str, + filters: Optional[Union[Dict[str, Any], List[FilterItem]]] = None, + openai_client: OpenAI = None, + **kwargs: Any, + ) -> List[Dict]: + """ + Execute the VectorSearchIndex tool calls from the ChatCompletions response that correspond to the + self.tool VectorSearchRetrieverToolInput and attach the retrieved documents into tool call messages. + + Execute vector search with automatic routing: + - MCP path: Used for Databricks-managed embeddings (no embedding model configuration needed) + - Direct API path: Used for self-managed embeddings (requires openai_client) + + Args: + query: The query text to use for the retrieval. + filters: Optional filters to refine vector search results. + openai_client: The OpenAI client object used to generate embeddings for retrieval queries. + Only used for self-managed embeddings. If not provided, the default OpenAI + client in the current environment will be used. + **kwargs: Additional search parameters (e.g., num_results, query_type, score_threshold, reranker). + For Databricks-managed embeddings, these are passed as MCP metadata. + For self-managed embeddings, these are passed to similarity_search(). + + Returns: + A list of document dictionaries. Format may vary between MCP and Direct API paths. + """ + if self._index_details.is_databricks_managed_embeddings(): + return self._execute_mcp_path(query, filters, **kwargs) + else: + return self._execute_direct_api_path(query, filters, openai_client, **kwargs) From ebfc4b911ac6ee1d0d802b448eabb8d1c533dc95 Mon Sep 17 00:00:00 2001 From: Nisha Balaji Date: Tue, 27 Jan 2026 10:12:54 -0800 Subject: [PATCH 04/11] revert --- .../test_vector_search_retriever_tool.py | 1096 ++++++++++++----- 1 file changed, 787 insertions(+), 309 deletions(-) diff --git a/integrations/langchain/tests/unit_tests/test_vector_search_retriever_tool.py b/integrations/langchain/tests/unit_tests/test_vector_search_retriever_tool.py index ff6d641a1..ff87c9079 100644 --- a/integrations/langchain/tests/unit_tests/test_vector_search_retriever_tool.py +++ b/integrations/langchain/tests/unit_tests/test_vector_search_retriever_tool.py @@ -1,346 +1,824 @@ -import inspect -import logging -from typing import Any, Callable, Dict, List, Optional, Tuple, Union - -from databricks.vector_search.client import VectorSearchIndex -from databricks_ai_bridge.utils.vector_search import ( - IndexDetails, - RetrieverSchema, - parse_vector_search_response, - validate_and_get_return_columns, - validate_and_get_text_column, +import json +import os +import threading +import uuid +from typing import Any, Dict, List, Optional +from unittest.mock import AsyncMock, MagicMock, create_autospec, patch + +import mlflow +import pytest +from databricks.sdk import WorkspaceClient +from databricks.sdk.credentials_provider import ModelServingUserCredentials +from databricks.vector_search.utils import CredentialStrategy +from databricks_ai_bridge.test_utils.vector_search import ( # noqa: F401 + ALL_INDEX_NAMES, + DELTA_SYNC_INDEX, + DELTA_SYNC_INDEX_EMBEDDING_MODEL_ENDPOINT_NAME, + DIRECT_ACCESS_INDEX, + INPUT_TEXTS, + _get_index, + mock_vs_client, + mock_workspace_client, ) -from databricks_ai_bridge.vector_search_retriever_tool import ( - FilterItem, - VectorSearchRetrieverToolMixin, - vector_search_retriever_tool_trace, +from databricks_ai_bridge.vector_search_retriever_tool import FilterItem +from langchain_core.documents import Document +from langchain_core.embeddings import Embeddings +from langchain_core.tools import BaseTool +from mlflow.entities import SpanType +from mlflow.models.resources import ( + DatabricksServingEndpoint, + DatabricksVectorSearchIndex, ) -from openai import OpenAI, pydantic_function_tool -from openai.types.chat import ChatCompletionToolParam -from pydantic import Field, PrivateAttr, model_validator -from databricks_openai.mcp_server_toolkit import McpServerToolkit +from databricks_langchain import ( + ChatDatabricks, + VectorSearchRetrieverTool, +) +from tests.utils.chat_models import llm, mock_client # noqa: F401 +from tests.utils.vector_search import ( + EMBEDDING_MODEL, + embeddings, # noqa: F401 +) +from tests.utils.vector_search import ( + mock_client as mock_embeddings_client, # noqa: F401 +) -_logger = logging.getLogger(__name__) +def _create_mcp_response_json(texts: List[str] = None) -> str: + """Create a mock MCP response in JSON format.""" + texts = texts or INPUT_TEXTS + return json.dumps( + [ + {"id": str(uuid.uuid4()), "text": text, "score": 0.85 - (i * 0.1)} + for i, text in enumerate(texts) + ] + ) -class VectorSearchRetrieverTool(VectorSearchRetrieverToolMixin): - """ - A utility class to create a vector search-based retrieval tool for querying indexed embeddings. - This class integrates with Databricks Vector Search and provides a convenient interface - for tool calling using the OpenAI SDK. - Example: - Step 1: Call model with VectorSearchRetrieverTool defined +@pytest.fixture +def mock_mcp_infrastructure(): + """Mock MCP infrastructure for tests that need it.""" + # Create mock MCP tool that returns JSON response + mock_tool = MagicMock() + mock_tool.invoke = MagicMock(return_value=_create_mcp_response_json()) + + # Create mock MCP client + mock_client_instance = MagicMock() + mock_client_instance.get_tools = AsyncMock(return_value=[mock_tool]) + + # Create mock MCP server + mock_server_instance = MagicMock() + + with ( + patch( + "databricks_langchain.vector_search_retriever_tool.DatabricksMultiServerMCPClient" + ) as mock_client_class, + patch( + "databricks_langchain.vector_search_retriever_tool.DatabricksMCPServer" + ) as mock_server_class, + ): + mock_client_class.return_value = mock_client_instance + mock_server_class.from_vector_search.return_value = mock_server_instance + yield { + "client_class": mock_client_class, + "client_instance": mock_client_instance, + "server_class": mock_server_class, + "server_instance": mock_server_instance, + "tool": mock_tool, + } - .. code-block:: python - dbvs_tool = VectorSearchRetrieverTool(index_name="catalog.schema.my_index_name") - messages = [ - {"role": "system", "content": "You are a helpful assistant."}, - { - "role": "user", - "content": "Using the Databricks documentation, answer what are AI Gateway inference tables?", - }, - ] - first_response = client.chat.completions.create( - model="gpt-4o", messages=messages, tools=[dbvs_tool.tool] - ) +@pytest.fixture(params=["mcp", "direct_api"]) +def execution_path(request, mock_mcp_infrastructure): + """Parametrized fixture that sets up mocks for MCP or Direct API path.""" + if request.param == "mcp": + yield { + "path": "mcp", + "index_name": DELTA_SYNC_INDEX, + "mock_tool": mock_mcp_infrastructure["tool"], + "mock_mcp": mock_mcp_infrastructure, + } + else: + # For direct API, use an index that requires self-managed embeddings + yield { + "path": "direct_api", + "index_name": DIRECT_ACCESS_INDEX, + "mock_tool": None, + "mock_mcp": mock_mcp_infrastructure, + } - Step 2: Execute function code – parse the model's response and handle function calls. - .. code-block:: python +def setup_tool_for_path(execution_path, tool): + """Set up mock for the tool based on execution path.""" + if execution_path["path"] == "direct_api": + tool._vector_store.similarity_search = MagicMock(return_value=[]) + + +def init_vector_search_tool( + index_name: str, + columns: Optional[List[str]] = None, + tool_name: Optional[str] = None, + tool_description: Optional[str] = None, + embedding: Optional[Embeddings] = None, + text_column: Optional[str] = None, + doc_uri: Optional[str] = None, + primary_key: Optional[str] = None, + filters: Optional[Dict[str, Any]] = None, + **kwargs: Any, +) -> VectorSearchRetrieverTool: + kwargs.update( + { + "index_name": index_name, + "columns": columns, + "tool_name": tool_name, + "tool_description": tool_description, + "embedding": embedding, + "text_column": text_column, + "doc_uri": doc_uri, + "primary_key": primary_key, + "filters": filters, + } + ) + if index_name != DELTA_SYNC_INDEX: + kwargs.update( + { + "embedding": EMBEDDING_MODEL, + "text_column": "text", + } + ) + return VectorSearchRetrieverTool(**kwargs) # type: ignore[arg-type] - tool_call = first_response.choices[0].message.tool_calls[0] - args = json.loads(tool_call.function.arguments) - result = dbvs_tool.execute( - query=args["query"], - filters={"category": "governance", "status": "general"}, - num_results=5, - score_threshold=0.7, - ) - Step 3: Supply model with results – so it can incorporate them into its final response. +@pytest.mark.parametrize("index_name", ALL_INDEX_NAMES) +def test_init(index_name: str) -> None: + vector_search_tool = init_vector_search_tool(index_name) + assert isinstance(vector_search_tool, BaseTool) + assert "'additionalProperties': true" not in str(vector_search_tool.args) - .. code-block:: python - messages.append(first_response.choices[0].message) - messages.append( - {"role": "tool", "tool_call_id": tool_call.id, "content": json.dumps(result)} - ) - second_response = client.chat.completions.create( - model="gpt-4o", messages=messages, tools=tools - ) +@pytest.mark.parametrize("index_name", ALL_INDEX_NAMES) +def test_chat_model_bind_tools(llm: ChatDatabricks, index_name: str) -> None: + from langchain_core.messages import AIMessage + + vector_search_tool = init_vector_search_tool(index_name) + llm_with_tools = llm.bind_tools([vector_search_tool]) + response = llm_with_tools.invoke("Which city is hotter today and which is bigger: LA or NY?") + assert isinstance(response, AIMessage) + + +def test_filters_are_passed_through(execution_path) -> None: + """Test filters are passed through correctly on both paths.""" + tool = init_vector_search_tool(execution_path["index_name"]) + setup_tool_for_path(execution_path, tool) - **Note**: Any additional keyword arguments passed to the constructor will be passed along when executing the tool. - The ``execute()`` method supports meta parameters such as ``num_results``, ``score_threshold``, ``query_type``, - ``filters``, ``columns``, and ``columns_to_rerank``. See - :class:`~databricks_ai_bridge.vector_search_retriever_tool.VectorSearchRetrieverToolMixin` for additional supported - constructor arguments. - - WorkspaceClient instances with auth types PAT, OAuth-M2M (client ID and client secret), or model serving credential - strategy will be used to instantiate the underlying VectorSearchClient. - """ - - text_column: Optional[str] = Field( - None, - description="The name of the text column to use for the embeddings. " - "Required for direct-access index or delta-sync index with " - "self-managed embeddings.", + tool.invoke( + { + "query": "what cities are in Germany", + "filters": [FilterItem(key="country", value="Germany")], + } ) - embedding_model_name: Optional[str] = Field( - None, - description="The name of the embedding model to use for embedding the query text." - "Required for direct-access index or delta-sync index with " - "self-managed embeddings.", + + if execution_path["path"] == "mcp": + execution_path["mock_tool"].invoke.assert_called_once() + call_args = execution_path["mock_tool"].invoke.call_args[0][0] + assert call_args["query"] == "what cities are in Germany" + # MCP path: filters are JSON stringified + assert json.loads(call_args["filters"]) == {"country": "Germany"} + else: + tool._vector_store.similarity_search.assert_called_once() + call_args = tool._vector_store.similarity_search.call_args + assert call_args[1]["query"] == "what cities are in Germany" + assert call_args[1]["filter"] == {"country": "Germany"} + + +def test_filters_are_combined(execution_path) -> None: + """Test filters are combined correctly (predefined + runtime) on both paths.""" + tool = init_vector_search_tool(execution_path["index_name"], filters={"city LIKE": "Berlin"}) + setup_tool_for_path(execution_path, tool) + + tool.invoke( + { + "query": "what cities are in Germany", + "filters": [FilterItem(key="country", value="Germany")], + } ) - tool: ChatCompletionToolParam = Field( - None, description="The tool input used in the OpenAI chat completion SDK" + expected_filters = {"city LIKE": "Berlin", "country": "Germany"} + if execution_path["path"] == "mcp": + execution_path["mock_tool"].invoke.assert_called_once() + call_args = execution_path["mock_tool"].invoke.call_args[0][0] + assert call_args["query"] == "what cities are in Germany" + # MCP path: filters are JSON stringified + assert json.loads(call_args["filters"]) == expected_filters + else: + tool._vector_store.similarity_search.assert_called_once() + call_args = tool._vector_store.similarity_search.call_args + assert call_args[1]["query"] == "what cities are in Germany" + assert call_args[1]["filter"] == expected_filters + + +@pytest.mark.parametrize("index_name", ALL_INDEX_NAMES) +@pytest.mark.parametrize("columns", [None, ["id", "text"]]) +@pytest.mark.parametrize("tool_name", [None, "test_tool"]) +@pytest.mark.parametrize("tool_description", [None, "Test tool for vector search"]) +@pytest.mark.parametrize("embedding", [None, EMBEDDING_MODEL]) +@pytest.mark.parametrize("text_column", [None, "text"]) +def test_vector_search_retriever_tool_combinations( + mock_mcp_infrastructure, + index_name: str, + columns: Optional[List[str]], + tool_name: Optional[str], + tool_description: Optional[str], + embedding: Optional[Any], + text_column: Optional[str], +) -> None: + if index_name == DELTA_SYNC_INDEX: + embedding = None + text_column = None + + vector_search_tool = init_vector_search_tool( + index_name=index_name, + columns=columns, + tool_name=tool_name, + tool_description=tool_description, + embedding=embedding, + text_column=text_column, ) - _index: VectorSearchIndex = PrivateAttr() - _index_details: IndexDetails = PrivateAttr() - _mcp_toolkit: Optional[McpServerToolkit] = PrivateAttr(default=None) - _mcp_tool_execute: Optional[Callable] = PrivateAttr(default=None) - - @model_validator(mode="after") - def _validate_tool_inputs(self): - from databricks.vector_search.client import ( - VectorSearchClient, # import here so we can mock in tests - ) - from databricks.vector_search.utils import CredentialStrategy + assert isinstance(vector_search_tool, BaseTool) + result = vector_search_tool.invoke("Databricks Agent Framework") + assert result is not None - splits = self.index_name.split(".") - if len(splits) != 3: - raise ValueError( - f"Index name {self.index_name} is not in the expected format 'catalog.schema.index'." - ) - client_args = { - "disable_notice": True, - } - if self.workspace_client is not None: - config = self.workspace_client.config - if config.auth_type == "model_serving_user_credentials": - client_args.setdefault( - "credential_strategy", CredentialStrategy.MODEL_SERVING_USER_CREDENTIALS + +def test_vector_search_retriever_tool_doc_uri_primary_key(mock_mcp_infrastructure) -> None: + """Test that doc_uri and primary_key work correctly with MCP path.""" + vector_search_tool = init_vector_search_tool( + index_name=DELTA_SYNC_INDEX, + doc_uri="uri", + primary_key="id", + ) + assert isinstance(vector_search_tool, BaseTool) + result = vector_search_tool.invoke("Databricks Agent Framework") + # With MCP path, results are parsed from mock JSON response + assert result is not None + assert len(result) > 0 + assert all(isinstance(doc, Document) for doc in result) + # Verify Documents have expected structure from mock response + assert all(doc.page_content for doc in result) + assert all("id" in doc.metadata for doc in result) + + +@pytest.mark.parametrize("index_name", ALL_INDEX_NAMES) +def test_vector_search_retriever_tool_description_generation(index_name: str) -> None: + vector_search_tool = init_vector_search_tool(index_name) + assert vector_search_tool.name != "" + assert vector_search_tool.description != "" + assert vector_search_tool.name == index_name.replace(".", "__") + assert ( + "A vector search-based retrieval tool for querying indexed embeddings." + in vector_search_tool.description + ) + assert vector_search_tool.args_schema.model_fields["query"] is not None + assert vector_search_tool.args_schema.model_fields["query"].description == ( + "The string used to query the index with and identify the most similar " + "vectors and return the associated documents." + ) + + +@pytest.mark.parametrize("index_name", ALL_INDEX_NAMES) +@pytest.mark.parametrize("tool_name", [None, "test_tool"]) +def test_vs_tool_tracing( + mock_mcp_infrastructure, index_name: str, tool_name: Optional[str] +) -> None: + vector_search_tool = init_vector_search_tool(index_name, tool_name=tool_name) + vector_search_tool._run("Databricks Agent Framework") + + trace = mlflow.get_trace(mlflow.get_last_active_trace_id()) + spans = trace.search_spans(name=tool_name or index_name, span_type=SpanType.RETRIEVER) + assert len(spans) == 1 + inputs = json.loads(trace.to_dict()["data"]["spans"][0]["attributes"]["mlflow.spanInputs"]) + assert inputs["query"] == "Databricks Agent Framework" + outputs = json.loads(trace.to_dict()["data"]["spans"][0]["attributes"]["mlflow.spanOutputs"]) + # Verify outputs are Documents with page_content + assert len(outputs) > 0 + assert all("page_content" in d for d in outputs) + assert all(d["page_content"] for d in outputs) # page_content is not empty + + +@pytest.mark.parametrize("index_name", ALL_INDEX_NAMES) +def test_vector_search_retriever_tool_resources( + mock_embeddings_client, + embeddings, + index_name: str, +) -> None: + text_column = "text" + if index_name == DELTA_SYNC_INDEX: + embeddings = None + text_column = None + + vector_search_tool = VectorSearchRetrieverTool( + index_name=index_name, embedding=embeddings, text_column=text_column + ) + expected_resources = ( + [DatabricksVectorSearchIndex(index_name=index_name)] + + ([DatabricksServingEndpoint(endpoint_name=embeddings.endpoint)] if embeddings else []) + + ( + [ + DatabricksServingEndpoint( + endpoint_name=DELTA_SYNC_INDEX_EMBEDDING_MODEL_ENDPOINT_NAME ) - elif config.auth_type == "pat": - client_args.setdefault("personal_access_token", config.token) - elif config.auth_type == "oauth-m2m": - client_args.setdefault("workspace_url", config.host) - client_args.setdefault("service_principal_client_id", config.client_id) - client_args.setdefault("service_principal_client_secret", config.client_secret) - self._index = VectorSearchClient(**client_args).get_index(index_name=self.index_name) - self._index_details = IndexDetails(self._index) - self.text_column = validate_and_get_text_column(self.text_column, self._index_details) - self.columns = validate_and_get_return_columns( - self.columns or [], - self.text_column, - self._index_details, - self.doc_uri, - self.primary_key, - ) - self._retriever_schema = RetrieverSchema( - text_column=self.text_column, - doc_uri=self.doc_uri, - primary_key=self.primary_key, - other_columns=self.columns, + ] + if index_name == DELTA_SYNC_INDEX + else [] ) + ) + assert [res.to_dict() for res in vector_search_tool.resources] == [ + res.to_dict() for res in expected_resources + ] + + +@pytest.mark.parametrize("tool_name", [None, "valid_tool_name", "test_tool"]) +def test_tool_name_validation_valid(tool_name: Optional[str]) -> None: + index_name = "catalog.schema.index" + tool = init_vector_search_tool(index_name, tool_name=tool_name) + assert tool.tool_name == tool_name + if tool_name: + assert tool.name == tool_name - if ( - not self._index_details.is_databricks_managed_embeddings() - and not self.embedding_model_name - ): - raise ValueError( - "The embedding model name is required for non-Databricks-managed " - "embeddings Vector Search indexes in order to generate embeddings for retrieval queries." - ) - tool_name = self._get_tool_name() +@pytest.mark.parametrize("tool_name", ["test.tool.name", "tool&name"]) +def test_tool_name_validation_invalid(tool_name: str) -> None: + index_name = "catalog.schema.index" + with pytest.raises(ValueError): + init_vector_search_tool(index_name, tool_name=tool_name) - # Create tool input model based on dynamic_filter setting - if self.dynamic_filter: - tool_input_class = self._create_enhanced_input_model() - else: - tool_input_class = self._create_basic_input_model() - self.tool = pydantic_function_tool( - tool_input_class, - name=tool_name, - description=self.tool_description - or self._get_default_tool_description(self._index_details), +@pytest.mark.parametrize( + "index_name,name", + [ + ("catalog.schema.index", "catalog__schema__index"), + ("cata_log.schema_.index", "cata_log__schema___index"), + ], +) +def test_index_name_to_tool_name(index_name: str, name: str) -> None: + vector_search_tool = init_vector_search_tool(index_name) + assert vector_search_tool.name == name + + +def test_vector_search_client_model_serving_environment(): + with patch("os.path.isfile", return_value=True): + # Simulate Model Serving Environment + os.environ["IS_IN_DB_MODEL_SERVING_ENV"] = "true" + + # Fake credential token + current_thread = threading.current_thread() + thread_data = current_thread.__dict__ + thread_data["invokers_token"] = "abc" + + w = WorkspaceClient( + host="testDogfod.com", credentials_strategy=ModelServingUserCredentials() ) - # We need to remove strict: True from the tool in order to support arbitrary filters - if "function" in self.tool and "strict" in self.tool["function"]: - del self.tool["function"]["strict"] - # We need to remove additionalProperties from the tool in order to support arbitrary kwargs - if ( - "function" in self.tool - and "parameters" in self.tool["function"] - and "additionalProperties" in self.tool["function"]["parameters"] - ): - del self.tool["function"]["parameters"]["additionalProperties"] - - try: - from databricks.sdk import WorkspaceClient - from databricks.sdk.errors.platform import ResourceDoesNotExist - - if self.workspace_client is not None: - self.workspace_client.serving_endpoints.get(self.embedding_model_name) - else: - WorkspaceClient().serving_endpoints.get(self.embedding_model_name) - self.resources = self._get_resources( - self.index_name, self.embedding_model_name, self._index_details - ) - except ResourceDoesNotExist: - self.resources = self._get_resources(self.index_name, None, self._index_details) - - return self - - def _create_or_get_mcp_toolkit(self) -> Callable: - """ - If it does not exist, create the MCP tool execution function for this index. - Otherwise, return the execution function. - - Uses McpServerToolkit.from_vector_search(catalog, schema, index_name, workspace_client) - to access tools for the specified vector search index. - """ - if self._mcp_tool_execute is not None: - return self._mcp_tool_execute - - catalog, schema, index = self._parse_index_name() - - try: - self._mcp_toolkit = McpServerToolkit.from_vector_search( - catalog=catalog, - schema=schema, - index_name=index, - workspace_client=self.workspace_client, - ) - except Exception as e: - self._handle_mcp_creation_error(e) - - tools = self._mcp_toolkit.get_tools() - self._validate_mcp_tools(tools) - - self._mcp_tool_execute = tools[0].execute - return self._mcp_tool_execute - - def _build_mcp_meta( - self, filters: Optional[Union[Dict[str, Any], List[FilterItem]]] = None, **kwargs: Any - ) -> Dict[str, Any]: - """Build metadata dict for MCP tool invocation.""" - return self._build_mcp_params(filters, **kwargs) - - def _parse_mcp_response(self, mcp_response: str) -> List[Dict]: - """Parse MCP JSON response and normalize to page_content/metadata format.""" - return self._parse_mcp_response_to_dicts(mcp_response, strict=True) - - def _execute_mcp_path( - self, - query: str, - filters: Optional[Union[Dict[str, Any], List[FilterItem]]] = None, - **kwargs: Any, - ) -> List[Dict]: - try: - mcp_execute = self._create_or_get_mcp_toolkit() - meta = self._build_mcp_meta(filters, **kwargs) - mcp_response = mcp_execute(query=query, _meta=meta) - return self._parse_mcp_response(mcp_response) - except Exception as e: - self._handle_mcp_execution_error(e) - - def _execute_direct_api_path( - self, - query: str, - filters: Optional[Union[Dict[str, Any], List[FilterItem]]] = None, - openai_client: OpenAI = None, - **kwargs: Any, - ) -> List[Dict]: - from openai import OpenAI - - oai_client = openai_client or OpenAI() - if not oai_client.api_key: - raise ValueError( - "OpenAI API key is required to generate embeddings for retrieval queries." - ) - signature = inspect.signature(self._index.similarity_search) - kwargs = {**kwargs, **(self.model_extra or {})} - kwargs = {k: v for k, v in kwargs.items() if k in signature.parameters} + with patch("databricks.vector_search.client.VectorSearchClient") as mockVSClient: + mock_instance = mockVSClient.return_value + mock_instance.get_index.side_effect = _get_index + with patch("databricks.sdk.service.serving.ServingEndpointsAPI.get", return_value=None): + vsTool = VectorSearchRetrieverTool( + index_name="test.delta_sync.index", + tool_description="desc", + workspace_client=w, + ) + mockVSClient.assert_called_once_with( + disable_notice=True, + credential_strategy=CredentialStrategy.MODEL_SERVING_USER_CREDENTIALS, + ) - # Allow kwargs to override the default values upon invocation - num_results = kwargs.pop("num_results", self.num_results) - query_type = kwargs.pop("query_type", self.query_type) - reranker = kwargs.pop("reranker", self.reranker) - query_text = query if query_type and query_type.upper() == "HYBRID" else None - query_vector = ( - oai_client.embeddings.create(input=query, model=self.embedding_model_name) - .data[0] - .embedding +def test_vector_search_client_non_model_serving_environment(): + with patch("databricks.vector_search.client.VectorSearchClient") as mockVSClient: + mock_instance = mockVSClient.return_value + mock_instance.get_index.side_effect = _get_index + vsTool = VectorSearchRetrieverTool( + index_name="test.delta_sync.index", + tool_description="desc", ) - if ( - index_embedding_dimension := self._index_details.embedding_vector_column.get( - "embedding_dimension" + mockVSClient.assert_called_once_with(disable_notice=True) + + +def test_vector_search_client_with_pat_workspace_client(): + w = WorkspaceClient(host="testDogfod.com", token="fakeToken") + with patch("databricks.vector_search.client.VectorSearchClient") as mockVSClient: + with patch("databricks.sdk.service.serving.ServingEndpointsAPI.get", return_value=None): + mock_instance = mockVSClient.return_value + mock_instance.get_index.side_effect = _get_index + VectorSearchRetrieverTool( + index_name="test.delta_sync.index", + tool_description="desc", + workspace_client=w, ) - ) and len(query_vector) != index_embedding_dimension: - raise ValueError( - f"Expected embedding dimension {index_embedding_dimension} but got {len(query_vector)}" + mockVSClient.assert_called_once_with( + disable_notice=True, personal_access_token="fakeToken" ) - combined_filters = {**(self.filters or {}), **self._normalize_filters(filters)} - kwargs.update( - { - "query_text": query_text, - "query_vector": query_vector, - "columns": self.columns, - "filters": combined_filters, - "num_results": num_results, - "query_type": query_type, - "reranker": reranker, - } - ) - search_resp = self._index.similarity_search(**kwargs) - docs_with_score: List[Tuple[Dict, float]] = parse_vector_search_response( - search_resp=search_resp, - retriever_schema=self._retriever_schema, - document_class=dict, - include_score=self.include_score, +def test_vector_search_client_with_sp_workspace_client(): + # Create a proper mock workspace client that passes isinstance check + w = create_autospec(WorkspaceClient, instance=True) + w.config.auth_type = "oauth-m2m" + w.config.host = "testDogfod.com" + w.config.client_id = "fakeClientId" + w.config.client_secret = "fakeClientSecret" + + with patch("databricks.vector_search.client.VectorSearchClient") as mockVSClient: + with patch("databricks.sdk.service.serving.ServingEndpointsAPI.get", return_value=None): + mock_instance = mockVSClient.return_value + mock_instance.get_index.side_effect = _get_index + VectorSearchRetrieverTool( + index_name="test.delta_sync.index", + tool_description="desc", + workspace_client=w, + ) + mockVSClient.assert_called_once_with( + disable_notice=True, + workspace_url="testDogfod.com", + service_principal_client_id="fakeClientId", + service_principal_client_secret="fakeClientSecret", + ) + + +def test_kwargs_are_passed_through(execution_path) -> None: + """Test kwargs are passed through correctly on both paths.""" + tool = init_vector_search_tool(execution_path["index_name"], score_threshold=0.5) + setup_tool_for_path(execution_path, tool) + + tool.invoke({"query": "what cities are in Germany"}) + + if execution_path["path"] == "mcp": + execution_path["mock_tool"].invoke.assert_called_once() + call_args = execution_path["mock_tool"].invoke.call_args[0][0] + assert call_args["query"] == "what cities are in Germany" + assert call_args["score_threshold"] == 0.5 + else: + tool._vector_store.similarity_search.assert_called_once() + call_args = tool._vector_store.similarity_search.call_args + assert call_args[1]["query"] == "what cities are in Germany" + assert call_args[1]["score_threshold"] == 0.5 + + +def test_kwargs_override_both_num_results_and_query_type(execution_path) -> None: + """Test kwargs can override num_results and query_type on both paths.""" + tool = init_vector_search_tool(execution_path["index_name"], num_results=10, query_type="ANN") + setup_tool_for_path(execution_path, tool) + + tool.invoke({"query": "what cities are in Germany", "k": 3, "query_type": "HYBRID"}) + + if execution_path["path"] == "mcp": + execution_path["mock_tool"].invoke.assert_called_once() + call_args = execution_path["mock_tool"].invoke.call_args[0][0] + assert call_args["query"] == "what cities are in Germany" + assert call_args["num_results"] == 3 + assert call_args["query_type"] == "HYBRID" + else: + tool._vector_store.similarity_search.assert_called_once() + call_args = tool._vector_store.similarity_search.call_args + assert call_args[1]["query"] == "what cities are in Germany" + assert call_args[1]["k"] == 3 + assert call_args[1]["query_type"] == "HYBRID" + + +def test_enhanced_filter_description_with_column_metadata() -> None: + """Test that the tool args_schema includes enhanced filter descriptions with column metadata.""" + vector_search_tool = init_vector_search_tool(DELTA_SYNC_INDEX, dynamic_filter=True) + + # The LangChain implementation calls index.describe() to get column information + # and includes them in the filter description + args_schema = vector_search_tool.args_schema + filter_field = args_schema.model_fields["filters"] + + # Check that the filter description is enhanced with available columns + # Note: The actual columns will depend on the mocked index.describe() response + assert ( + "Available columns for filtering:" in filter_field.description + or "Optional filters" in filter_field.description + ) + + # Should include comprehensive filter syntax + assert "Inclusion:" in filter_field.description + assert "Exclusion:" in filter_field.description + assert "Comparisons:" in filter_field.description + assert "Pattern match:" in filter_field.description + assert "OR logic:" in filter_field.description + + # Should include examples + assert "Examples:" in filter_field.description + assert "Filter by category:" in filter_field.description + assert "Filter by price range:" in filter_field.description + + +def test_enhanced_filter_description_fails_on_table_metadata_error() -> None: + """Test that tool initialization fails with clear error when table metadata cannot be retrieved.""" + # Mock WorkspaceClient to raise an exception when accessing table metadata + with patch("databricks.sdk.WorkspaceClient") as mock_ws_client_class: + mock_ws_client = MagicMock() + mock_ws_client.tables.get.side_effect = Exception("Permission denied") + mock_ws_client_class.return_value = mock_ws_client + + # Try to initialize tool with dynamic_filter=True + # This should fail because we can't get table metadata + with pytest.raises( + ValueError, + match="Failed to retrieve table metadata for index.*Permission denied", + ): + init_vector_search_tool(DELTA_SYNC_INDEX, dynamic_filter=True) + + +def test_enhanced_filter_description_fails_on_empty_columns() -> None: + """Test that tool initialization fails when table has no valid columns.""" + # Mock WorkspaceClient to return a table with no valid columns (all start with __) + with patch("databricks.sdk.WorkspaceClient") as mock_ws_client_class: + mock_ws_client = MagicMock() + mock_table = MagicMock() + mock_column = MagicMock() + mock_column.name = "__internal_column" + mock_column.type_name = MagicMock() + mock_column.type_name.name = "STRING" + mock_table.columns = [mock_column] + mock_ws_client.tables.get.return_value = mock_table + mock_ws_client_class.return_value = mock_ws_client + + # Try to initialize tool with dynamic_filter=True + # This should fail because there are no valid columns + with pytest.raises( + ValueError, + match="No valid columns found in table metadata for index", + ): + init_vector_search_tool(DELTA_SYNC_INDEX, dynamic_filter=True) + + +def test_cannot_use_both_dynamic_filter_and_predefined_filters() -> None: + """Test that using both dynamic_filter and predefined filters raises an error.""" + # Try to initialize tool with both dynamic_filter=True and predefined filters + with pytest.raises( + ValueError, match="Cannot use both dynamic_filter=True and predefined filters" + ): + init_vector_search_tool( + DELTA_SYNC_INDEX, + filters={"status": "active", "category": "electronics"}, + dynamic_filter=True, ) - return [doc for doc, _ in docs_with_score] - - @vector_search_retriever_tool_trace - def execute( - self, - query: str, - filters: Optional[Union[Dict[str, Any], List[FilterItem]]] = None, - openai_client: OpenAI = None, - **kwargs: Any, - ) -> List[Dict]: - """ - Execute the VectorSearchIndex tool calls from the ChatCompletions response that correspond to the - self.tool VectorSearchRetrieverToolInput and attach the retrieved documents into tool call messages. - - Execute vector search with automatic routing: - - MCP path: Used for Databricks-managed embeddings (no embedding model configuration needed) - - Direct API path: Used for self-managed embeddings (requires openai_client) - - Args: - query: The query text to use for the retrieval. - filters: Optional filters to refine vector search results. - openai_client: The OpenAI client object used to generate embeddings for retrieval queries. - Only used for self-managed embeddings. If not provided, the default OpenAI - client in the current environment will be used. - **kwargs: Additional search parameters (e.g., num_results, query_type, score_threshold, reranker). - For Databricks-managed embeddings, these are passed as MCP metadata. - For self-managed embeddings, these are passed to similarity_search(). - - Returns: - A list of document dictionaries. Format may vary between MCP and Direct API paths. - """ - if self._index_details.is_databricks_managed_embeddings(): - return self._execute_mcp_path(query, filters, **kwargs) - else: - return self._execute_direct_api_path(query, filters, openai_client, **kwargs) + + +def test_predefined_filters_work_without_dynamic_filter(execution_path) -> None: + """Test that predefined filters work correctly when dynamic_filter is False on both paths.""" + tool = init_vector_search_tool( + execution_path["index_name"], filters={"status": "active", "category": "electronics"} + ) + setup_tool_for_path(execution_path, tool) + + # The filters parameter should NOT be exposed since dynamic_filter=False + args_schema = tool.args_schema + assert "filters" not in args_schema.model_fields + + tool.invoke({"query": "what electronics are available"}) + + expected_filters = {"status": "active", "category": "electronics"} + if execution_path["path"] == "mcp": + execution_path["mock_tool"].invoke.assert_called_once() + call_args = execution_path["mock_tool"].invoke.call_args[0][0] + assert call_args["query"] == "what electronics are available" + # MCP path: filters are JSON stringified + assert json.loads(call_args["filters"]) == expected_filters + else: + tool._vector_store.similarity_search.assert_called_once() + call_args = tool._vector_store.similarity_search.call_args + assert call_args[1]["query"] == "what electronics are available" + assert call_args[1]["filter"] == expected_filters + + +def test_filter_item_serialization(execution_path) -> None: + """Test that FilterItem objects are properly converted to dictionaries on both paths.""" + tool = init_vector_search_tool(execution_path["index_name"]) + setup_tool_for_path(execution_path, tool) + + # Test various filter types + filters = [ + FilterItem(key="category", value="electronics"), + FilterItem(key="price >=", value=100), + FilterItem(key="status NOT", value="discontinued"), + FilterItem(key="tags", value=["wireless", "bluetooth"]), + ] + + tool.invoke({"query": "find products", "filters": filters}) + + expected_filters = { + "category": "electronics", + "price >=": 100, + "status NOT": "discontinued", + "tags": ["wireless", "bluetooth"], + } + + if execution_path["path"] == "mcp": + execution_path["mock_tool"].invoke.assert_called_once() + call_args = execution_path["mock_tool"].invoke.call_args[0][0] + assert call_args["query"] == "find products" + # MCP path: filters are JSON stringified + assert json.loads(call_args["filters"]) == expected_filters + else: + tool._vector_store.similarity_search.assert_called_once() + call_args = tool._vector_store.similarity_search.call_args + assert call_args[1]["query"] == "find products" + assert call_args[1]["filter"] == expected_filters + + +# ============================================================================= +# MCP Path Specific Tests +# ============================================================================= + + +def test_mcp_path_is_used_for_databricks_managed_embeddings(mock_mcp_infrastructure) -> None: + """Test that MCP path is used for Databricks-managed embeddings indexes.""" + vector_search_tool = init_vector_search_tool(DELTA_SYNC_INDEX) + + # Invoke the tool (should use MCP path for DELTA_SYNC_INDEX which has managed embeddings) + result = vector_search_tool._run("test query") + + # Verify MCP server was created with correct parameters + mock_mcp_infrastructure["server_class"].from_vector_search.assert_called_once() + call_kwargs = mock_mcp_infrastructure["server_class"].from_vector_search.call_args[1] + assert call_kwargs["catalog"] == "test" + assert call_kwargs["schema"] == "delta_sync" + assert call_kwargs["index_name"] == "index" + + # Verify MCP client was used + mock_mcp_infrastructure["client_class"].assert_called_once() + + # Verify MCP tool was invoked + mock_mcp_infrastructure["tool"].invoke.assert_called_once() + + +def test_direct_api_path_is_used_for_self_managed_embeddings(mock_mcp_infrastructure) -> None: + """Test that direct API path is used for self-managed embeddings indexes.""" + # Use an index that requires self-managed embeddings + index_name = "test.direct_access.index" + vector_search_tool = init_vector_search_tool(index_name) + vector_search_tool._vector_store.similarity_search = MagicMock(return_value=[]) + + # Invoke the tool (should use direct API path) + result = vector_search_tool._run("test query") + + # Verify similarity_search was called directly + vector_search_tool._vector_store.similarity_search.assert_called_once() + + # Verify MCP was NOT used for self-managed embeddings + mock_mcp_infrastructure["tool"].invoke.assert_not_called() + + +def test_mcp_tool_is_cached(mock_mcp_infrastructure) -> None: + """Test that MCP tool is cached and not recreated on subsequent calls.""" + vector_search_tool = init_vector_search_tool(DELTA_SYNC_INDEX) + + # Call _run multiple times + vector_search_tool._run("query 1") + vector_search_tool._run("query 2") + vector_search_tool._run("query 3") + + # MCP server should only be created once + assert mock_mcp_infrastructure["server_class"].from_vector_search.call_count == 1 + + # MCP client should only be created once + assert mock_mcp_infrastructure["client_class"].call_count == 1 + + # But MCP tool should be invoked 3 times + assert mock_mcp_infrastructure["tool"].invoke.call_count == 3 + + +def test_mcp_response_parsing_json_array() -> None: + """Test that MCP JSON array response is parsed correctly into Documents.""" + vector_search_tool = init_vector_search_tool(DELTA_SYNC_INDEX) + + json_response = json.dumps( + [ + {"id": "doc1", "text": "content1", "score": 0.9}, + {"id": "doc2", "text": "content2", "score": 0.8}, + ] + ) + + docs = vector_search_tool._parse_mcp_response(json_response) + + assert len(docs) == 2 + assert all(isinstance(doc, Document) for doc in docs) + assert docs[0].page_content == "content1" + assert docs[0].metadata == {"id": "doc1", "score": 0.9} + assert docs[1].page_content == "content2" + assert docs[1].metadata == {"id": "doc2", "score": 0.8} + + +def test_mcp_response_parsing_non_json() -> None: + """Test that non-JSON MCP response is treated as a single document.""" + vector_search_tool = init_vector_search_tool(DELTA_SYNC_INDEX) + + plain_text_response = "This is a plain text response" + + docs = vector_search_tool._parse_mcp_response(plain_text_response) + + assert len(docs) == 1 + assert docs[0].page_content == plain_text_response + + +def test_mcp_response_parsing_non_list_json() -> None: + """Test that non-list JSON is converted to a single document.""" + vector_search_tool = init_vector_search_tool(DELTA_SYNC_INDEX) + + json_response = json.dumps({"message": "single object response"}) + + docs = vector_search_tool._parse_mcp_response(json_response) + + assert len(docs) == 1 + assert docs[0].page_content == "{'message': 'single object response'}" + + +def test_normalize_filters_with_filter_items() -> None: + """Test that FilterItem list is normalized to dict.""" + vector_search_tool = init_vector_search_tool(DELTA_SYNC_INDEX) + + filters = [ + FilterItem(key="category", value="electronics"), + FilterItem(key="price >=", value=100), + ] + + result = vector_search_tool._normalize_filters(filters) + + assert result == {"category": "electronics", "price >=": 100} + + +def test_normalize_filters_with_dict() -> None: + """Test that dict filters are passed through unchanged.""" + vector_search_tool = init_vector_search_tool(DELTA_SYNC_INDEX) + + filters = {"category": "electronics", "price >=": 100} + + result = vector_search_tool._normalize_filters(filters) + + assert result == filters + + +def test_normalize_filters_with_none() -> None: + """Test that None filters return empty dict.""" + vector_search_tool = init_vector_search_tool(DELTA_SYNC_INDEX) + + result = vector_search_tool._normalize_filters(None) + + assert result == {} + + +def test_build_mcp_input() -> None: + """Test MCP input building with various parameters.""" + from databricks.vector_search.reranker import DatabricksReranker + + # Basic parameters + tool = init_vector_search_tool(DELTA_SYNC_INDEX) + mcp_input = tool._build_mcp_input("test query") + assert mcp_input["query"] == "test query" + assert mcp_input["num_results"] == tool.num_results + assert mcp_input["query_type"] == tool.query_type + assert mcp_input["include_score"] == "false" # Default + + # With filters (JSON stringified for MCP - parse back to compare) + filters = [FilterItem(key="category", value="electronics")] + mcp_input = tool._build_mcp_input("test query", filters=filters) + assert json.loads(mcp_input["filters"]) == {"category": "electronics"} + + # Combines predefined and runtime filters + tool_with_filters = init_vector_search_tool(DELTA_SYNC_INDEX, filters={"status": "active"}) + runtime_filters = [FilterItem(key="category", value="electronics")] + mcp_input = tool_with_filters._build_mcp_input("test query", filters=runtime_filters) + expected_filters = {"status": "active", "category": "electronics"} + assert json.loads(mcp_input["filters"]) == expected_filters + + # kwargs override defaults + tool_with_defaults = init_vector_search_tool(DELTA_SYNC_INDEX, num_results=10, query_type="ANN") + mcp_input = tool_with_defaults._build_mcp_input( + "test query", num_results=5, query_type="HYBRID" + ) + assert mcp_input["num_results"] == 5 + assert mcp_input["query_type"] == "HYBRID" + + # With columns (comma-separated for MCP) + tool_with_columns = init_vector_search_tool(DELTA_SYNC_INDEX, columns=["id", "text", "score"]) + mcp_input = tool_with_columns._build_mcp_input("test query") + assert mcp_input["columns"] == "id,text,score" + + # With score_threshold (converted to float) + mcp_input = tool._build_mcp_input("test query", score_threshold=0.7) + assert mcp_input["score_threshold"] == 0.7 + assert isinstance(mcp_input["score_threshold"], float) + + # With include_score=True + tool_with_score = init_vector_search_tool(DELTA_SYNC_INDEX, include_score=True) + mcp_input = tool_with_score._build_mcp_input("test query") + assert mcp_input["include_score"] == "true" + + # With reranker + reranker = DatabricksReranker(columns_to_rerank=["text", "title"]) + tool_with_reranker = init_vector_search_tool(DELTA_SYNC_INDEX, reranker=reranker) + mcp_input = tool_with_reranker._build_mcp_input("test query") + assert mcp_input["columns_to_rerank"] == "text,title" From ce9745fa47d56c7026fc7137124248291ae4593b Mon Sep 17 00:00:00 2001 From: Nisha Balaji Date: Tue, 27 Jan 2026 13:32:13 -0800 Subject: [PATCH 05/11] refactor tests --- .../vector_search_retriever_tool.py | 11 +- .../test_vector_search_retriever_tool.py | 270 ++++++------------ .../test_vector_search_retriever_tool.py | 242 +++++++++++++++- 3 files changed, 327 insertions(+), 196 deletions(-) diff --git a/integrations/langchain/src/databricks_langchain/vector_search_retriever_tool.py b/integrations/langchain/src/databricks_langchain/vector_search_retriever_tool.py index 666fc1e93..9f40bebc9 100644 --- a/integrations/langchain/src/databricks_langchain/vector_search_retriever_tool.py +++ b/integrations/langchain/src/databricks_langchain/vector_search_retriever_tool.py @@ -3,8 +3,6 @@ from typing import Any, Dict, List, Optional, Type, Union from databricks_ai_bridge.utils.vector_search import IndexDetails - -_logger = logging.getLogger(__name__) from databricks_ai_bridge.vector_search_retriever_tool import ( FilterItem, VectorSearchRetrieverToolInput, @@ -14,7 +12,6 @@ from langchain_core.documents import Document from langchain_core.embeddings import Embeddings from langchain_core.tools import BaseTool -from langchain_core.tools import BaseTool as LangChainBaseTool from pydantic import BaseModel, Field, PrivateAttr, model_validator from databricks_langchain import DatabricksEmbeddings @@ -24,6 +21,8 @@ ) from databricks_langchain.vectorstores import DatabricksVectorSearch +_logger = logging.getLogger(__name__) + class VectorSearchRetrieverTool(BaseTool, VectorSearchRetrieverToolMixin): """ @@ -58,7 +57,7 @@ class VectorSearchRetrieverTool(BaseTool, VectorSearchRetrieverToolMixin): args_schema: Type[BaseModel] = VectorSearchRetrieverToolInput _vector_store: DatabricksVectorSearch = PrivateAttr() - _mcp_tool: Optional[LangChainBaseTool] = PrivateAttr(default=None) + _mcp_tool: Optional[BaseTool] = PrivateAttr(default=None) @model_validator(mode="after") def _validate_tool_inputs(self): @@ -94,7 +93,7 @@ def _validate_tool_inputs(self): return self - def _create_or_get_mcp_tool(self) -> LangChainBaseTool: + def _create_or_get_mcp_tool(self) -> BaseTool: """Create or return existing MCP tool using LangChain MCP Server.""" if self._mcp_tool is not None: return self._mcp_tool @@ -132,7 +131,7 @@ def _build_mcp_input( def _parse_mcp_response(self, mcp_response: str) -> List[Document]: """Parse MCP tool response into LangChain Documents.""" - dicts = self._parse_mcp_response_to_dicts(mcp_response, strict=False) + dicts = self._parse_mcp_response_to_dicts(mcp_response, strict=True) return [Document(page_content=d["page_content"], metadata=d["metadata"]) for d in dicts] def _execute_mcp_path( diff --git a/integrations/langchain/tests/unit_tests/test_vector_search_retriever_tool.py b/integrations/langchain/tests/unit_tests/test_vector_search_retriever_tool.py index ff87c9079..74dd98cd4 100644 --- a/integrations/langchain/tests/unit_tests/test_vector_search_retriever_tool.py +++ b/integrations/langchain/tests/unit_tests/test_vector_search_retriever_tool.py @@ -55,6 +55,17 @@ def _create_mcp_response_json(texts: List[str] = None) -> str: ) +def assert_mcp_tool_called_with(mock_tool, expected_args: Dict[str, Any]): + """Assert MCP tool was called with expected args, handling JSON-stringified filters.""" + mock_tool.invoke.assert_called_once() + call_args = mock_tool.invoke.call_args[0][0] + for key, value in expected_args.items(): + if key == "filters": + assert json.loads(call_args["filters"]) == value + else: + assert call_args[key] == value + + @pytest.fixture def mock_mcp_infrastructure(): """Mock MCP infrastructure for tests that need it.""" @@ -179,16 +190,17 @@ def test_filters_are_passed_through(execution_path) -> None: ) if execution_path["path"] == "mcp": - execution_path["mock_tool"].invoke.assert_called_once() - call_args = execution_path["mock_tool"].invoke.call_args[0][0] - assert call_args["query"] == "what cities are in Germany" - # MCP path: filters are JSON stringified - assert json.loads(call_args["filters"]) == {"country": "Germany"} + assert_mcp_tool_called_with( + execution_path["mock_tool"], + {"query": "what cities are in Germany", "filters": {"country": "Germany"}}, + ) else: - tool._vector_store.similarity_search.assert_called_once() - call_args = tool._vector_store.similarity_search.call_args - assert call_args[1]["query"] == "what cities are in Germany" - assert call_args[1]["filter"] == {"country": "Germany"} + tool._vector_store.similarity_search.assert_called_once_with( + query="what cities are in Germany", + k=tool.num_results, + filter={"country": "Germany"}, + query_type=tool.query_type, + ) def test_filters_are_combined(execution_path) -> None: @@ -205,16 +217,17 @@ def test_filters_are_combined(execution_path) -> None: expected_filters = {"city LIKE": "Berlin", "country": "Germany"} if execution_path["path"] == "mcp": - execution_path["mock_tool"].invoke.assert_called_once() - call_args = execution_path["mock_tool"].invoke.call_args[0][0] - assert call_args["query"] == "what cities are in Germany" - # MCP path: filters are JSON stringified - assert json.loads(call_args["filters"]) == expected_filters + assert_mcp_tool_called_with( + execution_path["mock_tool"], + {"query": "what cities are in Germany", "filters": expected_filters}, + ) else: - tool._vector_store.similarity_search.assert_called_once() - call_args = tool._vector_store.similarity_search.call_args - assert call_args[1]["query"] == "what cities are in Germany" - assert call_args[1]["filter"] == expected_filters + tool._vector_store.similarity_search.assert_called_once_with( + query="what cities are in Germany", + k=tool.num_results, + filter=expected_filters, + query_type=tool.query_type, + ) @pytest.mark.parametrize("index_name", ALL_INDEX_NAMES) @@ -453,15 +466,18 @@ def test_kwargs_are_passed_through(execution_path) -> None: tool.invoke({"query": "what cities are in Germany"}) if execution_path["path"] == "mcp": - execution_path["mock_tool"].invoke.assert_called_once() - call_args = execution_path["mock_tool"].invoke.call_args[0][0] - assert call_args["query"] == "what cities are in Germany" - assert call_args["score_threshold"] == 0.5 + assert_mcp_tool_called_with( + execution_path["mock_tool"], + {"query": "what cities are in Germany", "score_threshold": 0.5}, + ) else: - tool._vector_store.similarity_search.assert_called_once() - call_args = tool._vector_store.similarity_search.call_args - assert call_args[1]["query"] == "what cities are in Germany" - assert call_args[1]["score_threshold"] == 0.5 + tool._vector_store.similarity_search.assert_called_once_with( + query="what cities are in Germany", + k=tool.num_results, + filter={}, + query_type=tool.query_type, + score_threshold=0.5, + ) def test_kwargs_override_both_num_results_and_query_type(execution_path) -> None: @@ -472,17 +488,17 @@ def test_kwargs_override_both_num_results_and_query_type(execution_path) -> None tool.invoke({"query": "what cities are in Germany", "k": 3, "query_type": "HYBRID"}) if execution_path["path"] == "mcp": - execution_path["mock_tool"].invoke.assert_called_once() - call_args = execution_path["mock_tool"].invoke.call_args[0][0] - assert call_args["query"] == "what cities are in Germany" - assert call_args["num_results"] == 3 - assert call_args["query_type"] == "HYBRID" + assert_mcp_tool_called_with( + execution_path["mock_tool"], + {"query": "what cities are in Germany", "num_results": 3, "query_type": "HYBRID"}, + ) else: - tool._vector_store.similarity_search.assert_called_once() - call_args = tool._vector_store.similarity_search.call_args - assert call_args[1]["query"] == "what cities are in Germany" - assert call_args[1]["k"] == 3 - assert call_args[1]["query_type"] == "HYBRID" + tool._vector_store.similarity_search.assert_called_once_with( + query="what cities are in Germany", + k=3, + filter={}, + query_type="HYBRID", + ) def test_enhanced_filter_description_with_column_metadata() -> None: @@ -582,16 +598,17 @@ def test_predefined_filters_work_without_dynamic_filter(execution_path) -> None: expected_filters = {"status": "active", "category": "electronics"} if execution_path["path"] == "mcp": - execution_path["mock_tool"].invoke.assert_called_once() - call_args = execution_path["mock_tool"].invoke.call_args[0][0] - assert call_args["query"] == "what electronics are available" - # MCP path: filters are JSON stringified - assert json.loads(call_args["filters"]) == expected_filters + assert_mcp_tool_called_with( + execution_path["mock_tool"], + {"query": "what electronics are available", "filters": expected_filters}, + ) else: - tool._vector_store.similarity_search.assert_called_once() - call_args = tool._vector_store.similarity_search.call_args - assert call_args[1]["query"] == "what electronics are available" - assert call_args[1]["filter"] == expected_filters + tool._vector_store.similarity_search.assert_called_once_with( + query="what electronics are available", + k=tool.num_results, + filter=expected_filters, + query_type=tool.query_type, + ) def test_filter_item_serialization(execution_path) -> None: @@ -617,16 +634,17 @@ def test_filter_item_serialization(execution_path) -> None: } if execution_path["path"] == "mcp": - execution_path["mock_tool"].invoke.assert_called_once() - call_args = execution_path["mock_tool"].invoke.call_args[0][0] - assert call_args["query"] == "find products" - # MCP path: filters are JSON stringified - assert json.loads(call_args["filters"]) == expected_filters + assert_mcp_tool_called_with( + execution_path["mock_tool"], + {"query": "find products", "filters": expected_filters}, + ) else: - tool._vector_store.similarity_search.assert_called_once() - call_args = tool._vector_store.similarity_search.call_args - assert call_args[1]["query"] == "find products" - assert call_args[1]["filter"] == expected_filters + tool._vector_store.similarity_search.assert_called_once_with( + query="find products", + k=tool.num_results, + filter=expected_filters, + query_type=tool.query_type, + ) # ============================================================================= @@ -651,8 +669,15 @@ def test_mcp_path_is_used_for_databricks_managed_embeddings(mock_mcp_infrastruct # Verify MCP client was used mock_mcp_infrastructure["client_class"].assert_called_once() - # Verify MCP tool was invoked - mock_mcp_infrastructure["tool"].invoke.assert_called_once() + # Verify MCP tool was invoked with expected query + mock_mcp_infrastructure["tool"].invoke.assert_called_once_with( + { + "query": "test query", + "num_results": vector_search_tool.num_results, + "query_type": vector_search_tool.query_type, + "include_score": "false", + } + ) def test_direct_api_path_is_used_for_self_managed_embeddings(mock_mcp_infrastructure) -> None: @@ -689,136 +714,3 @@ def test_mcp_tool_is_cached(mock_mcp_infrastructure) -> None: # But MCP tool should be invoked 3 times assert mock_mcp_infrastructure["tool"].invoke.call_count == 3 - - -def test_mcp_response_parsing_json_array() -> None: - """Test that MCP JSON array response is parsed correctly into Documents.""" - vector_search_tool = init_vector_search_tool(DELTA_SYNC_INDEX) - - json_response = json.dumps( - [ - {"id": "doc1", "text": "content1", "score": 0.9}, - {"id": "doc2", "text": "content2", "score": 0.8}, - ] - ) - - docs = vector_search_tool._parse_mcp_response(json_response) - - assert len(docs) == 2 - assert all(isinstance(doc, Document) for doc in docs) - assert docs[0].page_content == "content1" - assert docs[0].metadata == {"id": "doc1", "score": 0.9} - assert docs[1].page_content == "content2" - assert docs[1].metadata == {"id": "doc2", "score": 0.8} - - -def test_mcp_response_parsing_non_json() -> None: - """Test that non-JSON MCP response is treated as a single document.""" - vector_search_tool = init_vector_search_tool(DELTA_SYNC_INDEX) - - plain_text_response = "This is a plain text response" - - docs = vector_search_tool._parse_mcp_response(plain_text_response) - - assert len(docs) == 1 - assert docs[0].page_content == plain_text_response - - -def test_mcp_response_parsing_non_list_json() -> None: - """Test that non-list JSON is converted to a single document.""" - vector_search_tool = init_vector_search_tool(DELTA_SYNC_INDEX) - - json_response = json.dumps({"message": "single object response"}) - - docs = vector_search_tool._parse_mcp_response(json_response) - - assert len(docs) == 1 - assert docs[0].page_content == "{'message': 'single object response'}" - - -def test_normalize_filters_with_filter_items() -> None: - """Test that FilterItem list is normalized to dict.""" - vector_search_tool = init_vector_search_tool(DELTA_SYNC_INDEX) - - filters = [ - FilterItem(key="category", value="electronics"), - FilterItem(key="price >=", value=100), - ] - - result = vector_search_tool._normalize_filters(filters) - - assert result == {"category": "electronics", "price >=": 100} - - -def test_normalize_filters_with_dict() -> None: - """Test that dict filters are passed through unchanged.""" - vector_search_tool = init_vector_search_tool(DELTA_SYNC_INDEX) - - filters = {"category": "electronics", "price >=": 100} - - result = vector_search_tool._normalize_filters(filters) - - assert result == filters - - -def test_normalize_filters_with_none() -> None: - """Test that None filters return empty dict.""" - vector_search_tool = init_vector_search_tool(DELTA_SYNC_INDEX) - - result = vector_search_tool._normalize_filters(None) - - assert result == {} - - -def test_build_mcp_input() -> None: - """Test MCP input building with various parameters.""" - from databricks.vector_search.reranker import DatabricksReranker - - # Basic parameters - tool = init_vector_search_tool(DELTA_SYNC_INDEX) - mcp_input = tool._build_mcp_input("test query") - assert mcp_input["query"] == "test query" - assert mcp_input["num_results"] == tool.num_results - assert mcp_input["query_type"] == tool.query_type - assert mcp_input["include_score"] == "false" # Default - - # With filters (JSON stringified for MCP - parse back to compare) - filters = [FilterItem(key="category", value="electronics")] - mcp_input = tool._build_mcp_input("test query", filters=filters) - assert json.loads(mcp_input["filters"]) == {"category": "electronics"} - - # Combines predefined and runtime filters - tool_with_filters = init_vector_search_tool(DELTA_SYNC_INDEX, filters={"status": "active"}) - runtime_filters = [FilterItem(key="category", value="electronics")] - mcp_input = tool_with_filters._build_mcp_input("test query", filters=runtime_filters) - expected_filters = {"status": "active", "category": "electronics"} - assert json.loads(mcp_input["filters"]) == expected_filters - - # kwargs override defaults - tool_with_defaults = init_vector_search_tool(DELTA_SYNC_INDEX, num_results=10, query_type="ANN") - mcp_input = tool_with_defaults._build_mcp_input( - "test query", num_results=5, query_type="HYBRID" - ) - assert mcp_input["num_results"] == 5 - assert mcp_input["query_type"] == "HYBRID" - - # With columns (comma-separated for MCP) - tool_with_columns = init_vector_search_tool(DELTA_SYNC_INDEX, columns=["id", "text", "score"]) - mcp_input = tool_with_columns._build_mcp_input("test query") - assert mcp_input["columns"] == "id,text,score" - - # With score_threshold (converted to float) - mcp_input = tool._build_mcp_input("test query", score_threshold=0.7) - assert mcp_input["score_threshold"] == 0.7 - assert isinstance(mcp_input["score_threshold"], float) - - # With include_score=True - tool_with_score = init_vector_search_tool(DELTA_SYNC_INDEX, include_score=True) - mcp_input = tool_with_score._build_mcp_input("test query") - assert mcp_input["include_score"] == "true" - - # With reranker - reranker = DatabricksReranker(columns_to_rerank=["text", "title"]) - tool_with_reranker = init_vector_search_tool(DELTA_SYNC_INDEX, reranker=reranker) - mcp_input = tool_with_reranker._build_mcp_input("test query") - assert mcp_input["columns_to_rerank"] == "text,title" diff --git a/tests/databricks_ai_bridge/test_vector_search_retriever_tool.py b/tests/databricks_ai_bridge/test_vector_search_retriever_tool.py index dbf72a458..20badabdb 100644 --- a/tests/databricks_ai_bridge/test_vector_search_retriever_tool.py +++ b/tests/databricks_ai_bridge/test_vector_search_retriever_tool.py @@ -1,3 +1,4 @@ +import json from unittest.mock import MagicMock import pytest @@ -5,7 +6,10 @@ from databricks_ai_bridge.test_utils.vector_search import mock_workspace_client # noqa: F401 from databricks_ai_bridge.utils.vector_search import IndexDetails -from databricks_ai_bridge.vector_search_retriever_tool import VectorSearchRetrieverToolMixin +from databricks_ai_bridge.vector_search_retriever_tool import ( + FilterItem, + VectorSearchRetrieverToolMixin, +) class DummyVectorSearchRetrieverTool(VectorSearchRetrieverToolMixin): @@ -78,3 +82,239 @@ def test_describe_columns(): "country (STRING): Name of the country\n" "description (STRING): Detailed description of the city" ) + + +# ============================================================================= +# Tests for _normalize_filters +# ============================================================================= + + +def test_normalize_filters_with_filter_items(): + """Test that FilterItem list is normalized to dict.""" + tool = DummyVectorSearchRetrieverTool(index_name=index_name) + + filters = [ + FilterItem(key="category", value="electronics"), + FilterItem(key="price >=", value=100), + ] + + result = tool._normalize_filters(filters) + + assert result == {"category": "electronics", "price >=": 100} + + +# ============================================================================= +# Tests for _parse_mcp_response_to_dicts +# ============================================================================= + + +def test_parse_mcp_response_to_dicts_json_array(): + """Test that JSON array response is parsed correctly into dicts.""" + tool = DummyVectorSearchRetrieverTool(index_name=index_name) + + json_response = json.dumps( + [ + {"id": "doc1", "text": "content1", "score": 0.9}, + {"id": "doc2", "text": "content2", "score": 0.8}, + ] + ) + + dicts = tool._parse_mcp_response_to_dicts(json_response) + + assert len(dicts) == 2 + assert dicts[0]["page_content"] == "content1" + assert dicts[0]["metadata"] == {"id": "doc1", "score": 0.9} + assert dicts[1]["page_content"] == "content2" + assert dicts[1]["metadata"] == {"id": "doc2", "score": 0.8} + + +def test_parse_mcp_response_to_dicts_non_json_strict(): + """Test that non-JSON response raises ValueError when strict=True.""" + tool = DummyVectorSearchRetrieverTool(index_name=index_name) + + plain_text_response = "This is a plain text response" + + with pytest.raises(ValueError, match="Unable to parse MCP response"): + tool._parse_mcp_response_to_dicts(plain_text_response, strict=True) + + +def test_parse_mcp_response_to_dicts_non_json_non_strict(): + """Test that non-JSON response is treated as single document when strict=False.""" + tool = DummyVectorSearchRetrieverTool(index_name=index_name) + + plain_text_response = "This is a plain text response" + + dicts = tool._parse_mcp_response_to_dicts(plain_text_response, strict=False) + + assert len(dicts) == 1 + assert dicts[0]["page_content"] == plain_text_response + assert dicts[0]["metadata"] == {} + + +def test_parse_mcp_response_to_dicts_non_list_json_strict(): + """Test that non-list JSON raises ValueError when strict=True.""" + tool = DummyVectorSearchRetrieverTool(index_name=index_name) + + json_response = json.dumps({"message": "single object response"}) + + with pytest.raises(ValueError, match="Expected JSON array, got"): + tool._parse_mcp_response_to_dicts(json_response, strict=True) + + +def test_parse_mcp_response_to_dicts_non_list_json_non_strict(): + """Test that non-list JSON is converted to single document when strict=False.""" + tool = DummyVectorSearchRetrieverTool(index_name=index_name) + + json_response = json.dumps({"message": "single object response"}) + + dicts = tool._parse_mcp_response_to_dicts(json_response, strict=False) + + assert len(dicts) == 1 + assert dicts[0]["page_content"] == "{'message': 'single object response'}" + assert dicts[0]["metadata"] == {} + + +def test_parse_mcp_response_to_dicts_empty_list(): + """Test parsing empty list response.""" + tool = DummyVectorSearchRetrieverTool(index_name=index_name) + + json_response = json.dumps([]) + + dicts = tool._parse_mcp_response_to_dicts(json_response) + + assert dicts == [] + + +def test_parse_mcp_response_to_dicts_custom_text_column(): + """Test that custom text column is used for page_content.""" + tool = DummyVectorSearchRetrieverTool(index_name=index_name) + + json_response = json.dumps( + [ + {"id": "doc1", "content": "custom content", "score": 0.9}, + ] + ) + + dicts = tool._parse_mcp_response_to_dicts(json_response, text_column="content") + + assert len(dicts) == 1 + assert dicts[0]["page_content"] == "custom content" + assert dicts[0]["metadata"] == {"id": "doc1", "score": 0.9} + + +# ============================================================================= +# Tests for _build_mcp_params +# ============================================================================= + + +def test_build_mcp_params_basic(): + """Test basic MCP params building.""" + tool = DummyVectorSearchRetrieverTool(index_name=index_name) + + params = tool._build_mcp_params(None) + + assert params["num_results"] == tool.num_results + assert params["query_type"] == tool.query_type + assert params["include_score"] == "false" + assert "filters" not in params + + +def test_build_mcp_params_with_filters(): + """Test MCP params building with filters.""" + tool = DummyVectorSearchRetrieverTool(index_name=index_name) + + filters = [FilterItem(key="category", value="electronics")] + params = tool._build_mcp_params(filters) + + assert json.loads(params["filters"]) == {"category": "electronics"} + + +def test_build_mcp_params_combines_filters(): + """Test MCP params building combines predefined and runtime filters.""" + tool = DummyVectorSearchRetrieverTool(index_name=index_name, filters={"status": "active"}) + + runtime_filters = [FilterItem(key="category", value="electronics")] + params = tool._build_mcp_params(runtime_filters) + + expected_filters = {"status": "active", "category": "electronics"} + assert json.loads(params["filters"]) == expected_filters + + +def test_build_mcp_params_kwargs_override_defaults(): + """Test that kwargs override default values.""" + tool = DummyVectorSearchRetrieverTool(index_name=index_name, num_results=10, query_type="ANN") + + params = tool._build_mcp_params(None, num_results=5, query_type="HYBRID") + + assert params["num_results"] == 5 + assert params["query_type"] == "HYBRID" + + +def test_build_mcp_params_with_columns(): + """Test MCP params building with columns.""" + tool = DummyVectorSearchRetrieverTool(index_name=index_name, columns=["id", "text", "score"]) + + params = tool._build_mcp_params(None) + + assert params["columns"] == "id,text,score" + + +def test_build_mcp_params_with_include_score(): + """Test MCP params building with include_score=True.""" + tool = DummyVectorSearchRetrieverTool(index_name=index_name, include_score=True) + + params = tool._build_mcp_params(None) + + assert params["include_score"] == "true" + + +def test_build_mcp_params_k_alias_for_num_results(): + """Test that 'k' kwarg is treated as alias for num_results.""" + tool = DummyVectorSearchRetrieverTool(index_name=index_name, num_results=10) + + params = tool._build_mcp_params(None, k=3) + + assert params["num_results"] == 3 + + +def test_build_mcp_params_with_reranker(): + """Test MCP params building with reranker.""" + from databricks.vector_search.reranker import DatabricksReranker + + reranker = DatabricksReranker(columns_to_rerank=["text", "title"]) + tool = DummyVectorSearchRetrieverTool(index_name=index_name, reranker=reranker) + + params = tool._build_mcp_params(None) + + assert params["columns_to_rerank"] == "text,title" + + +# ============================================================================= +# Tests for _parse_index_name +# ============================================================================= + + +def test_parse_index_name_invalid(): + """Test parsing invalid index name raises ValueError.""" + tool = DummyVectorSearchRetrieverTool(index_name="invalid_index_name") + + with pytest.raises(ValueError, match="Invalid index name format"): + tool._parse_index_name() + + +# ============================================================================= +# Tests for validate_filter_configuration +# ============================================================================= + + +def test_cannot_use_both_dynamic_filter_and_predefined_filters(): + """Test that using both dynamic_filter and predefined filters raises an error.""" + # Try to initialize tool with both dynamic_filter=True and predefined filters + with pytest.raises( + ValueError, match="Cannot use both dynamic_filter=True and predefined filters" + ): + DummyVectorSearchRetrieverTool( + index_name=index_name, + filters={"status": "active", "category": "electronics"}, + dynamic_filter=True, + ) From acb745b2dee1aa933eff5220eedb3d26db85ea6b Mon Sep 17 00:00:00 2001 From: Nisha Balaji Date: Tue, 27 Jan 2026 15:22:59 -0800 Subject: [PATCH 06/11] async fix attempt --- .../vector_search_retriever_tool.py | 3 ++- .../test_vector_search_retriever_tool.py | 15 ++++++++------- 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/integrations/langchain/src/databricks_langchain/vector_search_retriever_tool.py b/integrations/langchain/src/databricks_langchain/vector_search_retriever_tool.py index 9f40bebc9..c92cb5f17 100644 --- a/integrations/langchain/src/databricks_langchain/vector_search_retriever_tool.py +++ b/integrations/langchain/src/databricks_langchain/vector_search_retriever_tool.py @@ -144,7 +144,8 @@ def _execute_mcp_path( try: mcp_tool = self._create_or_get_mcp_tool() mcp_input = self._build_mcp_input(query, filters, **kwargs) - result = mcp_tool.invoke(mcp_input) + # MCP tools only support async invocation + result = asyncio.run(mcp_tool.ainvoke(mcp_input)) return self._parse_mcp_response(result) except Exception as e: self._handle_mcp_execution_error(e) diff --git a/integrations/langchain/tests/unit_tests/test_vector_search_retriever_tool.py b/integrations/langchain/tests/unit_tests/test_vector_search_retriever_tool.py index 74dd98cd4..d2bf53c5d 100644 --- a/integrations/langchain/tests/unit_tests/test_vector_search_retriever_tool.py +++ b/integrations/langchain/tests/unit_tests/test_vector_search_retriever_tool.py @@ -57,8 +57,8 @@ def _create_mcp_response_json(texts: List[str] = None) -> str: def assert_mcp_tool_called_with(mock_tool, expected_args: Dict[str, Any]): """Assert MCP tool was called with expected args, handling JSON-stringified filters.""" - mock_tool.invoke.assert_called_once() - call_args = mock_tool.invoke.call_args[0][0] + mock_tool.ainvoke.assert_called_once() + call_args = mock_tool.ainvoke.call_args[0][0] for key, value in expected_args.items(): if key == "filters": assert json.loads(call_args["filters"]) == value @@ -70,8 +70,9 @@ def assert_mcp_tool_called_with(mock_tool, expected_args: Dict[str, Any]): def mock_mcp_infrastructure(): """Mock MCP infrastructure for tests that need it.""" # Create mock MCP tool that returns JSON response + # MCP tools are async-only, so we mock ainvoke mock_tool = MagicMock() - mock_tool.invoke = MagicMock(return_value=_create_mcp_response_json()) + mock_tool.ainvoke = AsyncMock(return_value=_create_mcp_response_json()) # Create mock MCP client mock_client_instance = MagicMock() @@ -669,8 +670,8 @@ def test_mcp_path_is_used_for_databricks_managed_embeddings(mock_mcp_infrastruct # Verify MCP client was used mock_mcp_infrastructure["client_class"].assert_called_once() - # Verify MCP tool was invoked with expected query - mock_mcp_infrastructure["tool"].invoke.assert_called_once_with( + # Verify MCP tool was invoked with expected query (ainvoke since MCP tools are async-only) + mock_mcp_infrastructure["tool"].ainvoke.assert_called_once_with( { "query": "test query", "num_results": vector_search_tool.num_results, @@ -694,7 +695,7 @@ def test_direct_api_path_is_used_for_self_managed_embeddings(mock_mcp_infrastruc vector_search_tool._vector_store.similarity_search.assert_called_once() # Verify MCP was NOT used for self-managed embeddings - mock_mcp_infrastructure["tool"].invoke.assert_not_called() + mock_mcp_infrastructure["tool"].ainvoke.assert_not_called() def test_mcp_tool_is_cached(mock_mcp_infrastructure) -> None: @@ -713,4 +714,4 @@ def test_mcp_tool_is_cached(mock_mcp_infrastructure) -> None: assert mock_mcp_infrastructure["client_class"].call_count == 1 # But MCP tool should be invoked 3 times - assert mock_mcp_infrastructure["tool"].invoke.call_count == 3 + assert mock_mcp_infrastructure["tool"].ainvoke.call_count == 3 From edf8b3fc44656906947048e060ee78efd5291f9c Mon Sep 17 00:00:00 2001 From: Nisha Balaji Date: Tue, 27 Jan 2026 16:23:45 -0800 Subject: [PATCH 07/11] debug --- .../databricks_langchain/vector_search_retriever_tool.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/integrations/langchain/src/databricks_langchain/vector_search_retriever_tool.py b/integrations/langchain/src/databricks_langchain/vector_search_retriever_tool.py index c92cb5f17..a7144480c 100644 --- a/integrations/langchain/src/databricks_langchain/vector_search_retriever_tool.py +++ b/integrations/langchain/src/databricks_langchain/vector_search_retriever_tool.py @@ -146,6 +146,14 @@ def _execute_mcp_path( mcp_input = self._build_mcp_input(query, filters, **kwargs) # MCP tools only support async invocation result = asyncio.run(mcp_tool.ainvoke(mcp_input)) + + # DEBUG: Print what we get from MCP + print(f"DEBUG MCP result type: {type(result)}") + print(f"DEBUG MCP result repr: {repr(result)[:500]}") + if isinstance(result, list) and result: + print(f"DEBUG First item type: {type(result[0])}") + print(f"DEBUG First item: {result[0]}") + return self._parse_mcp_response(result) except Exception as e: self._handle_mcp_execution_error(e) From b0fd83c2b245786a8f487d7c868de6876271a453 Mon Sep 17 00:00:00 2001 From: Nisha Balaji Date: Tue, 27 Jan 2026 19:05:55 -0800 Subject: [PATCH 08/11] fix langchain format --- .../vector_search_retriever_tool.py | 25 ++++++++++------- .../test_vector_search_retriever_tool.py | 28 ++++++++++++------- 2 files changed, 33 insertions(+), 20 deletions(-) diff --git a/integrations/langchain/src/databricks_langchain/vector_search_retriever_tool.py b/integrations/langchain/src/databricks_langchain/vector_search_retriever_tool.py index a7144480c..1b33b613c 100644 --- a/integrations/langchain/src/databricks_langchain/vector_search_retriever_tool.py +++ b/integrations/langchain/src/databricks_langchain/vector_search_retriever_tool.py @@ -129,8 +129,21 @@ def _build_mcp_input( mcp_input["query"] = query return mcp_input - def _parse_mcp_response(self, mcp_response: str) -> List[Document]: - """Parse MCP tool response into LangChain Documents.""" + def _parse_mcp_response(self, mcp_response: Any) -> List[Document]: + """Parse MCP tool response into LangChain Documents. + + LangChain MCP adapters return content blocks in the format: + [{'type': 'text', 'text': '', 'id': '...'}] + + We need to extract the JSON string from the 'text' field. + """ + # Handle LangChain MCP adapter content block format + if isinstance(mcp_response, list) and mcp_response: + first_item = mcp_response[0] + if isinstance(first_item, dict) and first_item.get("type") == "text": + # Extract the actual JSON string from the content block + mcp_response = first_item.get("text", "") + dicts = self._parse_mcp_response_to_dicts(mcp_response, strict=True) return [Document(page_content=d["page_content"], metadata=d["metadata"]) for d in dicts] @@ -146,14 +159,6 @@ def _execute_mcp_path( mcp_input = self._build_mcp_input(query, filters, **kwargs) # MCP tools only support async invocation result = asyncio.run(mcp_tool.ainvoke(mcp_input)) - - # DEBUG: Print what we get from MCP - print(f"DEBUG MCP result type: {type(result)}") - print(f"DEBUG MCP result repr: {repr(result)[:500]}") - if isinstance(result, list) and result: - print(f"DEBUG First item type: {type(result[0])}") - print(f"DEBUG First item: {result[0]}") - return self._parse_mcp_response(result) except Exception as e: self._handle_mcp_execution_error(e) diff --git a/integrations/langchain/tests/unit_tests/test_vector_search_retriever_tool.py b/integrations/langchain/tests/unit_tests/test_vector_search_retriever_tool.py index d2bf53c5d..52f41ef06 100644 --- a/integrations/langchain/tests/unit_tests/test_vector_search_retriever_tool.py +++ b/integrations/langchain/tests/unit_tests/test_vector_search_retriever_tool.py @@ -44,15 +44,22 @@ ) -def _create_mcp_response_json(texts: List[str] = None) -> str: - """Create a mock MCP response in JSON format.""" +def _create_mcp_response(texts: List[str] = None) -> List[Dict[str, Any]]: + """Create a mock MCP response in LangChain MCP adapter content block format. + + The langchain-mcp-adapters library returns content blocks like: + [{'type': 'text', 'text': '', 'id': 'lc_...'}] + + This matches the actual response format observed during integration testing. + """ texts = texts or INPUT_TEXTS - return json.dumps( - [ - {"id": str(uuid.uuid4()), "text": text, "score": 0.85 - (i * 0.1)} - for i, text in enumerate(texts) - ] - ) + # The actual search results as JSON string (what the MCP server returns) + search_results = [ + {"id": str(uuid.uuid4()), "text": text, "score": 0.85 - (i * 0.1)} + for i, text in enumerate(texts) + ] + # Wrapped in LangChain MCP adapter content block format + return [{"type": "text", "text": json.dumps(search_results), "id": f"lc_{uuid.uuid4()}"}] def assert_mcp_tool_called_with(mock_tool, expected_args: Dict[str, Any]): @@ -69,10 +76,11 @@ def assert_mcp_tool_called_with(mock_tool, expected_args: Dict[str, Any]): @pytest.fixture def mock_mcp_infrastructure(): """Mock MCP infrastructure for tests that need it.""" - # Create mock MCP tool that returns JSON response + # Create mock MCP tool that returns content block format + # (matching what langchain-mcp-adapters actually returns) # MCP tools are async-only, so we mock ainvoke mock_tool = MagicMock() - mock_tool.ainvoke = AsyncMock(return_value=_create_mcp_response_json()) + mock_tool.ainvoke = AsyncMock(return_value=_create_mcp_response()) # Create mock MCP client mock_client_instance = MagicMock() From 0015beb2a2b46194260d13f3a673afd258dca619 Mon Sep 17 00:00:00 2001 From: Nisha Balaji Date: Tue, 27 Jan 2026 19:52:44 -0800 Subject: [PATCH 09/11] cleanup comments --- .../vector_search_retriever_tool.py | 9 +-------- .../unit_tests/test_vector_search_retriever_tool.py | 10 +--------- 2 files changed, 2 insertions(+), 17 deletions(-) diff --git a/integrations/langchain/src/databricks_langchain/vector_search_retriever_tool.py b/integrations/langchain/src/databricks_langchain/vector_search_retriever_tool.py index 1b33b613c..6ed435f31 100644 --- a/integrations/langchain/src/databricks_langchain/vector_search_retriever_tool.py +++ b/integrations/langchain/src/databricks_langchain/vector_search_retriever_tool.py @@ -130,14 +130,7 @@ def _build_mcp_input( return mcp_input def _parse_mcp_response(self, mcp_response: Any) -> List[Document]: - """Parse MCP tool response into LangChain Documents. - - LangChain MCP adapters return content blocks in the format: - [{'type': 'text', 'text': '', 'id': '...'}] - - We need to extract the JSON string from the 'text' field. - """ - # Handle LangChain MCP adapter content block format + """Parse MCP tool response into LangChain Documents.""" if isinstance(mcp_response, list) and mcp_response: first_item = mcp_response[0] if isinstance(first_item, dict) and first_item.get("type") == "text": diff --git a/integrations/langchain/tests/unit_tests/test_vector_search_retriever_tool.py b/integrations/langchain/tests/unit_tests/test_vector_search_retriever_tool.py index 52f41ef06..80dc7f492 100644 --- a/integrations/langchain/tests/unit_tests/test_vector_search_retriever_tool.py +++ b/integrations/langchain/tests/unit_tests/test_vector_search_retriever_tool.py @@ -45,20 +45,12 @@ def _create_mcp_response(texts: List[str] = None) -> List[Dict[str, Any]]: - """Create a mock MCP response in LangChain MCP adapter content block format. - - The langchain-mcp-adapters library returns content blocks like: - [{'type': 'text', 'text': '', 'id': 'lc_...'}] - - This matches the actual response format observed during integration testing. - """ + """Create a mock MCP response in LangChain MCP adapter content block format.""" texts = texts or INPUT_TEXTS - # The actual search results as JSON string (what the MCP server returns) search_results = [ {"id": str(uuid.uuid4()), "text": text, "score": 0.85 - (i * 0.1)} for i, text in enumerate(texts) ] - # Wrapped in LangChain MCP adapter content block format return [{"type": "text", "text": json.dumps(search_results), "id": f"lc_{uuid.uuid4()}"}] From 13d31cd3efeafe68d2b84562aca45924bfe62a8c Mon Sep 17 00:00:00 2001 From: Nisha Balaji Date: Wed, 28 Jan 2026 16:44:53 -0800 Subject: [PATCH 10/11] review comments --- .../vector_search_retriever_tool.py | 13 +- .../vector_search_retriever_tool.py | 26 +- .../test_vector_search_retriever_tool.py | 273 ------------------ .../vector_search_retriever_tool.py | 12 +- .../test_vector_search_retriever_tool.py | 138 +++++++++ 5 files changed, 155 insertions(+), 307 deletions(-) diff --git a/integrations/langchain/src/databricks_langchain/vector_search_retriever_tool.py b/integrations/langchain/src/databricks_langchain/vector_search_retriever_tool.py index 6ed435f31..c63bb17a2 100644 --- a/integrations/langchain/src/databricks_langchain/vector_search_retriever_tool.py +++ b/integrations/langchain/src/databricks_langchain/vector_search_retriever_tool.py @@ -118,17 +118,6 @@ def _create_or_get_mcp_tool(self) -> BaseTool: self._mcp_tool = tools[0] return self._mcp_tool - def _build_mcp_input( - self, - query: str, - filters: Optional[Union[Dict[str, Any], List[FilterItem]]] = None, - **kwargs: Any, - ) -> Dict[str, Any]: - """Build input for MCP tool invocation.""" - mcp_input = self._build_mcp_params(filters, **kwargs) - mcp_input["query"] = query - return mcp_input - def _parse_mcp_response(self, mcp_response: Any) -> List[Document]: """Parse MCP tool response into LangChain Documents.""" if isinstance(mcp_response, list) and mcp_response: @@ -149,7 +138,7 @@ def _execute_mcp_path( """Execute vector search via LangChain MCP infrastructure.""" try: mcp_tool = self._create_or_get_mcp_tool() - mcp_input = self._build_mcp_input(query, filters, **kwargs) + mcp_input = self._build_mcp_params(filters, query=query, **kwargs) # MCP tools only support async invocation result = asyncio.run(mcp_tool.ainvoke(mcp_input)) return self._parse_mcp_response(result) diff --git a/integrations/openai/src/databricks_openai/vector_search_retriever_tool.py b/integrations/openai/src/databricks_openai/vector_search_retriever_tool.py index 0802261d9..713d72f65 100644 --- a/integrations/openai/src/databricks_openai/vector_search_retriever_tool.py +++ b/integrations/openai/src/databricks_openai/vector_search_retriever_tool.py @@ -226,16 +226,6 @@ def _create_or_get_mcp_toolkit(self) -> Callable: self._mcp_tool_execute = tools[0].execute return self._mcp_tool_execute - def _build_mcp_meta( - self, filters: Optional[Union[Dict[str, Any], List[FilterItem]]] = None, **kwargs: Any - ) -> Dict[str, Any]: - """Build metadata dict for MCP tool invocation.""" - return self._build_mcp_params(filters, **kwargs) - - def _parse_mcp_response(self, mcp_response: str) -> List[Dict]: - """Parse MCP JSON response and normalize to page_content/metadata format.""" - return self._parse_mcp_response_to_dicts(mcp_response, strict=True) - def _execute_mcp_path( self, query: str, @@ -244,9 +234,9 @@ def _execute_mcp_path( ) -> List[Dict]: try: mcp_execute = self._create_or_get_mcp_toolkit() - meta = self._build_mcp_meta(filters, **kwargs) + meta = self._build_mcp_params(filters, **kwargs) mcp_response = mcp_execute(query=query, _meta=meta) - return self._parse_mcp_response(mcp_response) + return self._parse_mcp_response_to_dicts(mcp_response, strict=True) except Exception as e: self._handle_mcp_execution_error(e) @@ -257,13 +247,9 @@ def _execute_direct_api_path( openai_client: OpenAI = None, **kwargs: Any, ) -> List[Dict]: - from openai import OpenAI + from databricks_openai import DatabricksOpenAI - oai_client = openai_client or OpenAI() - if not oai_client.api_key: - raise ValueError( - "OpenAI API key is required to generate embeddings for retrieval queries." - ) + oai_client = openai_client or DatabricksOpenAI(workspace_client=self.workspace_client) signature = inspect.signature(self._index.similarity_search) kwargs = {**kwargs, **(self.model_extra or {})} @@ -331,8 +317,8 @@ def execute( query: The query text to use for the retrieval. filters: Optional filters to refine vector search results. openai_client: The OpenAI client object used to generate embeddings for retrieval queries. - Only used for self-managed embeddings. If not provided, the default OpenAI - client in the current environment will be used. + Only used for self-managed embeddings. If not provided, a DatabricksOpenAI + client will be created using the workspace_client for authentication. **kwargs: Additional search parameters (e.g., num_results, query_type, score_threshold, reranker). For Databricks-managed embeddings, these are passed as MCP metadata. For self-managed embeddings, these are passed to similarity_search(). diff --git a/integrations/openai/tests/unit_tests/test_vector_search_retriever_tool.py b/integrations/openai/tests/unit_tests/test_vector_search_retriever_tool.py index e2f03008e..daa33bb6c 100644 --- a/integrations/openai/tests/unit_tests/test_vector_search_retriever_tool.py +++ b/integrations/openai/tests/unit_tests/test_vector_search_retriever_tool.py @@ -289,38 +289,6 @@ def test_open_ai_client_from_env( assert all(["id" in d["metadata"] for d in docs]) -@pytest.mark.parametrize("index_name", ALL_INDEX_NAMES) -def test_vector_search_retriever_index_name_rewrite( - index_name: str, -) -> None: - if index_name == DELTA_SYNC_INDEX: - self_managed_embeddings_test = SelfManagedEmbeddingsTest() - else: - from openai import OpenAI - - self_managed_embeddings_test = SelfManagedEmbeddingsTest( - "text", "text-embedding-3-small", OpenAI(api_key="your-api-key") - ) - - vector_search_tool = init_vector_search_tool( - index_name=index_name, - text_column=self_managed_embeddings_test.text_column, - embedding_model_name=self_managed_embeddings_test.embedding_model_name, - ) - assert vector_search_tool.tool["function"]["name"] == index_name.replace(".", "__") - - -@pytest.mark.parametrize( - "index_name", - ["catalog.schema.really_really_really_long_tool_name_that_should_be_truncated_to_64_chars"], -) -def test_vector_search_retriever_long_index_name( - index_name: str, -) -> None: - vector_search_tool = init_vector_search_tool(index_name=index_name) - assert len(vector_search_tool.tool["function"]["name"]) <= 64 - - def test_vector_search_client_model_serving_environment(): with patch("os.path.isfile", return_value=True): # Simulate Model Serving Environment @@ -578,123 +546,6 @@ def test_include_score_always_sent_in_meta(mock_mcp_toolkit) -> None: assert call_kwargs["_meta"]["include_score"] == "false" -def test_get_filter_param_description_with_column_metadata() -> None: - """Test that _get_filter_param_description includes column metadata when available.""" - # Mock table info with column metadata - mock_column1 = Mock() - mock_column1.name = "category" - mock_column1.type_name.name = "STRING" - - mock_column2 = Mock() - mock_column2.name = "price" - mock_column2.type_name.name = "FLOAT" - - mock_column3 = Mock() - mock_column3.name = "__internal_column" # Should be excluded - mock_column3.type_name.name = "STRING" - - mock_table_info = Mock() - mock_table_info.columns = [mock_column1, mock_column2, mock_column3] - - with patch("databricks.sdk.WorkspaceClient") as mock_ws_client_class: - mock_ws_client = Mock() - mock_ws_client.tables.get.return_value = mock_table_info - mock_ws_client_class.return_value = mock_ws_client - - vector_search_tool = init_vector_search_tool(DELTA_SYNC_INDEX) - - # Test the _get_filter_param_description method directly - description = vector_search_tool._get_filter_param_description() - - # Should include available columns in description - assert "Available columns for filtering: category (STRING), price (FLOAT)" in description - - # Should include comprehensive filter syntax - assert "Inclusion:" in description - assert "Exclusion:" in description - assert "Comparisons:" in description - assert "Pattern match:" in description - assert "OR logic:" in description - - # Should include examples - assert "Examples:" in description - assert "Filter by category:" in description - assert "Filter by price range:" in description - - -def test_enhanced_filter_description_used_in_tool_schema() -> None: - """Test that the tool schema includes comprehensive filter descriptions.""" - vector_search_tool = init_vector_search_tool(DELTA_SYNC_INDEX, dynamic_filter=True) - - # Check that the tool schema includes enhanced filter description - tool_schema = vector_search_tool.tool - filter_param = tool_schema["function"]["parameters"]["properties"]["filters"] - - # Check that it includes the comprehensive filter syntax - assert "Inclusion:" in filter_param["description"] - assert "Exclusion:" in filter_param["description"] - assert "Comparisons:" in filter_param["description"] - assert "Pattern match:" in filter_param["description"] - assert "OR logic:" in filter_param["description"] - - # Check that it includes useful filter information - assert "array of key-value pairs" in filter_param["description"] - assert "column" in filter_param["description"] - - -def test_enhanced_filter_description_fails_on_table_metadata_error() -> None: - """Test that tool initialization fails with clear error when table metadata cannot be retrieved.""" - # Mock WorkspaceClient to raise an exception when accessing table metadata - with patch("databricks.sdk.WorkspaceClient") as mock_ws_client_class: - mock_ws_client = MagicMock() - mock_ws_client.tables.get.side_effect = Exception("Permission denied") - mock_ws_client_class.return_value = mock_ws_client - - # Try to initialize tool with dynamic_filter=True - # This should fail because we can't get table metadata - with pytest.raises( - ValueError, - match="Failed to retrieve table metadata for index.*Permission denied", - ): - init_vector_search_tool(DELTA_SYNC_INDEX, dynamic_filter=True) - - -def test_enhanced_filter_description_fails_on_empty_columns() -> None: - """Test that tool initialization fails when table has no valid columns.""" - # Mock WorkspaceClient to return a table with no valid columns (all start with __) - with patch("databricks.sdk.WorkspaceClient") as mock_ws_client_class: - mock_ws_client = MagicMock() - mock_table = MagicMock() - mock_column = MagicMock() - mock_column.name = "__internal_column" - mock_column.type_name = MagicMock() - mock_column.type_name.name = "STRING" - mock_table.columns = [mock_column] - mock_ws_client.tables.get.return_value = mock_table - mock_ws_client_class.return_value = mock_ws_client - - # Try to initialize tool with dynamic_filter=True - # This should fail because there are no valid columns - with pytest.raises( - ValueError, - match="No valid columns found in table metadata for index", - ): - init_vector_search_tool(DELTA_SYNC_INDEX, dynamic_filter=True) - - -def test_cannot_use_both_dynamic_filter_and_predefined_filters() -> None: - """Test that using both dynamic_filter and predefined filters raises an error.""" - # Try to initialize tool with both dynamic_filter=True and predefined filters - with pytest.raises( - ValueError, match="Cannot use both dynamic_filter=True and predefined filters" - ): - init_vector_search_tool( - DELTA_SYNC_INDEX, - filters={"status": "active", "category": "electronics"}, - dynamic_filter=True, - ) - - def test_predefined_filters_work_without_dynamic_filter(execution_path) -> None: """Test that predefined filters work correctly when dynamic_filter is False.""" predefined_filters = {"status": "active", "category": "electronics"} @@ -732,51 +583,6 @@ def test_predefined_filters_work_without_dynamic_filter(execution_path) -> None: assert call_kwargs["query_type"] == vector_search_tool.query_type -def test_filter_item_serialization(execution_path) -> None: - """Test that FilterItem objects are properly converted to dictionaries.""" - vector_search_tool = init_vector_search_tool(execution_path["index_name"]) - setup_tool_for_path(execution_path, vector_search_tool) - - # Test various filter types - filters = [ - FilterItem(key="category", value="electronics"), - FilterItem(key="price >=", value=100), - FilterItem(key="status NOT", value="discontinued"), - FilterItem(key="tags", value=["wireless", "bluetooth"]), - ] - - vector_search_tool.execute("find products", filters=filters) - - expected_filters = { - "category": "electronics", - "price >=": 100, - "status NOT": "discontinued", - "tags": ["wireless", "bluetooth"], - } - - if execution_path["path"] == "mcp": - mock_tool = execution_path["mock_tool"] - mock_tool.execute.assert_called_once() - call_kwargs = mock_tool.execute.call_args.kwargs - - assert call_kwargs["query"] == "find products" - - meta = call_kwargs["_meta"] - # Filters should be serialized as JSON - assert json.loads(meta["filters"]) == expected_filters - assert meta["num_results"] == vector_search_tool.num_results - assert meta["query_type"] == vector_search_tool.query_type - assert meta["columns"] == ",".join(vector_search_tool.columns) - else: - vector_search_tool._index.similarity_search.assert_called_once() - call_kwargs = vector_search_tool._index.similarity_search.call_args.kwargs - - assert call_kwargs["filters"] == expected_filters - assert call_kwargs["num_results"] == vector_search_tool.num_results - assert call_kwargs["query_type"] == vector_search_tool.query_type - assert call_kwargs["columns"] == vector_search_tool.columns - - def test_reranker_is_passed_through(execution_path) -> None: reranker = DatabricksReranker(columns_to_rerank=["country"]) vector_search_tool = init_vector_search_tool(execution_path["index_name"], reranker=reranker) @@ -844,82 +650,3 @@ def test_reranker_is_overriden(execution_path) -> None: assert call_kwargs["filters"] == {"country": "Germany"} assert call_kwargs["num_results"] == vector_search_tool.num_results assert call_kwargs["query_type"] == vector_search_tool.query_type - - -# ============================================================================ -# Response Format Normalization Tests -# ============================================================================ - - -class TestMCPResponseNormalization: - """Test that MCP responses are normalized to match Direct API format.""" - - def test_parse_mcp_response_basic_normalization(self) -> None: - """Test basic normalization of MCP results via _parse_mcp_response.""" - vector_search_tool = init_vector_search_tool(DELTA_SYNC_INDEX) - - mcp_response = json.dumps( - [ - { - "id": "doc-123", - "text": "This is the document content", - "score": 0.95, - } - ] - ) - - results = vector_search_tool._parse_mcp_response(mcp_response) - - assert len(results) == 1 - assert results[0]["page_content"] == "This is the document content" - assert results[0]["metadata"]["id"] == "doc-123" - assert results[0]["metadata"]["score"] == 0.95 - assert "text" not in results[0]["metadata"] # text column moved to page_content - - def test_parse_mcp_response_missing_text_column(self) -> None: - """Test normalization handles missing text column gracefully.""" - vector_search_tool = init_vector_search_tool(DELTA_SYNC_INDEX) - - mcp_response = json.dumps( - [ - { - "id": "doc-789", - "score": 0.75, - # "text" column is missing - } - ] - ) - - results = vector_search_tool._parse_mcp_response(mcp_response) - - assert len(results) == 1 - # When text column is missing, the dict is converted to string - assert results[0]["metadata"]["id"] == "doc-789" - assert results[0]["metadata"]["score"] == 0.75 - - def test_parse_mcp_response_empty_list(self) -> None: - """Test parsing empty MCP response.""" - vector_search_tool = init_vector_search_tool(DELTA_SYNC_INDEX) - - mcp_response = json.dumps([]) - - results = vector_search_tool._parse_mcp_response(mcp_response) - - assert results == [] - - def test_parse_mcp_response_invalid_json(self) -> None: - """Test parsing invalid JSON raises ValueError.""" - vector_search_tool = init_vector_search_tool(DELTA_SYNC_INDEX) - - with pytest.raises(ValueError, match="Unable to parse MCP response"): - vector_search_tool._parse_mcp_response("not valid json {") - - def test_parse_mcp_response_not_a_list(self) -> None: - """Test parsing non-list JSON raises ValueError.""" - vector_search_tool = init_vector_search_tool(DELTA_SYNC_INDEX) - - # MCP should return a list, not a dict - mcp_response = json.dumps({"error": "something went wrong"}) - - with pytest.raises(ValueError, match="Expected JSON array, got"): - vector_search_tool._parse_mcp_response(mcp_response) diff --git a/src/databricks_ai_bridge/vector_search_retriever_tool.py b/src/databricks_ai_bridge/vector_search_retriever_tool.py index a975e9bc9..0fd56f71a 100644 --- a/src/databricks_ai_bridge/vector_search_retriever_tool.py +++ b/src/databricks_ai_bridge/vector_search_retriever_tool.py @@ -328,9 +328,13 @@ def _handle_mcp_creation_error(self, error: Exception) -> None: ) from error def _validate_mcp_tools(self, tools: list) -> None: - """Validate that MCP tools were returned.""" + """Validate that exactly one MCP tool was returned.""" if not tools: raise ValueError(f"No MCP tools found for index {self.index_name}") + if len(tools) != 1: + raise ValueError( + f"Expected exactly 1 MCP tool for index {self.index_name}, but got {len(tools)}" + ) def _handle_mcp_execution_error(self, error: Exception) -> None: """Log and raise standardized error for MCP execution failures.""" @@ -342,12 +346,16 @@ def _handle_mcp_execution_error(self, error: Exception) -> None: def _build_mcp_params( self, filters: Optional[Union[Dict[str, Any], List["FilterItem"]]] = None, + query: Optional[str] = None, **kwargs: Any, ) -> Dict[str, Any]: - """Build common MCP parameters dict (excludes query).""" + """Build common MCP parameters dict.""" kwargs = {**(self.model_extra or {}), **kwargs} params: Dict[str, Any] = {} + if query is not None: + params["query"] = query + num_results = kwargs.pop("num_results", kwargs.pop("k", self.num_results)) if num_results: params["num_results"] = num_results diff --git a/tests/databricks_ai_bridge/test_vector_search_retriever_tool.py b/tests/databricks_ai_bridge/test_vector_search_retriever_tool.py index 20badabdb..67e3fc61a 100644 --- a/tests/databricks_ai_bridge/test_vector_search_retriever_tool.py +++ b/tests/databricks_ai_bridge/test_vector_search_retriever_tool.py @@ -318,3 +318,141 @@ def test_cannot_use_both_dynamic_filter_and_predefined_filters(): filters={"status": "active", "category": "electronics"}, dynamic_filter=True, ) + + +# ============================================================================= +# Tests for _get_tool_name +# ============================================================================= + + +def test_get_tool_name_replaces_dots(): + """Test that dots in index name are replaced with underscores.""" + tool = DummyVectorSearchRetrieverTool(index_name="catalog.schema.my_index") + assert tool._get_tool_name() == "catalog__schema__my_index" + + +def test_get_tool_name_uses_custom_name(): + """Test that custom tool_name is used when provided.""" + tool = DummyVectorSearchRetrieverTool(index_name=index_name, tool_name="custom_tool") + assert tool._get_tool_name() == "custom_tool" + + +def test_get_tool_name_truncates_long_names(): + """Test that long tool names are truncated to 64 characters.""" + long_index = ( + "catalog.schema.really_really_really_long_tool_name_that_should_be_truncated_to_64_chars" + ) + tool = DummyVectorSearchRetrieverTool(index_name=long_index) + result = tool._get_tool_name() + assert len(result) <= 64 + + +# ============================================================================= +# Tests for _validate_mcp_tools +# ============================================================================= + + +def test_validate_mcp_tools_empty_list(): + """Test that empty tools list raises ValueError.""" + tool = DummyVectorSearchRetrieverTool(index_name=index_name) + with pytest.raises(ValueError, match="Expected exactly 1 MCP tool"): + tool._validate_mcp_tools([]) + + +def test_validate_mcp_tools_multiple_tools(): + """Test that multiple tools raises ValueError.""" + tool = DummyVectorSearchRetrieverTool(index_name=index_name) + with pytest.raises(ValueError, match="Expected exactly 1 MCP tool"): + tool._validate_mcp_tools([MagicMock(), MagicMock()]) + + +def test_validate_mcp_tools_single_tool(): + """Test that single tool passes validation.""" + tool = DummyVectorSearchRetrieverTool(index_name=index_name) + # Should not raise + tool._validate_mcp_tools([MagicMock()]) + + +# ============================================================================= +# Tests for _get_filter_param_description +# ============================================================================= + + +def test_get_filter_param_description_includes_column_metadata(): + """Test that _get_filter_param_description includes column metadata when available.""" + from unittest.mock import Mock, patch + + mock_column1 = Mock() + mock_column1.name = "category" + mock_column1.type_name.name = "STRING" + + mock_column2 = Mock() + mock_column2.name = "price" + mock_column2.type_name.name = "FLOAT" + + mock_column3 = Mock() + mock_column3.name = "__internal_column" # Should be excluded + mock_column3.type_name.name = "STRING" + + mock_table_info = Mock() + mock_table_info.columns = [mock_column1, mock_column2, mock_column3] + + with patch("databricks.sdk.WorkspaceClient") as mock_ws_client_class: + mock_ws_client = Mock() + mock_ws_client.tables.get.return_value = mock_table_info + mock_ws_client_class.return_value = mock_ws_client + + tool = DummyVectorSearchRetrieverTool(index_name=index_name) + description = tool._get_filter_param_description() + + # Should include available columns in description + assert "Available columns for filtering: category (STRING), price (FLOAT)" in description + + # Should include comprehensive filter syntax + assert "Inclusion:" in description + assert "Exclusion:" in description + assert "Comparisons:" in description + assert "Pattern match:" in description + assert "OR logic:" in description + + +def test_get_filter_param_description_fails_on_table_metadata_error(): + """Test that _get_filter_param_description fails with clear error when table metadata cannot be retrieved.""" + from unittest.mock import patch + + with patch("databricks.sdk.WorkspaceClient") as mock_ws_client_class: + mock_ws_client = MagicMock() + mock_ws_client.tables.get.side_effect = Exception("Permission denied") + mock_ws_client_class.return_value = mock_ws_client + + tool = DummyVectorSearchRetrieverTool(index_name=index_name) + + with pytest.raises( + ValueError, + match="Failed to retrieve table metadata for index.*Permission denied", + ): + tool._get_filter_param_description() + + +def test_get_filter_param_description_fails_on_empty_columns(): + """Test that _get_filter_param_description fails when table has no valid columns.""" + from unittest.mock import patch + + with patch("databricks.sdk.WorkspaceClient") as mock_ws_client_class: + mock_ws_client = MagicMock() + mock_table = MagicMock() + mock_column = MagicMock() + mock_column.name = "__internal_column" + mock_column.type_name = MagicMock() + mock_column.type_name.name = "STRING" + mock_table.columns = [mock_column] + mock_ws_client.tables.get.return_value = mock_table + mock_ws_client_class.return_value = mock_ws_client + + tool = DummyVectorSearchRetrieverTool(index_name=index_name) + + with pytest.raises( + ValueError, + match="No valid columns found in table metadata for index", + ): + tool._get_filter_param_description() From 9194af1b72f6428888285ee3697161d41539a9ef Mon Sep 17 00:00:00 2001 From: Nisha Balaji Date: Wed, 28 Jan 2026 16:51:00 -0800 Subject: [PATCH 11/11] test fixes --- .../tests/unit_tests/test_vector_search_retriever_tool.py | 3 ++- .../databricks_ai_bridge/test_vector_search_retriever_tool.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/integrations/openai/tests/unit_tests/test_vector_search_retriever_tool.py b/integrations/openai/tests/unit_tests/test_vector_search_retriever_tool.py index daa33bb6c..2cb0f5a99 100644 --- a/integrations/openai/tests/unit_tests/test_vector_search_retriever_tool.py +++ b/integrations/openai/tests/unit_tests/test_vector_search_retriever_tool.py @@ -45,7 +45,8 @@ def mock_openai_client(): mock_response.data = [Mock(embedding=[0.1, 0.2, 0.3, 0.4])] mock_client.embeddings.create.return_value = mock_response with patch("openai.OpenAI", return_value=mock_client): - yield mock_client + with patch("databricks_openai.DatabricksOpenAI", return_value=mock_client): + yield mock_client @pytest.fixture diff --git a/tests/databricks_ai_bridge/test_vector_search_retriever_tool.py b/tests/databricks_ai_bridge/test_vector_search_retriever_tool.py index 67e3fc61a..5a90c1cf4 100644 --- a/tests/databricks_ai_bridge/test_vector_search_retriever_tool.py +++ b/tests/databricks_ai_bridge/test_vector_search_retriever_tool.py @@ -355,7 +355,7 @@ def test_get_tool_name_truncates_long_names(): def test_validate_mcp_tools_empty_list(): """Test that empty tools list raises ValueError.""" tool = DummyVectorSearchRetrieverTool(index_name=index_name) - with pytest.raises(ValueError, match="Expected exactly 1 MCP tool"): + with pytest.raises(ValueError, match="No MCP tools found for index"): tool._validate_mcp_tools([])