diff --git a/server-https/app.py b/server-https/app.py index 694af8a..ef05b7b 100644 --- a/server-https/app.py +++ b/server-https/app.py @@ -1,4 +1,5 @@ import os +import sys import json import httpx from fastapi import FastAPI, HTTPException @@ -9,6 +10,13 @@ from typing import Dict, Any, List, Optional, AsyncGenerator from sentence_transformers import SentenceTransformer from pymilvus import connections, Collection +import logging + +# Allow importing shared utilities from the repo root. +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +from shared.retry import DEGRADED_RESULT, with_retry # noqa: E402 + +logger = logging.getLogger(__name__) # Config KSERVE_URL = os.getenv("KSERVE_URL", "http://llama.docs-agent.svc.cluster.local/openai/v1/chat/completions") @@ -111,17 +119,35 @@ class ChatRequest(BaseModel): message: str stream: Optional[bool] = True +# Load the embedding model once at module startup to avoid per-query overhead. +_encoder: SentenceTransformer | None = None + + +def _get_encoder() -> SentenceTransformer: + """Return the module-level encoder singleton, loading it on first call.""" + global _encoder + if _encoder is None: + logger.info("[STARTUP] Loading embedding model: %s", EMBEDDING_MODEL) + _encoder = SentenceTransformer(EMBEDDING_MODEL) + return _encoder + + +@with_retry( + max_attempts=3, + base_delay=1.0, + max_delay=10.0, + backoff_factor=2.0, + jitter=True, + exceptions=(Exception,), +) def milvus_search(query: str, top_k: int = 5) -> Dict[str, Any]: - """Execute a semantic search in Milvus and return structured JSON serializable results.""" + """Execute a semantic search in Milvus; raises on failure to allow retry.""" try: - # Connect to Milvus connections.connect(alias="default", host=MILVUS_HOST, port=MILVUS_PORT) collection = Collection(MILVUS_COLLECTION) collection.load() - # Encoder (same model as pipeline) - encoder = SentenceTransformer(EMBEDDING_MODEL) - query_vec = encoder.encode(query).tolist() + query_vec = _get_encoder().encode(query).tolist() search_params = {"metric_type": "COSINE", "params": {"nprobe": 32}} results = collection.search( @@ -140,16 +166,15 @@ def milvus_search(query: str, top_k: int = 5) -> Dict[str, Any]: content_text = entity.get("content_text") or "" if isinstance(content_text, str) and len(content_text) > 400: content_text = content_text[:400] + "..." - hits.append({ - "similarity": similarity, - "file_path": entity.get("file_path"), - "citation_url": entity.get("citation_url"), - "content_text": content_text, - }) + hits.append( + { + "similarity": similarity, + "file_path": entity.get("file_path"), + "citation_url": entity.get("citation_url"), + "content_text": content_text, + } + ) return {"results": hits} - except Exception as e: - print(f"[ERROR] Milvus search failed: {e}") - return {"results": []} finally: try: connections.disconnect(alias="default") @@ -157,41 +182,49 @@ def milvus_search(query: str, top_k: int = 5) -> Dict[str, Any]: pass async def execute_tool(tool_call: Dict[str, Any]) -> tuple[str, List[str]]: - """Execute a tool call and return the result and citations""" + """Execute a tool call; offloads Milvus search to a thread and degrades gracefully on failure.""" + from fastapi.concurrency import run_in_threadpool + try: function_name = tool_call.get("function", {}).get("name") arguments = json.loads(tool_call.get("function", {}).get("arguments", "{}")) - + if function_name == "search_kubeflow_docs": query = arguments.get("query", "") top_k = arguments.get("top_k", 5) - - print(f"[TOOL] Executing Milvus search for: '{query}' (top_k={top_k})") - result = milvus_search(query, top_k) - - # Collect citations - citations = [] + + logger.info("[TOOL] Executing Milvus search: '%s' (top_k=%d)", query, top_k) + try: + result = await run_in_threadpool(milvus_search, query, top_k) + except Exception as exc: + # milvus_search exhausted all @with_retry attempts. + logger.error("[TOOL] milvus_search failed after all retries: %s", exc) + return DEGRADED_RESULT, [] + + citations: List[str] = [] formatted_results = [] - + for hit in result.get("results", []): - citation_url = hit.get('citation_url', '') + citation_url = hit.get("citation_url", "") if citation_url and citation_url not in citations: citations.append(citation_url) - + formatted_results.append( f"File: {hit.get('file_path', 'Unknown')}\n" f"Content: {hit.get('content_text', '')}\n" f"URL: {citation_url}\n" f"Similarity: {hit.get('similarity', 0):.3f}\n" ) - - formatted_text = "\n".join(formatted_results) if formatted_results else "No relevant results found." + + formatted_text = ( + "\n".join(formatted_results) if formatted_results else "No relevant results found." + ) return formatted_text, citations - + return f"Unknown tool: {function_name}", [] - + except Exception as e: - print(f"[ERROR] Tool execution failed: {e}") + logger.error("[TOOL] Tool execution failed: %s", e) return f"Tool execution failed: {e}", [] async def stream_llm_response(payload: Dict[str, Any]) -> AsyncGenerator[str, None]: diff --git a/server/app.py b/server/app.py index 96b277c..ee517e7 100644 --- a/server/app.py +++ b/server/app.py @@ -1,4 +1,5 @@ import os +import sys import json import asyncio import httpx @@ -10,6 +11,12 @@ from sentence_transformers import SentenceTransformer from pymilvus import connections, Collection +# Allow importing shared utilities from the repo root. +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +from shared.retry import DEGRADED_RESULT, with_retry # noqa: E402 + +logger = logging.getLogger(__name__) + # Config KSERVE_URL = os.getenv("KSERVE_URL", "http://llama.docs-agent.svc.cluster.local/openai/v1/chat/completions") MODEL = os.getenv("MODEL", "llama3.1-8B") @@ -65,17 +72,35 @@ +# Load the embedding model once at module startup to avoid per-query overhead. +_encoder: SentenceTransformer | None = None + + +def _get_encoder() -> SentenceTransformer: + """Return the module-level encoder singleton, loading it on first call.""" + global _encoder + if _encoder is None: + logger.info("[STARTUP] Loading embedding model: %s", EMBEDDING_MODEL) + _encoder = SentenceTransformer(EMBEDDING_MODEL) + return _encoder + + +@with_retry( + max_attempts=3, + base_delay=1.0, + max_delay=10.0, + backoff_factor=2.0, + jitter=True, + exceptions=(Exception,), +) def milvus_search(query: str, top_k: int = 5) -> Dict[str, Any]: - """Execute a semantic search in Milvus and return structured JSON serializable results.""" + """Execute a semantic search in Milvus; raises on failure to allow retry.""" try: - # Connect to Milvus connections.connect(alias="default", host=MILVUS_HOST, port=MILVUS_PORT) collection = Collection(MILVUS_COLLECTION) collection.load() - # Encoder (same model as pipeline) - encoder = SentenceTransformer(EMBEDDING_MODEL) - query_vec = encoder.encode(query).tolist() + query_vec = _get_encoder().encode(query).tolist() search_params = {"metric_type": "COSINE", "params": {"nprobe": 32}} results = collection.search( @@ -94,16 +119,15 @@ def milvus_search(query: str, top_k: int = 5) -> Dict[str, Any]: content_text = entity.get("content_text") or "" if isinstance(content_text, str) and len(content_text) > 400: content_text = content_text[:400] + "..." - hits.append({ - "similarity": similarity, - "file_path": entity.get("file_path"), - "citation_url": entity.get("citation_url"), - "content_text": content_text, - }) + hits.append( + { + "similarity": similarity, + "file_path": entity.get("file_path"), + "citation_url": entity.get("citation_url"), + "content_text": content_text, + } + ) return {"results": hits} - except Exception as e: - print(f"[ERROR] Milvus search failed: {e}") - return {"results": []} finally: try: connections.disconnect(alias="default") @@ -144,41 +168,47 @@ def milvus_search(query: str, top_k: int = 5) -> Dict[str, Any]: async def execute_tool(tool_call: Dict[str, Any]) -> tuple[str, List[str]]: - """Execute a tool call and return the result and citations""" + """Execute a tool call; offloads Milvus search to a thread and degrades gracefully on failure.""" try: function_name = tool_call.get("function", {}).get("name") arguments = json.loads(tool_call.get("function", {}).get("arguments", "{}")) - + if function_name == "search_kubeflow_docs": query = arguments.get("query", "") top_k = arguments.get("top_k", 5) - - print(f"[TOOL] Executing Milvus search for: '{query}' (top_k={top_k})") - result = milvus_search(query, top_k) - - # Collect citations - citations = [] + + logger.info("[TOOL] Executing Milvus search: '%s' (top_k=%d)", query, top_k) + try: + result = await asyncio.to_thread(milvus_search, query, top_k) + except Exception as exc: + # milvus_search exhausted all @with_retry attempts. + logger.error("[TOOL] milvus_search failed after all retries: %s", exc) + return DEGRADED_RESULT, [] + + citations: List[str] = [] formatted_results = [] - + for hit in result.get("results", []): - citation_url = hit.get('citation_url', '') + citation_url = hit.get("citation_url", "") if citation_url and citation_url not in citations: citations.append(citation_url) - + formatted_results.append( f"File: {hit.get('file_path', 'Unknown')}\n" f"Content: {hit.get('content_text', '')}\n" f"URL: {citation_url}\n" f"Similarity: {hit.get('similarity', 0):.3f}\n" ) - - formatted_text = "\n".join(formatted_results) if formatted_results else "No relevant results found." + + formatted_text = ( + "\n".join(formatted_results) if formatted_results else "No relevant results found." + ) return formatted_text, citations - + return f"Unknown tool: {function_name}", [] - + except Exception as e: - print(f"[ERROR] Tool execution failed: {e}") + logger.error("[TOOL] Tool execution failed: %s", e) return f"Tool execution failed: {e}", [] async def stream_llm_response(payload: Dict[str, Any], websocket, citations_collector: List[str] = None) -> None: diff --git a/shared/__init__.py b/shared/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/shared/retry.py b/shared/retry.py new file mode 100644 index 0000000..562cdf0 --- /dev/null +++ b/shared/retry.py @@ -0,0 +1,143 @@ +""" +Exponential backoff retry with jitter and graceful degradation. + +Implements the retry and fault-tolerance requirements from the GSoC 2026 +Agentic RAG spec (Hardened System Considerations, Requirement #5): + + "Robust retry logic is a must for all tools. The agent implements + exponential backoff with jitter for Vector DB retrievals and LLM API + timeouts. If tools strictly fail, the agent is configured to + transparently degrade, informing the user that 'Live code context is + currently unreachable.'" +""" + +from __future__ import annotations + +import asyncio +import logging +import random +import time +from functools import wraps +from typing import Callable, Tuple, Type, TypeVar + +logger = logging.getLogger(__name__) + +F = TypeVar("F", bound=Callable) + +# Sentinel returned by execute_tool when all retries are exhausted, so the +# LLM receives an explicit degradation message instead of an empty result set. +DEGRADED_RESULT = ( + "The documentation search service is temporarily unreachable after " + "multiple retries. Please try again in a moment. If the problem " + "persists, the vector database or embedding service may be offline." +) + + +def with_retry( + max_attempts: int = 3, + base_delay: float = 1.0, + max_delay: float = 30.0, + backoff_factor: float = 2.0, + jitter: bool = True, + exceptions: Tuple[Type[Exception], ...] = (Exception,), +) -> Callable[[F], F]: + """Decorator that retries a sync or async callable with exponential backoff. + + Args: + max_attempts: Total number of attempts (first try + retries). + base_delay: Initial sleep duration in seconds before the first retry. + max_delay: Upper bound on the computed sleep duration. + backoff_factor: Multiplier applied to the delay after each failure. + jitter: When True, adds ±50 % uniform noise to prevent thundering herd. + exceptions: Tuple of exception types that trigger a retry. Other + exceptions propagate immediately. + + Usage (sync):: + + @with_retry(max_attempts=3, exceptions=(ConnectionError, TimeoutError)) + def milvus_search(query: str) -> dict: ... + + Usage (async):: + + @with_retry(max_attempts=3, base_delay=0.5) + async def call_kserve(payload: dict) -> dict: ... + """ + + def _compute_delay(attempt: int) -> float: + delay = min(base_delay * (backoff_factor**attempt), max_delay) + if jitter: + # Full-jitter strategy: uniform in [0, delay] avoids correlation + # between retrying clients (see AWS "Exponential Backoff and Jitter"). + delay = random.uniform(0, delay) + return delay + + def decorator(func: F) -> F: + if asyncio.iscoroutinefunction(func): + + @wraps(func) + async def async_wrapper(*args, **kwargs): + last_exc: Exception = RuntimeError("unreachable") + for attempt in range(max_attempts): + try: + return await func(*args, **kwargs) + except exceptions as exc: + last_exc = exc + if attempt == max_attempts - 1: + break + delay = _compute_delay(attempt) + logger.warning( + "[RETRY] %s attempt %d/%d failed: %s. " + "Retrying in %.2fs...", + func.__name__, + attempt + 1, + max_attempts, + exc, + delay, + ) + await asyncio.sleep(delay) + + logger.error( + "[RETRY] %s exhausted all %d attempts. Last error: %s", + func.__name__, + max_attempts, + last_exc, + ) + raise last_exc + + return async_wrapper # type: ignore[return-value] + + else: + + @wraps(func) + def sync_wrapper(*args, **kwargs): + last_exc: Exception = RuntimeError("unreachable") + for attempt in range(max_attempts): + try: + return func(*args, **kwargs) + except exceptions as exc: + last_exc = exc + if attempt == max_attempts - 1: + break + delay = _compute_delay(attempt) + logger.warning( + "[RETRY] %s attempt %d/%d failed: %s. " + "Retrying in %.2fs...", + func.__name__, + attempt + 1, + max_attempts, + exc, + delay, + ) + time.sleep(delay) + + logger.error( + "[RETRY] %s exhausted all %d attempts. Last error: %s", + func.__name__, + max_attempts, + last_exc, + ) + raise last_exc + + return sync_wrapper # type: ignore[return-value] + + return decorator # type: ignore[return-value]