Skip to content

Commit f1c478d

Browse files
committed
fix(server-https): preserve multi-hop citations in stream_llm_response
Signed-off-by: Ayush-kathil <kathilshiva@gmail.com>
1 parent 2629a40 commit f1c478d

1 file changed

Lines changed: 17 additions & 7 deletions

File tree

server-https/app.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,13 @@
1111
from pymilvus import connections, Collection
1212

1313

14-
embedding_model = SentenceTransformer(EMBEDDING_MODEL)
14+
embedding_model = None
15+
# Load embedding model once to avoid repeated initialization overhead
16+
def get_embedding_model():
17+
global embedding_model
18+
if embedding_model is None:
19+
embedding_model = SentenceTransformer(EMBEDDING_MODEL)
20+
return embedding_model
1521

1622
# Config
1723
KSERVE_URL = os.getenv("KSERVE_URL", "http://llama.docs-agent.svc.cluster.local/openai/v1/chat/completions")
@@ -124,7 +130,8 @@ def milvus_search(query: str, top_k: int = 5) -> Dict[str, Any]:
124130
collection.load()
125131

126132
# Encoder (same model as pipeline)
127-
query_vec = embedding_model.encode(query).tolist()
133+
model = get_embedding_model()
134+
query_vec = model.encode(query).tolist()
128135

129136
search_params = {"metric_type": "COSINE", "params": {"nprobe": 32}}
130137
results = collection.search(
@@ -197,9 +204,11 @@ async def execute_tool(tool_call: Dict[str, Any]) -> tuple[str, List[str]]:
197204
print(f"[ERROR] Tool execution failed: {e}")
198205
return f"Tool execution failed: {e}", []
199206

200-
async def stream_llm_response(payload: Dict[str, Any]) -> AsyncGenerator[str, None]:
207+
async def stream_llm_response(payload: Dict[str, Any], citations_collector: Optional[List[str]] = None) -> AsyncGenerator[str, None]:
201208
"""Stream response from LLM and handle tool calls, yielding SSE events"""
202-
citations_collector = []
209+
is_outermost = citations_collector is None
210+
if citations_collector is None:
211+
citations_collector = []
203212

204213
try:
205214
async with httpx.AsyncClient(timeout=120) as client:
@@ -298,7 +307,7 @@ async def stream_llm_response(payload: Dict[str, Any]) -> AsyncGenerator[str, No
298307
continue
299308

300309
# Send citations if any were collected
301-
if citations_collector:
310+
if is_outermost and citations_collector:
302311
# Remove duplicates while preserving order
303312
unique_citations = []
304313
for citation in citations_collector:
@@ -308,7 +317,8 @@ async def stream_llm_response(payload: Dict[str, Any]) -> AsyncGenerator[str, No
308317
yield f"data: {json.dumps({'type': 'citations', 'citations': unique_citations})}\n\n"
309318

310319
# Send completion signal
311-
yield f"data: {json.dumps({'type': 'done'})}\n\n"
320+
if is_outermost:
321+
yield f"data: {json.dumps({'type': 'done'})}\n\n"
312322

313323
except Exception as e:
314324
print(f"[ERROR] Streaming failed: {e}")
@@ -344,7 +354,7 @@ async def handle_tool_follow_up(original_payload: Dict[str, Any], tool_call: Dic
344354
}
345355

346356
# Stream the follow-up response
347-
async for chunk in stream_llm_response(follow_up_payload):
357+
async for chunk in stream_llm_response(follow_up_payload, citations_collector=citations_collector):
348358
yield chunk
349359

350360
except Exception as e:

0 commit comments

Comments
 (0)