Skip to content

Commit 4e3c06d

Browse files
committed
refactor(vector): decouple VectorRetriever.load() from FAISSVectorStore
Accept vectorstore: VectorStorePort as a parameter instead of instantiating FAISSVectorStore internally. components/ no longer imports from integrations/. Store restoration is the caller's responsibility.
1 parent 6de5d57 commit 4e3c06d

File tree

2 files changed

+17
-12
lines changed

2 files changed

+17
-12
lines changed

src/lang2sql/components/retrieval/vector.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -200,36 +200,39 @@ def load(
200200
cls,
201201
path: str,
202202
*,
203+
vectorstore: VectorStorePort,
203204
embedding: EmbeddingPort,
204205
top_n: int = 5,
205206
score_threshold: float = 0.0,
206207
name: Optional[str] = None,
207208
hook: Optional[TraceHook] = None,
208209
) -> "VectorRetriever":
209-
"""저장된 인덱스와 registry를 복원해 VectorRetriever를 반환.
210+
"""저장된 registry를 복원해 VectorRetriever를 반환.
210211
211-
save()로 저장한 path를 그대로 전달한다.
212-
embedding은 쿼리 시 embed_query()에 사용되므로 반드시 전달해야 한다.
212+
벡터 인덱스 복원은 호출자가 직접 수행한 뒤 vectorstore로 전달한다.
213+
이렇게 하면 VectorRetriever가 특정 store 구현체에 의존하지 않는다.
213214
214215
Args:
215-
path: save() 시 사용한 경로.
216-
embedding: EmbeddingPort 구현체.
217-
top_n: 최대 반환 스키마/컨텍스트 수. 기본 5.
216+
path: save() 시 사용한 경로 (registry 파일 위치 기준).
217+
vectorstore: 이미 로드된 VectorStorePort 구현체.
218+
embedding: EmbeddingPort 구현체.
219+
top_n: 최대 반환 스키마/컨텍스트 수. 기본 5.
218220
score_threshold: 이 점수 이하는 결과에서 제외. 기본 0.0.
221+
222+
Example:
223+
store = FAISSVectorStore.load(path)
224+
retriever = VectorRetriever.load(path, vectorstore=store, embedding=emb)
219225
"""
220226
import json
221227
import pathlib
222228

223-
from ...integrations.vectorstore.faiss_ import FAISSVectorStore
224-
225229
registry_path = pathlib.Path(path + ".registry")
226230
if not registry_path.exists():
227231
raise FileNotFoundError(f"Registry file not found: {registry_path}")
228232

229-
store = FAISSVectorStore.load(path)
230233
registry = json.loads(registry_path.read_text(encoding="utf-8"))
231234
return cls(
232-
vectorstore=store,
235+
vectorstore=vectorstore,
233236
embedding=embedding,
234237
registry=registry,
235238
top_n=top_n,

tests/test_components_vector_retriever.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -538,7 +538,8 @@ def test_save_and_load_returns_same_results(tmp_path):
538538
original = VectorRetriever.from_chunks(chunks, embedding=embedding, vectorstore=store)
539539
original.save(path)
540540

541-
loaded = VectorRetriever.load(path, embedding=embedding)
541+
loaded_store = FAISSVectorStore.load(path)
542+
loaded = VectorRetriever.load(path, vectorstore=loaded_store, embedding=embedding)
542543
result = loaded.run("주문 정보")
543544

544545
assert len(result.schemas) > 0
@@ -557,7 +558,8 @@ def test_load_registry_intact(tmp_path):
557558
original = VectorRetriever.from_chunks(chunks, embedding=embedding, vectorstore=store)
558559
original.save(path)
559560

560-
loaded = VectorRetriever.load(path, embedding=embedding)
561+
loaded_store = FAISSVectorStore.load(path)
562+
loaded = VectorRetriever.load(path, vectorstore=loaded_store, embedding=embedding)
561563

562564
assert set(loaded._registry.keys()) == set(original._registry.keys())
563565
for chunk_id, chunk in original._registry.items():

0 commit comments

Comments
 (0)