Skip to content

Commit f1c0cd0

Browse files
committed
fix(performance): avoid repeated SentenceTransformer loading in milvus_search
Implemented thread-safe lazy-loading for SentenceTransformer to eliminate redundant loading within milvus_search. Signed-off-by: Antigravity <antigravity@google.com>
1 parent f05614a commit f1c0cd0

2 files changed

Lines changed: 77 additions & 7 deletions

File tree

server-https/app.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,23 @@
99
from typing import Dict, Any, List, Optional, AsyncGenerator
1010
from sentence_transformers import SentenceTransformer
1111
from pymilvus import connections, Collection
12+
import threading
13+
import time
1214

13-
14-
embedding_model = None
15-
# Load embedding model once to avoid repeated initialization overhead
15+
_model_lock = threading.Lock()
16+
_embedding_model = None
1617
def get_embedding_model():
17-
global embedding_model
18-
if embedding_model is None:
19-
embedding_model = SentenceTransformer(EMBEDDING_MODEL)
20-
return embedding_model
18+
"""Thread-safe lazy initialization of SentenceTransformer"""
19+
global _embedding_model
20+
if _embedding_model is None:
21+
with _model_lock:
22+
# Double-checked locking
23+
if _embedding_model is None:
24+
start_t = time.perf_counter()
25+
print(f"[INFO] Lazy loading SentenceTransformer model '{EMBEDDING_MODEL}'...")
26+
_embedding_model = SentenceTransformer(EMBEDDING_MODEL)
27+
print(f"[INFO] Model loaded in {time.perf_counter() - start_t:.3f} seconds.")
28+
return _embedding_model
2129

2230
# Config
2331
KSERVE_URL = os.getenv("KSERVE_URL", "http://llama.docs-agent.svc.cluster.local/openai/v1/chat/completions")

server/app.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
from typing import Dict, Any, List
1010
from sentence_transformers import SentenceTransformer
1111
from pymilvus import connections, Collection
12+
import threading
13+
import time
1214

1315
# Config
1416
KSERVE_URL = os.getenv("KSERVE_URL", "http://llama.docs-agent.svc.cluster.local/openai/v1/chat/completions")
@@ -66,6 +68,66 @@
6668

6769

6870

71+
_model_lock = threading.Lock()
72+
_embedding_model = None
73+
74+
def get_embedding_model():
75+
"""Thread-safe lazy initialization of SentenceTransformer"""
76+
global _embedding_model
77+
if _embedding_model is None:
78+
with _model_lock:
79+
# Double-checked locking
80+
if _embedding_model is None:
81+
start_t = time.perf_counter()
82+
print(f"[INFO] Lazy loading SentenceTransformer model '{EMBEDDING_MODEL}'...")
83+
_embedding_model = SentenceTransformer(EMBEDDING_MODEL)
84+
print(f"[INFO] Model loaded in {time.perf_counter() - start_t:.3f} seconds.")
85+
return _embedding_model
86+
87+
def milvus_search(query: str, top_k: int = 5) -> Dict[str, Any]:
88+
"""Execute a semantic search in Milvus and return structured JSON serializable results."""
89+
try:
90+
# Connect to Milvus
91+
connections.connect(alias="default", host=MILVUS_HOST, port=MILVUS_PORT)
92+
collection = Collection(MILVUS_COLLECTION)
93+
collection.load()
94+
95+
# Thread-safe cached encoder
96+
model = get_embedding_model()
97+
query_vec = model.encode(query).tolist()
98+
99+
search_params = {"metric_type": "COSINE", "params": {"nprobe": 32}}
100+
results = collection.search(
101+
data=[query_vec],
102+
anns_field=MILVUS_VECTOR_FIELD,
103+
param=search_params,
104+
limit=int(top_k),
105+
output_fields=["file_path", "content_text", "citation_url"],
106+
)
107+
108+
hits = []
109+
for hit in results[0]:
110+
similarity = 1.0 - float(hit.distance)
111+
entity = hit.entity
112+
content_text = entity.get("content_text") or ""
113+
if isinstance(content_text, str) and len(content_text) > 400:
114+
content_text = content_text[:400] + "..."
115+
hits.append({
116+
"similarity": similarity,
117+
"file_path": entity.get("file_path"),
118+
"citation_url": entity.get("citation_url"),
119+
"content_text": content_text,
120+
})
121+
return {"results": hits}
122+
except Exception as e:
123+
print(f"[ERROR] Milvus search failed: {e}")
124+
return {"results": []}
125+
finally:
126+
try:
127+
connections.disconnect(alias="default")
128+
except Exception:
129+
pass
130+
69131
TOOLS = [
70132
{
71133
"type": "function",

0 commit comments

Comments
 (0)