From 8b89d135e938f7e740a1cc26767214a2646edee4 Mon Sep 17 00:00:00 2001 From: Shengzhong Guan Date: Sun, 22 Mar 2026 05:53:15 -0400 Subject: [PATCH] feat(shared): exponential backoff retry and graceful degradation Implements GSoC 2026 Agentic RAG spec Requirement #5: 'Robust retry logic is a must for all tools. The agent implements exponential backoff with jitter for Vector DB retrievals and LLM API timeouts. If tools strictly fail, the agent is configured to transparently degrade, informing the user that Live code context is currently unreachable.' Changes: - shared/retry.py: reusable @with_retry decorator supporting both sync and async callables; uses AWS full-jitter strategy (random.uniform(0, delay)) to prevent thundering-herd on retry; exposes DEGRADED_RESULT sentinel string for LLM-visible outage messages - server/app.py, server-https/app.py: * milvus_search: remove silent exception swallow; add @with_retry (3 attempts, base 1s, max 10s, factor 2x + jitter); encoder loaded once at module level via _get_encoder() singleton * execute_tool: offload blocking milvus_search to asyncio.to_thread (websocket server) and run_in_threadpool (FastAPI server) so the async event loop stays responsive under concurrent load; on retry exhaustion return DEGRADED_RESULT so LLM communicates the outage to the user instead of silently hallucinating from empty context Signed-off-by: Shengzhong Guan Made-with: Cursor Signed-off-by: Shengzhong Guan --- server-https/app.py | 93 ++++++++++++++++++---------- server/app.py | 90 ++++++++++++++++++---------- shared/__init__.py | 0 shared/retry.py | 143 ++++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 266 insertions(+), 60 deletions(-) create mode 100644 shared/__init__.py create mode 100644 shared/retry.py diff --git a/server-https/app.py b/server-https/app.py index 694af8a..ef05b7b 100644 --- a/server-https/app.py +++ b/server-https/app.py @@ -1,4 +1,5 @@ import os +import sys import json import httpx from fastapi import FastAPI, HTTPException @@ -9,6 +10,13 @@ from typing import Dict, Any, List, Optional, AsyncGenerator from sentence_transformers import SentenceTransformer from pymilvus import connections, Collection +import logging + +# Allow importing shared utilities from the repo root. +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +from shared.retry import DEGRADED_RESULT, with_retry # noqa: E402 + +logger = logging.getLogger(__name__) # Config KSERVE_URL = os.getenv("KSERVE_URL", "http://llama.docs-agent.svc.cluster.local/openai/v1/chat/completions") @@ -111,17 +119,35 @@ class ChatRequest(BaseModel): message: str stream: Optional[bool] = True +# Load the embedding model once at module startup to avoid per-query overhead. +_encoder: SentenceTransformer | None = None + + +def _get_encoder() -> SentenceTransformer: + """Return the module-level encoder singleton, loading it on first call.""" + global _encoder + if _encoder is None: + logger.info("[STARTUP] Loading embedding model: %s", EMBEDDING_MODEL) + _encoder = SentenceTransformer(EMBEDDING_MODEL) + return _encoder + + +@with_retry( + max_attempts=3, + base_delay=1.0, + max_delay=10.0, + backoff_factor=2.0, + jitter=True, + exceptions=(Exception,), +) def milvus_search(query: str, top_k: int = 5) -> Dict[str, Any]: - """Execute a semantic search in Milvus and return structured JSON serializable results.""" + """Execute a semantic search in Milvus; raises on failure to allow retry.""" try: - # Connect to Milvus connections.connect(alias="default", host=MILVUS_HOST, port=MILVUS_PORT) collection = Collection(MILVUS_COLLECTION) collection.load() - # Encoder (same model as pipeline) - encoder = SentenceTransformer(EMBEDDING_MODEL) - query_vec = encoder.encode(query).tolist() + query_vec = _get_encoder().encode(query).tolist() search_params = {"metric_type": "COSINE", "params": {"nprobe": 32}} results = collection.search( @@ -140,16 +166,15 @@ def milvus_search(query: str, top_k: int = 5) -> Dict[str, Any]: content_text = entity.get("content_text") or "" if isinstance(content_text, str) and len(content_text) > 400: content_text = content_text[:400] + "..." - hits.append({ - "similarity": similarity, - "file_path": entity.get("file_path"), - "citation_url": entity.get("citation_url"), - "content_text": content_text, - }) + hits.append( + { + "similarity": similarity, + "file_path": entity.get("file_path"), + "citation_url": entity.get("citation_url"), + "content_text": content_text, + } + ) return {"results": hits} - except Exception as e: - print(f"[ERROR] Milvus search failed: {e}") - return {"results": []} finally: try: connections.disconnect(alias="default") @@ -157,41 +182,49 @@ def milvus_search(query: str, top_k: int = 5) -> Dict[str, Any]: pass async def execute_tool(tool_call: Dict[str, Any]) -> tuple[str, List[str]]: - """Execute a tool call and return the result and citations""" + """Execute a tool call; offloads Milvus search to a thread and degrades gracefully on failure.""" + from fastapi.concurrency import run_in_threadpool + try: function_name = tool_call.get("function", {}).get("name") arguments = json.loads(tool_call.get("function", {}).get("arguments", "{}")) - + if function_name == "search_kubeflow_docs": query = arguments.get("query", "") top_k = arguments.get("top_k", 5) - - print(f"[TOOL] Executing Milvus search for: '{query}' (top_k={top_k})") - result = milvus_search(query, top_k) - - # Collect citations - citations = [] + + logger.info("[TOOL] Executing Milvus search: '%s' (top_k=%d)", query, top_k) + try: + result = await run_in_threadpool(milvus_search, query, top_k) + except Exception as exc: + # milvus_search exhausted all @with_retry attempts. + logger.error("[TOOL] milvus_search failed after all retries: %s", exc) + return DEGRADED_RESULT, [] + + citations: List[str] = [] formatted_results = [] - + for hit in result.get("results", []): - citation_url = hit.get('citation_url', '') + citation_url = hit.get("citation_url", "") if citation_url and citation_url not in citations: citations.append(citation_url) - + formatted_results.append( f"File: {hit.get('file_path', 'Unknown')}\n" f"Content: {hit.get('content_text', '')}\n" f"URL: {citation_url}\n" f"Similarity: {hit.get('similarity', 0):.3f}\n" ) - - formatted_text = "\n".join(formatted_results) if formatted_results else "No relevant results found." + + formatted_text = ( + "\n".join(formatted_results) if formatted_results else "No relevant results found." + ) return formatted_text, citations - + return f"Unknown tool: {function_name}", [] - + except Exception as e: - print(f"[ERROR] Tool execution failed: {e}") + logger.error("[TOOL] Tool execution failed: %s", e) return f"Tool execution failed: {e}", [] async def stream_llm_response(payload: Dict[str, Any]) -> AsyncGenerator[str, None]: diff --git a/server/app.py b/server/app.py index 96b277c..ee517e7 100644 --- a/server/app.py +++ b/server/app.py @@ -1,4 +1,5 @@ import os +import sys import json import asyncio import httpx @@ -10,6 +11,12 @@ from sentence_transformers import SentenceTransformer from pymilvus import connections, Collection +# Allow importing shared utilities from the repo root. +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +from shared.retry import DEGRADED_RESULT, with_retry # noqa: E402 + +logger = logging.getLogger(__name__) + # Config KSERVE_URL = os.getenv("KSERVE_URL", "http://llama.docs-agent.svc.cluster.local/openai/v1/chat/completions") MODEL = os.getenv("MODEL", "llama3.1-8B") @@ -65,17 +72,35 @@ +# Load the embedding model once at module startup to avoid per-query overhead. +_encoder: SentenceTransformer | None = None + + +def _get_encoder() -> SentenceTransformer: + """Return the module-level encoder singleton, loading it on first call.""" + global _encoder + if _encoder is None: + logger.info("[STARTUP] Loading embedding model: %s", EMBEDDING_MODEL) + _encoder = SentenceTransformer(EMBEDDING_MODEL) + return _encoder + + +@with_retry( + max_attempts=3, + base_delay=1.0, + max_delay=10.0, + backoff_factor=2.0, + jitter=True, + exceptions=(Exception,), +) def milvus_search(query: str, top_k: int = 5) -> Dict[str, Any]: - """Execute a semantic search in Milvus and return structured JSON serializable results.""" + """Execute a semantic search in Milvus; raises on failure to allow retry.""" try: - # Connect to Milvus connections.connect(alias="default", host=MILVUS_HOST, port=MILVUS_PORT) collection = Collection(MILVUS_COLLECTION) collection.load() - # Encoder (same model as pipeline) - encoder = SentenceTransformer(EMBEDDING_MODEL) - query_vec = encoder.encode(query).tolist() + query_vec = _get_encoder().encode(query).tolist() search_params = {"metric_type": "COSINE", "params": {"nprobe": 32}} results = collection.search( @@ -94,16 +119,15 @@ def milvus_search(query: str, top_k: int = 5) -> Dict[str, Any]: content_text = entity.get("content_text") or "" if isinstance(content_text, str) and len(content_text) > 400: content_text = content_text[:400] + "..." - hits.append({ - "similarity": similarity, - "file_path": entity.get("file_path"), - "citation_url": entity.get("citation_url"), - "content_text": content_text, - }) + hits.append( + { + "similarity": similarity, + "file_path": entity.get("file_path"), + "citation_url": entity.get("citation_url"), + "content_text": content_text, + } + ) return {"results": hits} - except Exception as e: - print(f"[ERROR] Milvus search failed: {e}") - return {"results": []} finally: try: connections.disconnect(alias="default") @@ -144,41 +168,47 @@ def milvus_search(query: str, top_k: int = 5) -> Dict[str, Any]: async def execute_tool(tool_call: Dict[str, Any]) -> tuple[str, List[str]]: - """Execute a tool call and return the result and citations""" + """Execute a tool call; offloads Milvus search to a thread and degrades gracefully on failure.""" try: function_name = tool_call.get("function", {}).get("name") arguments = json.loads(tool_call.get("function", {}).get("arguments", "{}")) - + if function_name == "search_kubeflow_docs": query = arguments.get("query", "") top_k = arguments.get("top_k", 5) - - print(f"[TOOL] Executing Milvus search for: '{query}' (top_k={top_k})") - result = milvus_search(query, top_k) - - # Collect citations - citations = [] + + logger.info("[TOOL] Executing Milvus search: '%s' (top_k=%d)", query, top_k) + try: + result = await asyncio.to_thread(milvus_search, query, top_k) + except Exception as exc: + # milvus_search exhausted all @with_retry attempts. + logger.error("[TOOL] milvus_search failed after all retries: %s", exc) + return DEGRADED_RESULT, [] + + citations: List[str] = [] formatted_results = [] - + for hit in result.get("results", []): - citation_url = hit.get('citation_url', '') + citation_url = hit.get("citation_url", "") if citation_url and citation_url not in citations: citations.append(citation_url) - + formatted_results.append( f"File: {hit.get('file_path', 'Unknown')}\n" f"Content: {hit.get('content_text', '')}\n" f"URL: {citation_url}\n" f"Similarity: {hit.get('similarity', 0):.3f}\n" ) - - formatted_text = "\n".join(formatted_results) if formatted_results else "No relevant results found." + + formatted_text = ( + "\n".join(formatted_results) if formatted_results else "No relevant results found." + ) return formatted_text, citations - + return f"Unknown tool: {function_name}", [] - + except Exception as e: - print(f"[ERROR] Tool execution failed: {e}") + logger.error("[TOOL] Tool execution failed: %s", e) return f"Tool execution failed: {e}", [] async def stream_llm_response(payload: Dict[str, Any], websocket, citations_collector: List[str] = None) -> None: diff --git a/shared/__init__.py b/shared/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/shared/retry.py b/shared/retry.py new file mode 100644 index 0000000..562cdf0 --- /dev/null +++ b/shared/retry.py @@ -0,0 +1,143 @@ +""" +Exponential backoff retry with jitter and graceful degradation. + +Implements the retry and fault-tolerance requirements from the GSoC 2026 +Agentic RAG spec (Hardened System Considerations, Requirement #5): + + "Robust retry logic is a must for all tools. The agent implements + exponential backoff with jitter for Vector DB retrievals and LLM API + timeouts. If tools strictly fail, the agent is configured to + transparently degrade, informing the user that 'Live code context is + currently unreachable.'" +""" + +from __future__ import annotations + +import asyncio +import logging +import random +import time +from functools import wraps +from typing import Callable, Tuple, Type, TypeVar + +logger = logging.getLogger(__name__) + +F = TypeVar("F", bound=Callable) + +# Sentinel returned by execute_tool when all retries are exhausted, so the +# LLM receives an explicit degradation message instead of an empty result set. +DEGRADED_RESULT = ( + "The documentation search service is temporarily unreachable after " + "multiple retries. Please try again in a moment. If the problem " + "persists, the vector database or embedding service may be offline." +) + + +def with_retry( + max_attempts: int = 3, + base_delay: float = 1.0, + max_delay: float = 30.0, + backoff_factor: float = 2.0, + jitter: bool = True, + exceptions: Tuple[Type[Exception], ...] = (Exception,), +) -> Callable[[F], F]: + """Decorator that retries a sync or async callable with exponential backoff. + + Args: + max_attempts: Total number of attempts (first try + retries). + base_delay: Initial sleep duration in seconds before the first retry. + max_delay: Upper bound on the computed sleep duration. + backoff_factor: Multiplier applied to the delay after each failure. + jitter: When True, adds ±50 % uniform noise to prevent thundering herd. + exceptions: Tuple of exception types that trigger a retry. Other + exceptions propagate immediately. + + Usage (sync):: + + @with_retry(max_attempts=3, exceptions=(ConnectionError, TimeoutError)) + def milvus_search(query: str) -> dict: ... + + Usage (async):: + + @with_retry(max_attempts=3, base_delay=0.5) + async def call_kserve(payload: dict) -> dict: ... + """ + + def _compute_delay(attempt: int) -> float: + delay = min(base_delay * (backoff_factor**attempt), max_delay) + if jitter: + # Full-jitter strategy: uniform in [0, delay] avoids correlation + # between retrying clients (see AWS "Exponential Backoff and Jitter"). + delay = random.uniform(0, delay) + return delay + + def decorator(func: F) -> F: + if asyncio.iscoroutinefunction(func): + + @wraps(func) + async def async_wrapper(*args, **kwargs): + last_exc: Exception = RuntimeError("unreachable") + for attempt in range(max_attempts): + try: + return await func(*args, **kwargs) + except exceptions as exc: + last_exc = exc + if attempt == max_attempts - 1: + break + delay = _compute_delay(attempt) + logger.warning( + "[RETRY] %s attempt %d/%d failed: %s. " + "Retrying in %.2fs...", + func.__name__, + attempt + 1, + max_attempts, + exc, + delay, + ) + await asyncio.sleep(delay) + + logger.error( + "[RETRY] %s exhausted all %d attempts. Last error: %s", + func.__name__, + max_attempts, + last_exc, + ) + raise last_exc + + return async_wrapper # type: ignore[return-value] + + else: + + @wraps(func) + def sync_wrapper(*args, **kwargs): + last_exc: Exception = RuntimeError("unreachable") + for attempt in range(max_attempts): + try: + return func(*args, **kwargs) + except exceptions as exc: + last_exc = exc + if attempt == max_attempts - 1: + break + delay = _compute_delay(attempt) + logger.warning( + "[RETRY] %s attempt %d/%d failed: %s. " + "Retrying in %.2fs...", + func.__name__, + attempt + 1, + max_attempts, + exc, + delay, + ) + time.sleep(delay) + + logger.error( + "[RETRY] %s exhausted all %d attempts. Last error: %s", + func.__name__, + max_attempts, + last_exc, + ) + raise last_exc + + return sync_wrapper # type: ignore[return-value] + + return decorator # type: ignore[return-value]