From fb578d6fbc34aed0f61331361e78bd14eb9d96aa Mon Sep 17 00:00:00 2001 From: Wanbogang Date: Sat, 28 Feb 2026 19:58:14 +0700 Subject: [PATCH] test(knowledge_base): add missing coverage for FAISSRetriever and KnowledgeBase - FAISSRetriever: cover None index from faiss.read_index, dict metadata with qa_pair format, dict metadata with invalid keys, unsupported metadata type, search/batch_search with None index, and batch_search out-of-bounds index handling - KnowledgeBase: cover format_context qa_pair branch that returns answer instead of text --- .../faiss/test_faiss_retriever.py | 146 ++++++++++++++++++ tests/fuser/knowledge_base/test_retriever.py | 33 ++++ 2 files changed, 179 insertions(+) diff --git a/tests/fuser/knowledge_base/faiss/test_faiss_retriever.py b/tests/fuser/knowledge_base/faiss/test_faiss_retriever.py index 7128b8a3d..009b9fd89 100644 --- a/tests/fuser/knowledge_base/faiss/test_faiss_retriever.py +++ b/tests/fuser/knowledge_base/faiss/test_faiss_retriever.py @@ -288,3 +288,149 @@ def test_batch_search_returns_independent_results(self, mock_faiss_index): all_results[0][0].metadata["modified"] = True assert "modified" not in all_results[1][0].metadata + + +class TestFAISSRetrieverMissingCoverage: + """Tests to cover missing branches in FAISSRetriever.""" + + def test_load_faiss_index_returns_none(self, tmp_path): + """Test ValueError when faiss.read_index returns None.""" + dimension = 384 + index = faiss.IndexFlatL2(dimension) + index_path = tmp_path / "test.faiss" + faiss.write_index(index, str(index_path)) + + metadata_path = tmp_path / "test.pkl" + with open(metadata_path, "wb") as f: + pickle.dump([{"text": "doc", "metadata": {}}], f) + + with patch( + "src.fuser.knowledge_base.faiss.faiss_retriever.faiss.read_index", + return_value=None, + ): + with pytest.raises(ValueError, match="Failed to load FAISS index"): + FAISSRetriever(index_path, metadata_path) + + def test_load_metadata_dict_qa_format(self, tmp_path): + """Test loading metadata in dict format with questions/answers keys.""" + dimension = 384 + index = faiss.IndexFlatL2(dimension) + embeddings = np.random.randn(3, dimension).astype("float32") + index.add(x=embeddings) # type: ignore + index_path = tmp_path / "test.faiss" + faiss.write_index(index, str(index_path)) + + metadata = { + "questions": ["Q1", "Q2", "Q3"], + "answers": ["A1", "A2", "A3"], + } + metadata_path = tmp_path / "test.pkl" + with open(metadata_path, "wb") as f: + pickle.dump(metadata, f) + + retriever = FAISSRetriever(index_path, metadata_path) + + assert len(retriever.documents) == 3 + assert retriever.documents[0].text == "Q1" + assert retriever.documents[0].metadata["answer"] == "A1" + assert retriever.documents[0].metadata["type"] == "qa_pair" + + def test_load_metadata_dict_invalid_keys(self, tmp_path): + """Test ValueError when dict metadata has unsupported keys.""" + dimension = 384 + index = faiss.IndexFlatL2(dimension) + index_path = tmp_path / "test.faiss" + faiss.write_index(index, str(index_path)) + + metadata = {"unsupported_key": "value"} + metadata_path = tmp_path / "test.pkl" + with open(metadata_path, "wb") as f: + pickle.dump(metadata, f) + + with pytest.raises(ValueError, match="Unsupported metadata format"): + FAISSRetriever(index_path, metadata_path) + + def test_load_metadata_unsupported_type(self, tmp_path): + """Test ValueError when metadata is neither list nor dict.""" + dimension = 384 + index = faiss.IndexFlatL2(dimension) + index_path = tmp_path / "test.faiss" + faiss.write_index(index, str(index_path)) + + metadata_path = tmp_path / "test.pkl" + with open(metadata_path, "wb") as f: + pickle.dump("invalid_string_metadata", f) + + with pytest.raises(ValueError, match="Unsupported metadata format"): + FAISSRetriever(index_path, metadata_path) + + def test_search_index_none(self, tmp_path): + """Test ValueError when index is None during search.""" + dimension = 384 + index = faiss.IndexFlatL2(dimension) + embeddings = np.random.randn(2, dimension).astype("float32") + index.add(x=embeddings) # type: ignore + index_path = tmp_path / "test.faiss" + faiss.write_index(index, str(index_path)) + + metadata = [{"text": "doc", "metadata": {}} for _ in range(2)] + metadata_path = tmp_path / "test.pkl" + with open(metadata_path, "wb") as f: + pickle.dump(metadata, f) + + retriever = FAISSRetriever(index_path, metadata_path) + retriever.index = None + + query_embedding = np.random.randn(dimension).astype("float32") + with pytest.raises(ValueError, match="FAISS index not loaded"): + retriever.search(query_embedding, top_k=2) + + def test_batch_search_index_none(self, tmp_path): + """Test ValueError when index is None during batch_search.""" + dimension = 384 + index = faiss.IndexFlatL2(dimension) + embeddings = np.random.randn(2, dimension).astype("float32") + index.add(x=embeddings) # type: ignore + index_path = tmp_path / "test.faiss" + faiss.write_index(index, str(index_path)) + + metadata = [{"text": "doc", "metadata": {}} for _ in range(2)] + metadata_path = tmp_path / "test.pkl" + with open(metadata_path, "wb") as f: + pickle.dump(metadata, f) + + retriever = FAISSRetriever(index_path, metadata_path) + retriever.index = None + + query_embeddings = np.random.randn(2, dimension).astype("float32") + with pytest.raises(ValueError, match="FAISS index not loaded"): + retriever.batch_search(query_embeddings, top_k=2) + + def test_batch_search_out_of_bounds_indices(self, tmp_path): + """Test that batch_search skips out-of-bounds indices.""" + dimension = 384 + index = faiss.IndexFlatL2(dimension) + embeddings = np.random.randn(2, dimension).astype("float32") + index.add(x=embeddings) # type: ignore + index_path = tmp_path / "test.faiss" + faiss.write_index(index, str(index_path)) + + metadata = [{"text": f"Doc {i}", "metadata": {"id": i}} for i in range(2)] + metadata_path = tmp_path / "test.pkl" + with open(metadata_path, "wb") as f: + pickle.dump(metadata, f) + + retriever = FAISSRetriever(index_path, metadata_path) + + with patch.object(retriever.index, "search") as mock_search: + mock_search.return_value = ( + np.array([[0.1, 0.2]], dtype="float32"), + np.array([[0, 999]], dtype="int64"), + ) + + query_embeddings = np.random.randn(1, dimension).astype("float32") + results = retriever.batch_search(query_embeddings, top_k=2) + + assert len(results) == 1 + assert len(results[0]) == 1 + assert results[0][0].text == "Doc 0" diff --git a/tests/fuser/knowledge_base/test_retriever.py b/tests/fuser/knowledge_base/test_retriever.py index 3ed216c8a..c25d3616d 100644 --- a/tests/fuser/knowledge_base/test_retriever.py +++ b/tests/fuser/knowledge_base/test_retriever.py @@ -364,3 +364,36 @@ async def test_query_with_different_top_k(self, mock_kb_structure): call_args = mock_retriever.search.call_args assert call_args[1]["top_k"] == 10 + + def test_format_context_qa_pair_type(self, mock_kb_structure): + """Test format_context with qa_pair type uses answer instead of text.""" + with ( + patch("src.fuser.knowledge_base.retriever.EmbeddingClient"), + patch("src.fuser.knowledge_base.retriever.FAISSRetriever") as mock_ret, + ): + mock_retriever = MagicMock() + mock_retriever.num_documents = 10 + mock_retriever.dimension = 384 + mock_ret.return_value = mock_retriever + + kb = KnowledgeBase( + knowledge_base_name="demo", knowledge_base_root=mock_kb_structure + ) + + docs = [ + Document( + text="What is the question?", + metadata={ + "source": "qa.txt", + "chunk_id": 0, + "type": "qa_pair", + "answer": "This is the answer.", + }, + score=0.95, + ) + ] + + context = kb.format_context(docs) + + assert "This is the answer." in context + assert "What is the question?" not in context