1111from pymilvus import connections , Collection
1212
1313
14- embedding_model = SentenceTransformer (EMBEDDING_MODEL )
14+ embedding_model = None
15+ # Load embedding model once to avoid repeated initialization overhead
16+ def get_embedding_model ():
17+ global embedding_model
18+ if embedding_model is None :
19+ embedding_model = SentenceTransformer (EMBEDDING_MODEL )
20+ return embedding_model
1521
1622# Config
1723KSERVE_URL = os .getenv ("KSERVE_URL" , "http://llama.docs-agent.svc.cluster.local/openai/v1/chat/completions" )
@@ -124,7 +130,8 @@ def milvus_search(query: str, top_k: int = 5) -> Dict[str, Any]:
124130 collection .load ()
125131
126132 # Encoder (same model as pipeline)
127- query_vec = embedding_model .encode (query ).tolist ()
133+ model = get_embedding_model ()
134+ query_vec = model .encode (query ).tolist ()
128135
129136 search_params = {"metric_type" : "COSINE" , "params" : {"nprobe" : 32 }}
130137 results = collection .search (
@@ -197,9 +204,11 @@ async def execute_tool(tool_call: Dict[str, Any]) -> tuple[str, List[str]]:
197204 print (f"[ERROR] Tool execution failed: { e } " )
198205 return f"Tool execution failed: { e } " , []
199206
200- async def stream_llm_response (payload : Dict [str , Any ]) -> AsyncGenerator [str , None ]:
207+ async def stream_llm_response (payload : Dict [str , Any ], citations_collector : Optional [ List [ str ]] = None ) -> AsyncGenerator [str , None ]:
201208 """Stream response from LLM and handle tool calls, yielding SSE events"""
202- citations_collector = []
209+ is_outermost = citations_collector is None
210+ if citations_collector is None :
211+ citations_collector = []
203212
204213 try :
205214 async with httpx .AsyncClient (timeout = 120 ) as client :
@@ -298,7 +307,7 @@ async def stream_llm_response(payload: Dict[str, Any]) -> AsyncGenerator[str, No
298307 continue
299308
300309 # Send citations if any were collected
301- if citations_collector :
310+ if is_outermost and citations_collector :
302311 # Remove duplicates while preserving order
303312 unique_citations = []
304313 for citation in citations_collector :
@@ -308,7 +317,8 @@ async def stream_llm_response(payload: Dict[str, Any]) -> AsyncGenerator[str, No
308317 yield f"data: { json .dumps ({'type' : 'citations' , 'citations' : unique_citations })} \n \n "
309318
310319 # Send completion signal
311- yield f"data: { json .dumps ({'type' : 'done' })} \n \n "
320+ if is_outermost :
321+ yield f"data: { json .dumps ({'type' : 'done' })} \n \n "
312322
313323 except Exception as e :
314324 print (f"[ERROR] Streaming failed: { e } " )
@@ -344,7 +354,7 @@ async def handle_tool_follow_up(original_payload: Dict[str, Any], tool_call: Dic
344354 }
345355
346356 # Stream the follow-up response
347- async for chunk in stream_llm_response (follow_up_payload ):
357+ async for chunk in stream_llm_response (follow_up_payload , citations_collector = citations_collector ):
348358 yield chunk
349359
350360 except Exception as e :
0 commit comments