Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 63 additions & 30 deletions server-https/app.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import sys
import json
import httpx
from fastapi import FastAPI, HTTPException
Expand All @@ -9,6 +10,13 @@
from typing import Dict, Any, List, Optional, AsyncGenerator
from sentence_transformers import SentenceTransformer
from pymilvus import connections, Collection
import logging

# Allow importing shared utilities from the repo root.
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from shared.retry import DEGRADED_RESULT, with_retry # noqa: E402

logger = logging.getLogger(__name__)

# Config
KSERVE_URL = os.getenv("KSERVE_URL", "http://llama.docs-agent.svc.cluster.local/openai/v1/chat/completions")
Expand Down Expand Up @@ -111,17 +119,35 @@ class ChatRequest(BaseModel):
message: str
stream: Optional[bool] = True

# Load the embedding model once at module startup to avoid per-query overhead.
_encoder: SentenceTransformer | None = None


def _get_encoder() -> SentenceTransformer:
"""Return the module-level encoder singleton, loading it on first call."""
global _encoder
if _encoder is None:
logger.info("[STARTUP] Loading embedding model: %s", EMBEDDING_MODEL)
_encoder = SentenceTransformer(EMBEDDING_MODEL)
return _encoder


@with_retry(
max_attempts=3,
base_delay=1.0,
max_delay=10.0,
backoff_factor=2.0,
jitter=True,
exceptions=(Exception,),
)
def milvus_search(query: str, top_k: int = 5) -> Dict[str, Any]:
"""Execute a semantic search in Milvus and return structured JSON serializable results."""
"""Execute a semantic search in Milvus; raises on failure to allow retry."""
try:
# Connect to Milvus
connections.connect(alias="default", host=MILVUS_HOST, port=MILVUS_PORT)
collection = Collection(MILVUS_COLLECTION)
collection.load()

# Encoder (same model as pipeline)
encoder = SentenceTransformer(EMBEDDING_MODEL)
query_vec = encoder.encode(query).tolist()
query_vec = _get_encoder().encode(query).tolist()

search_params = {"metric_type": "COSINE", "params": {"nprobe": 32}}
results = collection.search(
Expand All @@ -140,58 +166,65 @@ def milvus_search(query: str, top_k: int = 5) -> Dict[str, Any]:
content_text = entity.get("content_text") or ""
if isinstance(content_text, str) and len(content_text) > 400:
content_text = content_text[:400] + "..."
hits.append({
"similarity": similarity,
"file_path": entity.get("file_path"),
"citation_url": entity.get("citation_url"),
"content_text": content_text,
})
hits.append(
{
"similarity": similarity,
"file_path": entity.get("file_path"),
"citation_url": entity.get("citation_url"),
"content_text": content_text,
}
)
return {"results": hits}
except Exception as e:
print(f"[ERROR] Milvus search failed: {e}")
return {"results": []}
finally:
try:
connections.disconnect(alias="default")
except Exception:
pass

async def execute_tool(tool_call: Dict[str, Any]) -> tuple[str, List[str]]:
"""Execute a tool call and return the result and citations"""
"""Execute a tool call; offloads Milvus search to a thread and degrades gracefully on failure."""
from fastapi.concurrency import run_in_threadpool

try:
function_name = tool_call.get("function", {}).get("name")
arguments = json.loads(tool_call.get("function", {}).get("arguments", "{}"))

if function_name == "search_kubeflow_docs":
query = arguments.get("query", "")
top_k = arguments.get("top_k", 5)

print(f"[TOOL] Executing Milvus search for: '{query}' (top_k={top_k})")
result = milvus_search(query, top_k)

# Collect citations
citations = []

logger.info("[TOOL] Executing Milvus search: '%s' (top_k=%d)", query, top_k)
try:
result = await run_in_threadpool(milvus_search, query, top_k)
except Exception as exc:
# milvus_search exhausted all @with_retry attempts.
logger.error("[TOOL] milvus_search failed after all retries: %s", exc)
return DEGRADED_RESULT, []

citations: List[str] = []
formatted_results = []

for hit in result.get("results", []):
citation_url = hit.get('citation_url', '')
citation_url = hit.get("citation_url", "")
if citation_url and citation_url not in citations:
citations.append(citation_url)

formatted_results.append(
f"File: {hit.get('file_path', 'Unknown')}\n"
f"Content: {hit.get('content_text', '')}\n"
f"URL: {citation_url}\n"
f"Similarity: {hit.get('similarity', 0):.3f}\n"
)

formatted_text = "\n".join(formatted_results) if formatted_results else "No relevant results found."

formatted_text = (
"\n".join(formatted_results) if formatted_results else "No relevant results found."
)
return formatted_text, citations

return f"Unknown tool: {function_name}", []

except Exception as e:
print(f"[ERROR] Tool execution failed: {e}")
logger.error("[TOOL] Tool execution failed: %s", e)
return f"Tool execution failed: {e}", []

async def stream_llm_response(payload: Dict[str, Any]) -> AsyncGenerator[str, None]:
Expand Down
90 changes: 60 additions & 30 deletions server/app.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import sys
import json
import asyncio
import httpx
Expand All @@ -10,6 +11,12 @@
from sentence_transformers import SentenceTransformer
from pymilvus import connections, Collection

# Allow importing shared utilities from the repo root.
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from shared.retry import DEGRADED_RESULT, with_retry # noqa: E402

logger = logging.getLogger(__name__)

# Config
KSERVE_URL = os.getenv("KSERVE_URL", "http://llama.docs-agent.svc.cluster.local/openai/v1/chat/completions")
MODEL = os.getenv("MODEL", "llama3.1-8B")
Expand Down Expand Up @@ -65,17 +72,35 @@



# Load the embedding model once at module startup to avoid per-query overhead.
_encoder: SentenceTransformer | None = None


def _get_encoder() -> SentenceTransformer:
"""Return the module-level encoder singleton, loading it on first call."""
global _encoder
if _encoder is None:
logger.info("[STARTUP] Loading embedding model: %s", EMBEDDING_MODEL)
_encoder = SentenceTransformer(EMBEDDING_MODEL)
return _encoder


@with_retry(
max_attempts=3,
base_delay=1.0,
max_delay=10.0,
backoff_factor=2.0,
jitter=True,
exceptions=(Exception,),
)
def milvus_search(query: str, top_k: int = 5) -> Dict[str, Any]:
"""Execute a semantic search in Milvus and return structured JSON serializable results."""
"""Execute a semantic search in Milvus; raises on failure to allow retry."""
try:
# Connect to Milvus
connections.connect(alias="default", host=MILVUS_HOST, port=MILVUS_PORT)
collection = Collection(MILVUS_COLLECTION)
collection.load()

# Encoder (same model as pipeline)
encoder = SentenceTransformer(EMBEDDING_MODEL)
query_vec = encoder.encode(query).tolist()
query_vec = _get_encoder().encode(query).tolist()

search_params = {"metric_type": "COSINE", "params": {"nprobe": 32}}
results = collection.search(
Expand All @@ -94,16 +119,15 @@ def milvus_search(query: str, top_k: int = 5) -> Dict[str, Any]:
content_text = entity.get("content_text") or ""
if isinstance(content_text, str) and len(content_text) > 400:
content_text = content_text[:400] + "..."
hits.append({
"similarity": similarity,
"file_path": entity.get("file_path"),
"citation_url": entity.get("citation_url"),
"content_text": content_text,
})
hits.append(
{
"similarity": similarity,
"file_path": entity.get("file_path"),
"citation_url": entity.get("citation_url"),
"content_text": content_text,
}
)
return {"results": hits}
except Exception as e:
print(f"[ERROR] Milvus search failed: {e}")
return {"results": []}
finally:
try:
connections.disconnect(alias="default")
Expand Down Expand Up @@ -144,41 +168,47 @@ def milvus_search(query: str, top_k: int = 5) -> Dict[str, Any]:


async def execute_tool(tool_call: Dict[str, Any]) -> tuple[str, List[str]]:
"""Execute a tool call and return the result and citations"""
"""Execute a tool call; offloads Milvus search to a thread and degrades gracefully on failure."""
try:
function_name = tool_call.get("function", {}).get("name")
arguments = json.loads(tool_call.get("function", {}).get("arguments", "{}"))

if function_name == "search_kubeflow_docs":
query = arguments.get("query", "")
top_k = arguments.get("top_k", 5)

print(f"[TOOL] Executing Milvus search for: '{query}' (top_k={top_k})")
result = milvus_search(query, top_k)

# Collect citations
citations = []

logger.info("[TOOL] Executing Milvus search: '%s' (top_k=%d)", query, top_k)
try:
result = await asyncio.to_thread(milvus_search, query, top_k)
except Exception as exc:
# milvus_search exhausted all @with_retry attempts.
logger.error("[TOOL] milvus_search failed after all retries: %s", exc)
return DEGRADED_RESULT, []

citations: List[str] = []
formatted_results = []

for hit in result.get("results", []):
citation_url = hit.get('citation_url', '')
citation_url = hit.get("citation_url", "")
if citation_url and citation_url not in citations:
citations.append(citation_url)

formatted_results.append(
f"File: {hit.get('file_path', 'Unknown')}\n"
f"Content: {hit.get('content_text', '')}\n"
f"URL: {citation_url}\n"
f"Similarity: {hit.get('similarity', 0):.3f}\n"
)

formatted_text = "\n".join(formatted_results) if formatted_results else "No relevant results found."

formatted_text = (
"\n".join(formatted_results) if formatted_results else "No relevant results found."
)
return formatted_text, citations

return f"Unknown tool: {function_name}", []

except Exception as e:
print(f"[ERROR] Tool execution failed: {e}")
logger.error("[TOOL] Tool execution failed: %s", e)
return f"Tool execution failed: {e}", []

async def stream_llm_response(payload: Dict[str, Any], websocket, citations_collector: List[str] = None) -> None:
Expand Down
Empty file added shared/__init__.py
Empty file.
Loading