Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 53 additions & 13 deletions rag-engine/src/common/utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,20 @@
import os
from pathlib import Path as FilePath
from fastembed import TextEmbedding
from fastembed import SparseTextEmbedding, TextEmbedding
from fastembed.common.model_description import ModelSource, PoolingType
from qdrant_client import QdrantClient, models
from qdrant_client.models import Distance, FieldCondition, MatchValue, VectorParams
from qdrant_client.conversions.common_types import SparseVectorParams
from fastembed.rerank.cross_encoder import TextCrossEncoder
import json
from fastapi import HTTPException, status
from qdrant_client.models import Optional
from qdrant_client.models import (
Distance,
FieldCondition,
MatchValue,
VectorParams,
)


CACHE_DIR = FilePath("./models_cache")
CACHE_DIR.mkdir(exist_ok=True)
Expand All @@ -18,10 +29,18 @@
dim=VECTOR_SIZE,
model_file="onnx/model.onnx",
)
embedding_model = TextEmbedding(
dense_embedding = TextEmbedding(
model_name="intfloat/multilingual-e5-small",
cache_dir=str(CACHE_DIR),
)
sparse_embedding = SparseTextEmbedding(
model_name="prithivida/Splade_PP_en_v1",
cache_dir=str(CACHE_DIR),
)
reranker = TextCrossEncoder(
model_name="Xenova/ms-marco-MiniLM-L-12-v2",
cache_dir=str(CACHE_DIR),
)

qclient = QdrantClient(
url=os.getenv("QDRANT_DB_URL"),
Expand All @@ -31,20 +50,25 @@
if COLLECTION_NAME not in [c.name for c in qclient.get_collections().collections]:
qclient.create_collection(
collection_name=COLLECTION_NAME,
vectors_config=VectorParams(
size=VECTOR_SIZE,
distance=Distance.COSINE,
),
vectors_config={
"text-dense": VectorParams(
size=VECTOR_SIZE,
distance=Distance.COSINE,
),
},
sparse_vectors_config={
"text-sparse": SparseVectorParams(index=models.SparseIndexParams()),
},
)
qclient.create_payload_index(
collection_name=COLLECTION_NAME,
field_name="_file_hash",
field_schema=models.PayloadSchemaType.KEYWORD,
collection_name=COLLECTION_NAME,
field_name="_file_hash",
field_schema=models.PayloadSchemaType.KEYWORD,
)
qclient.create_payload_index(
collection_name=COLLECTION_NAME,
field_name="_user_id",
field_schema=models.PayloadSchemaType.KEYWORD,
collection_name=COLLECTION_NAME,
field_name="_user_id",
field_schema=models.PayloadSchemaType.KEYWORD,
)


Expand All @@ -67,3 +91,19 @@ def document_exists(user_id: str, file_hash: str) -> bool:
)

return len(results[0]) > 0





def parse_metadata(metadata_str: Optional[str]) -> dict:
"""Parse JSON metadata string, return empty dict if None."""
if metadata_str:
try:
return json.loads(metadata_str)
except json.JSONDecodeError:
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail="Invalid JSON in metadata field",
)
return {}
Loading