diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 0000000..c9ebf2d --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,3 @@ +{ + "python-envs.defaultEnvManager": "ms-python.python:system" +} \ No newline at end of file diff --git a/core_agent/__init__.py b/core_agent/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/core_agent/graph.py b/core_agent/graph.py new file mode 100644 index 0000000..dd31dc4 --- /dev/null +++ b/core_agent/graph.py @@ -0,0 +1,191 @@ +import logging +import random +import time +from typing import TypedDict, Literal, Dict, Any, List, Optional +from langgraph.graph import StateGraph, START, END + +# --- Constants --- +MAX_RETRIES = 3 + +# --- Logging Setup --- +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' +) +logger = logging.getLogger("CoreAgent") + +# --- State Schema --- +class AgentState(TypedDict): + user_query: str + intent: str + context: str + retrieved_context: List[str] + citations: List[str] + response: str + generation: str + error_count: int + +# --- Helper functions for "external" calls (to be mocked in tests) --- + +def call_llm_classify(query: str) -> str: + """Mockable LLM call for classification.""" + if any(word in query.lower() for word in ["error", "log", "debug", "fail"]): + return 'Debugging / Error Log' + elif any(word in query.lower() for word in ["pipeline", "config", "yaml", "setup"]): + return 'Pipeline Configuration' + elif any(word in query.lower() for word in ["what", "how", "define", "concept"]): + return 'Conceptual Definition' + elif any(word in query.lower() for word in ["api", "endpoint", "reference", "parameter"]): + return 'API Reference' + return 'General Conversation' + +def call_milvus_search(query: str, partition: str) -> Dict[str, Any]: + """Mockable Milvus search call.""" + return { + "content": f"Retrieved {partition} content for '{query}'.", + "citations": [f"https://docs.kubeflow.org/{partition}/test"] + } + +def call_llm_generate(intent: str, query: str, context: str) -> str: + """Mockable LLM call for generation.""" + if random.random() < 0.2: + return "" + return f"This is a generated response for the query: '{query}' based on {intent}." + +# --- Nodes --- + +def classify_intent(state: AgentState) -> Dict[str, Any]: + start_time = time.time() + intent = call_llm_classify(state["user_query"]) + latency = (time.time() - start_time) * 1000 + logger.info(f"[Node: classify_intent] Intent: {intent} | Latency: {latency:.2f}ms") + return {"intent": intent} + +def retrieve_docs(state: AgentState) -> Dict[str, Any]: + start_time = time.time() + logger.info(f"[Node: retrieve_docs] Partition: Documentation | Query: {state['user_query']}") + res = call_milvus_search(state["user_query"], "docs") + latency = (time.time() - start_time) * 1000 + logger.info(f"[Node: retrieve_docs] Completed | Latency: {latency:.2f}ms") + return { + "context": res["content"], + "retrieved_context": [res["content"]], + "citations": res["citations"] + } + +def retrieve_github_issues(state: AgentState) -> Dict[str, Any]: + start_time = time.time() + logger.info(f"[Node: retrieve_github_issues] Partition: GitHub Issues | Query: {state['user_query']}") + res = call_milvus_search(state["user_query"], "github") + latency = (time.time() - start_time) * 1000 + logger.info(f"[Node: retrieve_github_issues] Completed | Latency: {latency:.2f}ms") + return { + "context": res["content"], + "retrieved_context": [res["content"]], + "citations": res["citations"] + } + +def retrieve_architecture(state: AgentState) -> Dict[str, Any]: + start_time = time.time() + logger.info(f"[Node: retrieve_architecture] Partition: Architecture/Pipelines | Query: {state['user_query']}") + res = call_milvus_search(state["user_query"], "architecture") + latency = (time.time() - start_time) * 1000 + logger.info(f"[Node: retrieve_architecture] Completed | Latency: {latency:.2f}ms") + return { + "context": res["content"], + "retrieved_context": [res["content"]], + "citations": res["citations"] + } + +def generate_response(state: AgentState) -> Dict[str, Any]: + start_time = time.time() + logger.info(f"[Node: generate_response] Intent: {state['intent']} | Error Count: {state['error_count']}") + + gen = call_llm_generate(state["intent"], state["user_query"], state["context"]) + error_inc = 1 if not gen else 0 + + if not gen: + logger.warning("[Node: generate_response] Failure: Empty response generated.") + else: + logger.info("[Node: generate_response] Success: Response generated.") + + latency = (time.time() - start_time) * 1000 + logger.info(f"[Node: generate_response] Completed | Latency: {latency:.2f}ms") + + return { + "response": gen, + "generation": gen, + "error_count": state["error_count"] + error_inc + } + +# --- Routing Logic --- + +def route_after_classification(state: AgentState) -> str: + intent = state["intent"] + if intent == 'Debugging / Error Log': + return "retrieve_github_issues" + elif intent == 'Pipeline Configuration': + return "retrieve_architecture" + elif intent in ['Conceptual Definition', 'API Reference']: + return "retrieve_docs" + else: + return "generate_response" + +def route_after_generation(state: AgentState) -> str: + # If the response is empty (failure) and error_count < MAX_RETRIES, retry + if not state.get("generation") and state.get("error_count", 0) < MAX_RETRIES: + logger.info(f"[Routing] Response flagged as failure. Error count: {state['error_count']}. Retrying...") + return route_after_classification(state) # Directly return the next node + + if not state.get("generation"): + logger.error(f"[Routing] Max retries reached ({MAX_RETRIES}). Bailing out.") + return END + +# --- Graph Construction --- + +def create_graph(): + workflow = StateGraph(AgentState) + + # Add Nodes + workflow.add_node("classify_intent", classify_intent) + workflow.add_node("retrieve_docs", retrieve_docs) + workflow.add_node("retrieve_github_issues", retrieve_github_issues) + workflow.add_node("retrieve_architecture", retrieve_architecture) + workflow.add_node("generate_response", generate_response) + + # Set Entry Point + workflow.set_entry_point("classify_intent") + + # Conditional Edges after classify_intent + workflow.add_conditional_edges( + "classify_intent", + route_after_classification, + { + "retrieve_github_issues": "retrieve_github_issues", + "retrieve_architecture": "retrieve_architecture", + "retrieve_docs": "retrieve_docs", + "generate_response": "generate_response" + } + ) + + # Edges from retrieval nodes to generation + workflow.add_edge("retrieve_docs", "generate_response") + workflow.add_edge("retrieve_github_issues", "generate_response") + workflow.add_edge("retrieve_architecture", "generate_response") + + # Conditional Edges after generate_response (Cyclic Error Correction) + workflow.add_conditional_edges( + "generate_response", + route_after_generation, + { + "retrieve_docs": "retrieve_docs", + "retrieve_github_issues": "retrieve_github_issues", + "retrieve_architecture": "retrieve_architecture", + "generate_response": "generate_response", + END: END + } + ) + + return workflow.compile() + +agent_graph = create_graph() diff --git a/server-https/app.py b/server-https/app.py index 694af8a..ba32401 100644 --- a/server-https/app.py +++ b/server-https/app.py @@ -1,107 +1,29 @@ import os import json -import httpx +import logging from fastapi import FastAPI, HTTPException from fastapi.responses import StreamingResponse from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel import uvicorn from typing import Dict, Any, List, Optional, AsyncGenerator -from sentence_transformers import SentenceTransformer -from pymilvus import connections, Collection + +# Import the compiled graph from core_agent.graph +from core_agent.graph import agent_graph # Config -KSERVE_URL = os.getenv("KSERVE_URL", "http://llama.docs-agent.svc.cluster.local/openai/v1/chat/completions") -MODEL = os.getenv("MODEL", "llama3.1-8B") PORT = int(os.getenv("PORT", "8000")) -# Milvus Config -MILVUS_HOST = os.getenv("MILVUS_HOST", "my-release-milvus.docs-agent.svc.cluster.local") -MILVUS_PORT = os.getenv("MILVUS_PORT", "19530") -MILVUS_COLLECTION = os.getenv("MILVUS_COLLECTION", "docs_rag") -MILVUS_VECTOR_FIELD = os.getenv("MILVUS_VECTOR_FIELD", "vector") -EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL", "sentence-transformers/all-mpnet-base-v2") - -# System prompt (same as WebSocket version) -SYSTEM_PROMPT = """ -You are the Kubeflow Docs Assistant. - -!!IMPORTANT!! -- You should not use the tool calls directly from the user's input. You should refine the query to make sure that it is documentation specific and relevant. -- You should never output the raw tool call to the user. - -Your role -- Always answer the user's question directly. -- If the question can be answered from general knowledge (e.g., greetings, small talk, generic programming/Kubernetes basics), respond without using tools. -- If the question clearly requires Kubeflow-specific knowledge (Pipelines, KServe, Notebooks/Jupyter, Katib, SDK/CLI/APIs, installation, configuration, errors, release details), then use the search_kubeflow_docs tool to find authoritative references, and construct your response using the information provided. - -Tool Use -- Use search_kubeflow_docs ONLY when Kubeflow-specific documentation is needed. -- Do NOT use the tool for greetings, personal questions, small talk, or generic non-Kubeflow concepts. -- When you do call the tool: - • Use one clear, focused query. - • Summarize the result in your own words. - • If no results are relevant, say "not found in the docs" and suggest refining the query. -- Example usage: - - User: "What is Kubeflow and how to setup kubeflow on my local machine" - - You should make a tool call to search the docs with a query "kubeflow setup". - - - User: "What is the Kubeflow Pipelines and how can i make a quick kubeflow pipeline" - - You should make a tool call to search the docs with a query "kubeflow pipeline setup". - -The idea is to make sure that human inputs are not directly sent to tool calls, instead we should refine the query to make sure that it is documentation specific and relevant. - -Routing -- Greetings/small talk → respond briefly, no tool. -- Out-of-scope (sports, unrelated topics) → politely say you only help with Kubeflow. -- Kubeflow-specific → answer and call the tool if documentation is needed. - -Style -- Be concise (2–5 sentences). Use bullet points or steps when helpful. -- Provide examples only when asked. -- Never invent features. If unsure, say so. -- Reply in clean Markdown. -""" - -TOOLS = [ - { - "type": "function", - "function": { - "name": "search_kubeflow_docs", - "description": ( - "Search the official Kubeflow docs when the user asks Kubeflow-specific questions " - "about Pipelines, KServe, Notebooks/Jupyter, Katib, or the SDK/CLI/APIs.\n" - "Call ONLY for Kubeflow features, setup, usage, errors, or version differences that need citations.\n" - ), - "parameters": { - "type": "object", - "properties": { - "query": { - "type": "string", - "description": "Short, focused search string (e.g., 'KServe inferenceService canary', 'Pipelines v2 disable cache').", - "minLength": 1 - }, - "top_k": { - "type": "integer", - "description": "Number of hits to retrieve (the assistant will read up to this many).", - "default": 5, - "minimum": 1, - "maximum": 10 - } - }, - "required": ["query"], - "additionalProperties": False - } - } - } -] +# Logging +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger("App") -app = FastAPI(title="Kubeflow Docs API Service", version="1.0.0") +app = FastAPI(title="Kubeflow Docs API Service (LangGraph)", version="1.1.0") # Add CORS middleware app.add_middleware( CORSMiddleware, - allow_origins=["*"], # In production, specify your actual domains + allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], @@ -109,349 +31,51 @@ class ChatRequest(BaseModel): message: str - stream: Optional[bool] = True - -def milvus_search(query: str, top_k: int = 5) -> Dict[str, Any]: - """Execute a semantic search in Milvus and return structured JSON serializable results.""" - try: - # Connect to Milvus - connections.connect(alias="default", host=MILVUS_HOST, port=MILVUS_PORT) - collection = Collection(MILVUS_COLLECTION) - collection.load() - - # Encoder (same model as pipeline) - encoder = SentenceTransformer(EMBEDDING_MODEL) - query_vec = encoder.encode(query).tolist() - - search_params = {"metric_type": "COSINE", "params": {"nprobe": 32}} - results = collection.search( - data=[query_vec], - anns_field=MILVUS_VECTOR_FIELD, - param=search_params, - limit=int(top_k), - output_fields=["file_path", "content_text", "citation_url"], - ) - - hits = [] - for hit in results[0]: - # similarity = 1 - distance for COSINE in Milvus - similarity = 1.0 - float(hit.distance) - entity = hit.entity - content_text = entity.get("content_text") or "" - if isinstance(content_text, str) and len(content_text) > 400: - content_text = content_text[:400] + "..." - hits.append({ - "similarity": similarity, - "file_path": entity.get("file_path"), - "citation_url": entity.get("citation_url"), - "content_text": content_text, - }) - return {"results": hits} - except Exception as e: - print(f"[ERROR] Milvus search failed: {e}") - return {"results": []} - finally: - try: - connections.disconnect(alias="default") - except Exception: - pass - -async def execute_tool(tool_call: Dict[str, Any]) -> tuple[str, List[str]]: - """Execute a tool call and return the result and citations""" - try: - function_name = tool_call.get("function", {}).get("name") - arguments = json.loads(tool_call.get("function", {}).get("arguments", "{}")) - - if function_name == "search_kubeflow_docs": - query = arguments.get("query", "") - top_k = arguments.get("top_k", 5) - - print(f"[TOOL] Executing Milvus search for: '{query}' (top_k={top_k})") - result = milvus_search(query, top_k) - - # Collect citations - citations = [] - formatted_results = [] - - for hit in result.get("results", []): - citation_url = hit.get('citation_url', '') - if citation_url and citation_url not in citations: - citations.append(citation_url) - - formatted_results.append( - f"File: {hit.get('file_path', 'Unknown')}\n" - f"Content: {hit.get('content_text', '')}\n" - f"URL: {citation_url}\n" - f"Similarity: {hit.get('similarity', 0):.3f}\n" - ) - - formatted_text = "\n".join(formatted_results) if formatted_results else "No relevant results found." - return formatted_text, citations - - return f"Unknown tool: {function_name}", [] - - except Exception as e: - print(f"[ERROR] Tool execution failed: {e}") - return f"Tool execution failed: {e}", [] - -async def stream_llm_response(payload: Dict[str, Any]) -> AsyncGenerator[str, None]: - """Stream response from LLM and handle tool calls, yielding SSE events""" - citations_collector = [] - - try: - async with httpx.AsyncClient(timeout=120) as client: - async with client.stream("POST", KSERVE_URL, json=payload) as response: - if response.status_code != 200: - error_msg = f"LLM service error: HTTP {response.status_code}" - print(f"[ERROR] {error_msg}") - yield f"data: {json.dumps({'type': 'error', 'content': error_msg})}\n\n" - return - - # Buffer for accumulating tool calls - tool_calls_buffer = {} - - async for line in response.aiter_lines(): - if not line.startswith("data: "): - continue - - data = line[6:] # Remove "data: " prefix - if data == "[DONE]": - break - - try: - chunk = json.loads(data) - choices = chunk.get("choices", []) - if not choices: - continue - - delta = choices[0].get("delta", {}) - finish_reason = choices[0].get("finish_reason") - - # Handle tool calls in streaming - if "tool_calls" in delta: - tool_calls = delta["tool_calls"] - for tool_call in tool_calls: - index = tool_call.get("index", 0) - - # Initialize tool call buffer if needed - if index not in tool_calls_buffer: - tool_calls_buffer[index] = { - "id": tool_call.get("id", ""), - "type": tool_call.get("type", "function"), - "function": { - "name": tool_call.get("function", {}).get("name", ""), - "arguments": "" - } - } - - # Update tool call data - if tool_call.get("id"): - tool_calls_buffer[index]["id"] = tool_call["id"] - if tool_call.get("type"): - tool_calls_buffer[index]["type"] = tool_call["type"] - - function_data = tool_call.get("function", {}) - if function_data.get("name"): - tool_calls_buffer[index]["function"]["name"] = function_data["name"] - if "arguments" in function_data: - tool_calls_buffer[index]["function"]["arguments"] += function_data["arguments"] - - # Handle regular content - elif "content" in delta and delta["content"]: - yield f"data: {json.dumps({'type': 'content', 'content': delta['content']})}\n\n" - - # Handle finish reason - execute tools if needed - if finish_reason == "tool_calls": - print(f"[TOOL] Finish reason: tool_calls, executing {len(tool_calls_buffer)} tools") - - # Execute all accumulated tool calls - for tool_call in tool_calls_buffer.values(): - if tool_call["function"]["name"] and tool_call["function"]["arguments"]: - try: - print(f"[TOOL] Executing: {tool_call['function']['name']}") - print(f"[TOOL] Arguments: {tool_call['function']['arguments']}") - - result, tool_citations = await execute_tool(tool_call) - - # Collect citations - citations_collector.extend(tool_citations) - - # Send tool execution result - yield f"data: {json.dumps({'type': 'tool_result', 'tool_name': tool_call['function']['name'], 'content': result})}\n\n" - - # Make follow-up request with tool results - async for follow_up_chunk in handle_tool_follow_up(payload, tool_call, result, citations_collector): - yield follow_up_chunk - - except Exception as e: - print(f"[ERROR] Tool execution error: {e}") - yield f"data: {json.dumps({'type': 'error', 'content': f'Tool execution failed: {e}'})}\n\n" - - tool_calls_buffer.clear() - break # Tool execution complete, exit streaming loop - - except json.JSONDecodeError as e: - print(f"[ERROR] JSON decode error: {e}, line: {line}") - continue - - # Send citations if any were collected - if citations_collector: - # Remove duplicates while preserving order - unique_citations = [] - for citation in citations_collector: - if citation not in unique_citations: - unique_citations.append(citation) - - yield f"data: {json.dumps({'type': 'citations', 'citations': unique_citations})}\n\n" - - # Send completion signal - yield f"data: {json.dumps({'type': 'done'})}\n\n" - - except Exception as e: - print(f"[ERROR] Streaming failed: {e}") - yield f"data: {json.dumps({'type': 'error', 'content': f'Streaming failed: {e}'})}\n\n" - -async def handle_tool_follow_up(original_payload: Dict[str, Any], tool_call: Dict[str, Any], tool_result: str, citations_collector: List[str]) -> AsyncGenerator[str, None]: - """Handle follow-up request after tool execution""" - try: - print("[TOOL] Handling follow-up request with tool results") - - # Create messages with tool call and result - messages = original_payload["messages"].copy() - - # Add assistant's tool call message - messages.append({ - "role": "assistant", - "tool_calls": [tool_call] - }) - - # Add tool result message - messages.append({ - "role": "tool", - "tool_call_id": tool_call["id"], - "content": tool_result - }) - - # Create follow-up payload - remove tools to get final response - follow_up_payload = { - "model": original_payload["model"], - "messages": messages, - "stream": True, - "max_tokens": 1000 - } - - # Stream the follow-up response - async for chunk in stream_llm_response(follow_up_payload): - yield chunk - - except Exception as e: - print(f"[ERROR] Tool follow-up failed: {e}") - yield f"data: {json.dumps({'type': 'error', 'content': f'Tool follow-up failed: {e}'})}\n\n" - -async def get_non_streaming_response(payload: Dict[str, Any]) -> tuple[str, List[str]]: - """Get non-streaming response by collecting all streaming chunks""" - response_content = "" - citations = [] - - async for chunk in stream_llm_response(payload): - if chunk.startswith("data: "): - try: - data = json.loads(chunk[6:].strip()) - if data.get("type") == "content": - response_content += data.get("content", "") - elif data.get("type") == "citations": - citations.extend(data.get("citations", [])) - elif data.get("type") == "error": - raise HTTPException(status_code=500, detail=data.get("content", "Unknown error")) - except json.JSONDecodeError: - continue - - return response_content, citations + stream: Optional[bool] = False # LangGraph streaming is different, defaulting to False for simplicity here @app.get("/") async def hello(): - """Simple hello endpoint""" - return {"message": "Hello from Kubeflow Docs API!", "service": "https-api"} + return {"message": "Hello from Kubeflow Docs API with LangGraph!", "service": "https-api"} @app.get("/health") async def health_check(): - """Health check endpoint for Kubernetes probes""" return {"status": "healthy", "service": "https-api"} -@app.options("/chat") -async def options_chat(): - """Handle preflight OPTIONS request""" - return {"message": "OK"} - -@app.options("/") -async def options_root(): - """Handle preflight OPTIONS request for root""" - return {"message": "OK"} - -@app.options("/health") -async def options_health(): - """Handle preflight OPTIONS request for health""" - return {"message": "OK"} - @app.post("/chat") async def chat(request: ChatRequest): - """Chat endpoint with RAG capabilities - supports both streaming and non-streaming""" + """ + Invokes the LangGraph state machine. + """ try: - print(f"[CHAT] Processing message: {request.message[:100]}...") + logger.info(f"Processing message: {request.message[:100]}") - # Create initial payload - payload = { - "model": MODEL, - "messages": [ - {"role": "system", "content": SYSTEM_PROMPT}, - {"role": "user", "content": request.message} - ], - "tools": TOOLS, - "tool_choice": "auto", - "stream": True, - "max_tokens": 1500 + # Initial state for the graph + initial_state = { + "user_query": request.message, + "intent": "", + "retrieved_context": [], + "citations": [], + "generation": "", + "error_count": 0 } - if request.stream: - # Return streaming response using Server-Sent Events - return StreamingResponse( - stream_llm_response(payload), - media_type="text/event-stream", - headers={ - "Cache-Control": "no-cache", - "Connection": "keep-alive", - "Access-Control-Allow-Origin": "*", - "Access-Control-Allow-Headers": "Cache-Control" - } - ) - else: - # Return non-streaming JSON response - response_content, citations = await get_non_streaming_response(payload) - - # Remove duplicates from citations while preserving order - unique_citations = [] - for citation in citations: - if citation not in unique_citations: - unique_citations.append(citation) - - return { - "response": response_content, - "citations": unique_citations if unique_citations else None + # Invoke the graph + # For simplicity in this boilerplate, we're using .ainvoke() + final_state = await agent_graph.ainvoke(initial_state) + + return { + "response": final_state.get("generation"), + "citations": final_state.get("citations"), + "metadata": { + "intent": final_state.get("intent"), + "error_count": final_state.get("error_count") } + } except Exception as e: - print(f"[ERROR] Chat handling failed: {e}") + logger.error(f"Chat handling failed: {e}") raise HTTPException(status_code=500, detail=f"Request failed: {e}") if __name__ == "__main__": - print("🚀 Starting Kubeflow Docs HTTP API Server") - print(f" Port: {PORT}") - print(f" LLM Service: {KSERVE_URL}") - print(f" Milvus: {MILVUS_HOST}:{MILVUS_PORT}") - print(f" Collection: {MILVUS_COLLECTION}") - - uvicorn.run( - app, - host="0.0.0.0", - port=PORT - ) + logger.info(f"🚀 Starting Kubeflow Docs HTTP API Server on port {PORT}") + uvicorn.run(app, host="0.0.0.0", port=PORT) diff --git a/server/app.py b/server/app.py index 96b277c..d1c0346 100644 --- a/server/app.py +++ b/server/app.py @@ -7,6 +7,7 @@ from websockets.exceptions import ConnectionClosedError import logging from typing import Dict, Any, List +from core_agent.graph import agent_graph # Import shared core agent graph from sentence_transformers import SentenceTransformer from pymilvus import connections, Collection @@ -63,6 +64,15 @@ - Reply in clean Markdown. """ +# --- NEW: Global Model Initialization --- +print(f"Loading embedding model '{EMBEDDING_MODEL}' into memory...") +try: + GLOBAL_ENCODER = SentenceTransformer(EMBEDDING_MODEL) + print("Embedding model loaded successfully.") +except Exception as e: + print(f"[FATAL] Failed to load embedding model: {e}") + GLOBAL_ENCODER = None +# ---------------------------------------------------- def milvus_search(query: str, top_k: int = 5) -> Dict[str, Any]: @@ -73,9 +83,12 @@ def milvus_search(query: str, top_k: int = 5) -> Dict[str, Any]: collection = Collection(MILVUS_COLLECTION) collection.load() - # Encoder (same model as pipeline) - encoder = SentenceTransformer(EMBEDDING_MODEL) - query_vec = encoder.encode(query).tolist() + # --- NEW: Use global encoder instead of re-loading --- + if GLOBAL_ENCODER is None: + raise RuntimeError("Embedding model is not initialized.") + + query_vec = GLOBAL_ENCODER.encode(query).tolist() + # ---------------------------------------------------- search_params = {"metric_type": "COSINE", "params": {"nprobe": 32}} results = collection.search( diff --git a/tests/test_graph.py b/tests/test_graph.py new file mode 100644 index 0000000..0d4ed16 --- /dev/null +++ b/tests/test_graph.py @@ -0,0 +1,139 @@ +import pytest +from unittest.mock import patch, MagicMock +from core_agent.graph import agent_graph, MAX_RETRIES + +# --- Fixtures --- + +@pytest.fixture +def initial_state(): + """Returns a clean initial state for the agent.""" + return { + "user_query": "Test query", + "intent": "", + "context": "", + "retrieved_context": [], + "citations": [], + "response": "", + "generation": "", + "error_count": 0 + } + +# --- Parameterized Routing Tests --- + +@pytest.mark.parametrize("mock_intent, expected_retrieval_node", [ + ('Conceptual Definition', 'retrieve_docs'), + ('API Reference', 'retrieve_docs'), + ('Debugging / Error Log', 'retrieve_github_issues'), + ('Pipeline Configuration', 'retrieve_architecture'), +]) +def test_routing_to_retrieval(initial_state, mock_intent, expected_retrieval_node): + """ + Tests that the graph correctly routes from classify_intent to the + appropriate retrieval node based on the intent. + """ + with patch('core_agent.graph.call_llm_classify', return_value=mock_intent), \ + patch('core_agent.graph.call_milvus_search') as mock_milvus, \ + patch('core_agent.graph.call_llm_generate') as mock_gen: + + # Mock Milvus to return dummy data + mock_milvus.return_value = { + "content": "Dummy content", + "citations": ["http://test.com"] + } + # Mock LLM generation to return a successful response + mock_gen.return_value = "Successful response" + + # Execute the graph + final_state = agent_graph.invoke(initial_state) + + # Assertions + assert final_state['intent'] == mock_intent + assert final_state['generation'] == "Successful response" + # Verify the correct retrieval node was called (via its side effect on the state) + if expected_retrieval_node == 'retrieve_docs': + assert "docs" in str(mock_milvus.call_args_list) + elif expected_retrieval_node == 'retrieve_github_issues': + assert "github" in str(mock_milvus.call_args_list) + elif expected_retrieval_node == 'retrieve_architecture': + assert "architecture" in str(mock_milvus.call_args_list) + +def test_routing_general_conversation(initial_state): + """ + Tests that 'General Conversation' intent bypasses retrieval nodes + and goes directly to generate_response. + """ + mock_intent = 'General Conversation' + with patch('core_agent.graph.call_llm_classify', return_value=mock_intent), \ + patch('core_agent.graph.call_milvus_search') as mock_milvus, \ + patch('core_agent.graph.call_llm_generate', return_value="General response"): + + final_state = agent_graph.invoke(initial_state) + + assert final_state['intent'] == mock_intent + assert final_state['generation'] == "General response" + # Milvus search should NOT be called for general conversation + mock_milvus.assert_not_called() + +# --- Cyclic Error Correction Tests --- + +def test_cyclic_error_correction_loop(initial_state): + """ + Tests the failure recovery loop. + Simulates a single failure in generation, asserts retry, and then a success. + """ + mock_intent = 'API Reference' + + # We want call_llm_generate to return "" first (failure), then "Success" + side_effects = ["", "Success after retry"] + + with patch('core_agent.graph.call_llm_classify', return_value=mock_intent), \ + patch('core_agent.graph.call_milvus_search', return_value={"content": "data", "citations": []}), \ + patch('core_agent.graph.call_llm_generate', side_effect=side_effects) as mock_gen: + + final_state = agent_graph.invoke(initial_state) + + # Verify it called generation twice + assert mock_gen.call_count == 2 + # Verify error_count was incremented then carried over + assert final_state['error_count'] == 1 # 1 failure, then 1 success (which doesn't inc) + assert final_state['generation'] == "Success after retry" + assert final_state['intent'] == mock_intent + +# --- Bailout Condition Test --- + +def test_bailout_condition(initial_state): + """ + Tests that the graph exits to END when error_count reaches MAX_RETRIES. + Ensures no infinite loops. + """ + mock_intent = 'Conceptual Definition' + + # Always return failure + with patch('core_agent.graph.call_llm_classify', return_value=mock_intent), \ + patch('core_agent.graph.call_milvus_search', return_value={"content": "data", "citations": []}), \ + patch('core_agent.graph.call_llm_generate', return_value="") as mock_gen: + + final_state = agent_graph.invoke(initial_state) + + # Should have called generate MAX_RETRIES times + assert mock_gen.call_count == MAX_RETRIES + assert final_state['error_count'] == MAX_RETRIES + assert final_state['generation'] == "" # Or a fallback message if implemented + # Verify we didn't get stuck in an infinite loop + # (The fact that invoke() returned is proof enough in a synchronous test) + +# --- Robustness / Edge Case Tests --- + +def test_initial_error_count_high(initial_state): + """ + Tests behavior when starting with an already high error count. + """ + initial_state['error_count'] = MAX_RETRIES + with patch('core_agent.graph.call_llm_classify', return_value='General Conversation'), \ + patch('core_agent.graph.call_llm_generate', return_value=""): + + final_state = agent_graph.invoke(initial_state) + + # Even if it fails, it should exit immediately because error_count is already at max + assert final_state['error_count'] == MAX_RETRIES + 1 + assert final_state['generation'] == ""