|
9 | 9 | from typing import Dict, Any, List |
10 | 10 | from sentence_transformers import SentenceTransformer |
11 | 11 | from pymilvus import connections, Collection |
| 12 | +import threading |
| 13 | +import time |
12 | 14 |
|
13 | 15 | # Config |
14 | 16 | KSERVE_URL = os.getenv("KSERVE_URL", "http://llama.docs-agent.svc.cluster.local/openai/v1/chat/completions") |
|
66 | 68 |
|
67 | 69 |
|
68 | 70 |
|
| 71 | +_model_lock = threading.Lock() |
| 72 | +_embedding_model = None |
| 73 | + |
| 74 | +def get_embedding_model(): |
| 75 | + """Thread-safe lazy initialization of SentenceTransformer""" |
| 76 | + global _embedding_model |
| 77 | + if _embedding_model is None: |
| 78 | + with _model_lock: |
| 79 | + # Double-checked locking |
| 80 | + if _embedding_model is None: |
| 81 | + start_t = time.perf_counter() |
| 82 | + print(f"[INFO] Lazy loading SentenceTransformer model '{EMBEDDING_MODEL}'...") |
| 83 | + _embedding_model = SentenceTransformer(EMBEDDING_MODEL) |
| 84 | + print(f"[INFO] Model loaded in {time.perf_counter() - start_t:.3f} seconds.") |
| 85 | + return _embedding_model |
| 86 | + |
| 87 | +def milvus_search(query: str, top_k: int = 5) -> Dict[str, Any]: |
| 88 | + """Execute a semantic search in Milvus and return structured JSON serializable results.""" |
| 89 | + try: |
| 90 | + # Connect to Milvus |
| 91 | + connections.connect(alias="default", host=MILVUS_HOST, port=MILVUS_PORT) |
| 92 | + collection = Collection(MILVUS_COLLECTION) |
| 93 | + collection.load() |
| 94 | + |
| 95 | + # Thread-safe cached encoder |
| 96 | + model = get_embedding_model() |
| 97 | + query_vec = model.encode(query).tolist() |
| 98 | + |
| 99 | + search_params = {"metric_type": "COSINE", "params": {"nprobe": 32}} |
| 100 | + results = collection.search( |
| 101 | + data=[query_vec], |
| 102 | + anns_field=MILVUS_VECTOR_FIELD, |
| 103 | + param=search_params, |
| 104 | + limit=int(top_k), |
| 105 | + output_fields=["file_path", "content_text", "citation_url"], |
| 106 | + ) |
| 107 | + |
| 108 | + hits = [] |
| 109 | + for hit in results[0]: |
| 110 | + similarity = 1.0 - float(hit.distance) |
| 111 | + entity = hit.entity |
| 112 | + content_text = entity.get("content_text") or "" |
| 113 | + if isinstance(content_text, str) and len(content_text) > 400: |
| 114 | + content_text = content_text[:400] + "..." |
| 115 | + hits.append({ |
| 116 | + "similarity": similarity, |
| 117 | + "file_path": entity.get("file_path"), |
| 118 | + "citation_url": entity.get("citation_url"), |
| 119 | + "content_text": content_text, |
| 120 | + }) |
| 121 | + return {"results": hits} |
| 122 | + except Exception as e: |
| 123 | + print(f"[ERROR] Milvus search failed: {e}") |
| 124 | + return {"results": []} |
| 125 | + finally: |
| 126 | + try: |
| 127 | + connections.disconnect(alias="default") |
| 128 | + except Exception: |
| 129 | + pass |
| 130 | + |
69 | 131 | TOOLS = [ |
70 | 132 | { |
71 | 133 | "type": "function", |
|
0 commit comments