Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
146 changes: 146 additions & 0 deletions tests/fuser/knowledge_base/faiss/test_faiss_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,3 +288,149 @@ def test_batch_search_returns_independent_results(self, mock_faiss_index):
all_results[0][0].metadata["modified"] = True

assert "modified" not in all_results[1][0].metadata


class TestFAISSRetrieverMissingCoverage:
"""Tests to cover missing branches in FAISSRetriever."""

def test_load_faiss_index_returns_none(self, tmp_path):
"""Test ValueError when faiss.read_index returns None."""
dimension = 384
index = faiss.IndexFlatL2(dimension)
index_path = tmp_path / "test.faiss"
faiss.write_index(index, str(index_path))

metadata_path = tmp_path / "test.pkl"
with open(metadata_path, "wb") as f:
pickle.dump([{"text": "doc", "metadata": {}}], f)

with patch(
"src.fuser.knowledge_base.faiss.faiss_retriever.faiss.read_index",
return_value=None,
):
with pytest.raises(ValueError, match="Failed to load FAISS index"):
FAISSRetriever(index_path, metadata_path)

def test_load_metadata_dict_qa_format(self, tmp_path):
"""Test loading metadata in dict format with questions/answers keys."""
dimension = 384
index = faiss.IndexFlatL2(dimension)
embeddings = np.random.randn(3, dimension).astype("float32")
index.add(x=embeddings) # type: ignore
index_path = tmp_path / "test.faiss"
faiss.write_index(index, str(index_path))

metadata = {
"questions": ["Q1", "Q2", "Q3"],
"answers": ["A1", "A2", "A3"],
}
metadata_path = tmp_path / "test.pkl"
with open(metadata_path, "wb") as f:
pickle.dump(metadata, f)

retriever = FAISSRetriever(index_path, metadata_path)

assert len(retriever.documents) == 3
assert retriever.documents[0].text == "Q1"
assert retriever.documents[0].metadata["answer"] == "A1"
assert retriever.documents[0].metadata["type"] == "qa_pair"

def test_load_metadata_dict_invalid_keys(self, tmp_path):
"""Test ValueError when dict metadata has unsupported keys."""
dimension = 384
index = faiss.IndexFlatL2(dimension)
index_path = tmp_path / "test.faiss"
faiss.write_index(index, str(index_path))

metadata = {"unsupported_key": "value"}
metadata_path = tmp_path / "test.pkl"
with open(metadata_path, "wb") as f:
pickle.dump(metadata, f)

with pytest.raises(ValueError, match="Unsupported metadata format"):
FAISSRetriever(index_path, metadata_path)

def test_load_metadata_unsupported_type(self, tmp_path):
"""Test ValueError when metadata is neither list nor dict."""
dimension = 384
index = faiss.IndexFlatL2(dimension)
index_path = tmp_path / "test.faiss"
faiss.write_index(index, str(index_path))

metadata_path = tmp_path / "test.pkl"
with open(metadata_path, "wb") as f:
pickle.dump("invalid_string_metadata", f)

with pytest.raises(ValueError, match="Unsupported metadata format"):
FAISSRetriever(index_path, metadata_path)

def test_search_index_none(self, tmp_path):
"""Test ValueError when index is None during search."""
dimension = 384
index = faiss.IndexFlatL2(dimension)
embeddings = np.random.randn(2, dimension).astype("float32")
index.add(x=embeddings) # type: ignore
index_path = tmp_path / "test.faiss"
faiss.write_index(index, str(index_path))

metadata = [{"text": "doc", "metadata": {}} for _ in range(2)]
metadata_path = tmp_path / "test.pkl"
with open(metadata_path, "wb") as f:
pickle.dump(metadata, f)

retriever = FAISSRetriever(index_path, metadata_path)
retriever.index = None

query_embedding = np.random.randn(dimension).astype("float32")
with pytest.raises(ValueError, match="FAISS index not loaded"):
retriever.search(query_embedding, top_k=2)

def test_batch_search_index_none(self, tmp_path):
"""Test ValueError when index is None during batch_search."""
dimension = 384
index = faiss.IndexFlatL2(dimension)
embeddings = np.random.randn(2, dimension).astype("float32")
index.add(x=embeddings) # type: ignore
index_path = tmp_path / "test.faiss"
faiss.write_index(index, str(index_path))

metadata = [{"text": "doc", "metadata": {}} for _ in range(2)]
metadata_path = tmp_path / "test.pkl"
with open(metadata_path, "wb") as f:
pickle.dump(metadata, f)

retriever = FAISSRetriever(index_path, metadata_path)
retriever.index = None

query_embeddings = np.random.randn(2, dimension).astype("float32")
with pytest.raises(ValueError, match="FAISS index not loaded"):
retriever.batch_search(query_embeddings, top_k=2)

def test_batch_search_out_of_bounds_indices(self, tmp_path):
"""Test that batch_search skips out-of-bounds indices."""
dimension = 384
index = faiss.IndexFlatL2(dimension)
embeddings = np.random.randn(2, dimension).astype("float32")
index.add(x=embeddings) # type: ignore
index_path = tmp_path / "test.faiss"
faiss.write_index(index, str(index_path))

metadata = [{"text": f"Doc {i}", "metadata": {"id": i}} for i in range(2)]
metadata_path = tmp_path / "test.pkl"
with open(metadata_path, "wb") as f:
pickle.dump(metadata, f)

retriever = FAISSRetriever(index_path, metadata_path)

with patch.object(retriever.index, "search") as mock_search:
mock_search.return_value = (
np.array([[0.1, 0.2]], dtype="float32"),
np.array([[0, 999]], dtype="int64"),
)

query_embeddings = np.random.randn(1, dimension).astype("float32")
results = retriever.batch_search(query_embeddings, top_k=2)

assert len(results) == 1
assert len(results[0]) == 1
assert results[0][0].text == "Doc 0"
33 changes: 33 additions & 0 deletions tests/fuser/knowledge_base/test_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,3 +364,36 @@ async def test_query_with_different_top_k(self, mock_kb_structure):

call_args = mock_retriever.search.call_args
assert call_args[1]["top_k"] == 10

def test_format_context_qa_pair_type(self, mock_kb_structure):
"""Test format_context with qa_pair type uses answer instead of text."""
with (
patch("src.fuser.knowledge_base.retriever.EmbeddingClient"),
patch("src.fuser.knowledge_base.retriever.FAISSRetriever") as mock_ret,
):
mock_retriever = MagicMock()
mock_retriever.num_documents = 10
mock_retriever.dimension = 384
mock_ret.return_value = mock_retriever

kb = KnowledgeBase(
knowledge_base_name="demo", knowledge_base_root=mock_kb_structure
)

docs = [
Document(
text="What is the question?",
metadata={
"source": "qa.txt",
"chunk_id": 0,
"type": "qa_pair",
"answer": "This is the answer.",
},
score=0.95,
)
]

context = kb.format_context(docs)

assert "This is the answer." in context
assert "What is the question?" not in context
Loading