diff --git a/requirements.txt b/requirements.txt index 9f4683e..0376bbc 100644 --- a/requirements.txt +++ b/requirements.txt @@ -77,6 +77,7 @@ pinecone-client==3.0.0 propcache==0.4.1 proto-plus==1.26.1 protobuf==4.25.8 +PyJWT==2.8.0 pyasn1==0.6.1 pyasn1-modules==0.4.2 pydantic==2.12.4 diff --git a/services/agent_core.py b/services/agent_core.py index 390e939..cd7abcd 100644 --- a/services/agent_core.py +++ b/services/agent_core.py @@ -2,6 +2,7 @@ from langchain.agents import AgentExecutor, initialize_agent, AgentType from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder +from langchain_core.messages import HumanMessage, AIMessage from langchain_google_genai import ChatGoogleGenerativeAI from config.settings import settings from services.agent_tools import all_tools @@ -146,17 +147,72 @@ async def invoke_agent( # This ensures thread-safety for concurrent users token_context.set(user_token) - # 4. Invoke Agent Executor (use ainvoke for async tools) - result = await self.agent_executor.ainvoke({ - "input": user_query, - "chat_history": chat_history, - "user_context": user_context_str, # Injected into System Prompt - "rag_context": rag_context_str # Injected into System Prompt - }) + # 4. Convert chat history to LangChain message objects + # The MessagesPlaceholder expects HumanMessage and AIMessage objects, not plain dicts + langchain_history = [] + for msg in chat_history: + role = msg.get("role", "").lower() + content = msg.get("content", "") + if role == "user" or role == "human": + langchain_history.append(HumanMessage(content=content)) + elif role == "assistant" or role == "ai": + langchain_history.append(AIMessage(content=content)) - # 5. Determine Tool Execution Status + # 5. Invoke Agent Executor (try async, fallback to sync if async is not available) + raw_result = None + try: + # some AgentExecutor versions expose an async method named `ainvoke` + if hasattr(self.agent_executor, "ainvoke"): + raw_result = await self.agent_executor.ainvoke({ + "input": user_query, + "chat_history": langchain_history, + "user_context": user_context_str, # Injected into System Prompt + "rag_context": rag_context_str # Injected into System Prompt + }) + else: + # Fallback: call the synchronous `run` in a thread if async method missing + logger.info("AgentExecutor does not expose `ainvoke`, using sync `run` in an executor as fallback") + import asyncio as _asyncio + raw_result = await _asyncio.to_thread( + self.agent_executor.run, + { + "input": user_query, + "chat_history": langchain_history, + "user_context": user_context_str, + "rag_context": rag_context_str + } + ) + + except Exception as ex: + # Log exception with stack trace for easier debugging and re-raise + logger.exception("AgentExecutor invocation failed") + raise + + # 6. Determine Tool Execution Status tool_executed = None - intermediate_steps = result.get('intermediate_steps', []) + + # Normalize raw_result into the expected structure + intermediate_steps = [] + result = {} + + try: + if isinstance(raw_result, dict): + # When agent returns a dict-like response + result = raw_result + intermediate_steps = result.get('intermediate_steps', []) or [] + elif isinstance(raw_result, tuple) and len(raw_result) >= 2: + # Common return shape when return_intermediate_steps=True -> (output, intermediate_steps) + result = {"output": raw_result[0], "intermediate_steps": raw_result[1]} + intermediate_steps = raw_result[1] or [] + elif isinstance(raw_result, str): + # Simple string output + result = {"output": raw_result} + else: + # Any other shape - convert to string for output + result = {"output": str(raw_result)} + except Exception: + logger.exception("Failed to normalize agent executor output; converting to string") + result = {"output": str(raw_result)} if intermediate_steps: # intermediate_steps is a list of tuples: (AgentAction, tool_output) diff --git a/services/agent_tools.py b/services/agent_tools.py index 5c9a803..07eeebc 100644 --- a/services/agent_tools.py +++ b/services/agent_tools.py @@ -4,8 +4,14 @@ from services.token_context import token_context import json -# Get the singleton client instance -client = get_microservice_client() +# Attempt to get the singleton client instance - make import resilient (tests may not have httpx installed) +try: + client = get_microservice_client() +except Exception as e: + # Tests will patch `services.agent_tools.client` as needed; avoid hard failures during import + import logging + logging.getLogger(__name__).warning("Microservice client not available at import time: %s", e) + client = None # --- 1. Appointment Tools --- @@ -76,7 +82,14 @@ async def get_my_vehicles_tool() -> str: summary = "Your Vehicles:\n" for v in vehicles: - summary += f"- {v.get('make')} {v.get('model')} ({v.get('year')}) - Plate: {v.get('licensePlate')} - ID: {v.get('id')}\n" + # tolerate different JSON shapes from services (camelCase vs snake_case) + make = v.get('make') or v.get('Make') or '' + model = v.get('model') or v.get('Model') or '' + year = v.get('year') or v.get('Year') or '' + plate = v.get('licensePlate') or v.get('license_plate') or v.get('plate') or '' + vid = v.get('vehicleId') or v.get('id') or v.get('vehicle_id') or '' + + summary += f"- {make} {model} ({year}) - Plate: {plate} - ID: {vid}\n" return summary async def get_vehicle_details_tool(vehicle_id: str) -> str: diff --git a/services/microservice_client.py b/services/microservice_client.py index 8c13ec0..d18704c 100644 --- a/services/microservice_client.py +++ b/services/microservice_client.py @@ -2,7 +2,8 @@ import os import logging import asyncio -from typing import List, Dict, Any, Optional +import jwt +from typing import List, Dict, Any, Optional, Tuple from config.settings import settings from models.chat import UserContext, VehicleInfo @@ -19,25 +20,79 @@ class MicroserviceClient: def __init__(self): # Initialize an AsyncClient once per instance self._async_client = httpx.AsyncClient(timeout=5.0) - self.auth_url = settings.AUTHENTICATION_SERVICE_URL - self.vehicle_url = settings.VEHICLE_SERVICE_URL - self.project_url = settings.PROJECT_SERVICE_URL + # Normalize and sanitize URLs (strip whitespace and trailing slashes as needed) + self.auth_url = (settings.AUTHENTICATION_SERVICE_URL or "").strip() + self.vehicle_url = (settings.VEHICLE_SERVICE_URL or "").strip() + self.project_url = (settings.PROJECT_SERVICE_URL or "").strip() # FIX: Added required microservice URLs self.appointment_url = settings.APPOINTMENT_SERVICE_URL self.time_log_url = settings.TIME_LOGGING_SERVICE_URL + def _extract_user_from_token(self, token: str) -> Tuple[str, str]: + """ + Extract username and roles from JWT token. + Returns (username, roles_csv_string) + """ + try: + # Decode without verification (we trust our own tokens) + payload = jwt.decode(token, options={"verify_signature": False}) + username = payload.get("sub", "") + + # Extract roles - they might be in different formats + roles = payload.get("roles", []) + if isinstance(roles, list): + # Remove ROLE_ prefix if present + cleaned_roles = [r.replace("ROLE_", "") for r in roles] + roles_str = ",".join(cleaned_roles) + elif isinstance(roles, str): + roles_str = roles.replace("ROLE_", "") + else: + roles_str = "" + + logger.debug(f"Extracted from JWT - username: {username}, roles: {roles_str}") + return username, roles_str + except Exception as e: + logger.warning(f"Failed to extract user from token: {e}") + return "", "" + async def _make_get_request(self, url: str, token: str, params: Dict[str, Any] = None) -> Dict[str, Any]: """Internal helper for making async authenticated GET requests.""" headers = {"Authorization": f"Bearer {token}"} + + # Add X-User headers for direct service calls + username, roles = self._extract_user_from_token(token) + if username: + headers["X-User-Subject"] = username + headers["X-User-Roles"] = roles + + # defensive trimming - remove accidental spaces + url = (url or "").strip() + logger.debug(f"Making GET request to: {url} params={params}") try: - # FIX: Use async client and await response = await self._async_client.get(url, params=params, headers=headers) response.raise_for_status() return response.json() except httpx.HTTPStatusError as errh: - logger.error(f"HTTP Error {errh.response.status_code} from {url}: {errh.response.text}") - return {"error": f"HTTP Error {errh.response.status_code}", "status_code": errh.response.status_code} + # Detailed error body may be helpful for callers - attempt to parse JSON + status = errh.response.status_code + body = None + try: + body = errh.response.json() + except Exception: + body = errh.response.text or None + + logger.error(f"HTTP Error {status} from {url}: {body}") + + # Return underlying error body if available, but keep a consistent shape + result = {"status_code": status} + if isinstance(body, dict): + # merge error body and preserve status_code + result.update(body) + else: + result["error"] = body or f"HTTP Error {status}" + + return result except httpx.RequestError as errc: logger.error(f"Request Error to {url}: {errc}") return {"error": "Microservice Unreachable", "status_code": 503} @@ -48,6 +103,13 @@ async def _make_get_request(self, url: str, token: str, params: Dict[str, Any] = async def _make_post_request(self, url: str, token: str, data: Dict[str, Any] = None) -> Dict[str, Any]: """Internal helper for making async authenticated POST requests.""" headers = {"Authorization": f"Bearer {token}"} + + # Add X-User headers for direct service calls + username, roles = self._extract_user_from_token(token) + if username: + headers["X-User-Subject"] = username + headers["X-User-Roles"] = roles + try: response = await self._async_client.post(url, json=data, headers=headers) if response.is_success: @@ -63,6 +125,13 @@ async def _make_post_request(self, url: str, token: str, data: Dict[str, Any] = async def _make_put_request(self, url: str, token: str, data: Dict[str, Any] = None) -> Dict[str, Any]: """Internal helper for making async authenticated PUT requests.""" headers = {"Authorization": f"Bearer {token}"} + + # Add X-User headers for direct service calls + username, roles = self._extract_user_from_token(token) + if username: + headers["X-User-Subject"] = username + headers["X-User-Roles"] = roles + try: response = await self._async_client.put(url, json=data, headers=headers) if response.is_success: @@ -78,6 +147,13 @@ async def _make_put_request(self, url: str, token: str, data: Dict[str, Any] = N async def _make_delete_request(self, url: str, token: str) -> Dict[str, Any]: """Internal helper for making async authenticated DELETE requests.""" headers = {"Authorization": f"Bearer {token}"} + + # Add X-User headers for direct service calls + username, roles = self._extract_user_from_token(token) + if username: + headers["X-User-Subject"] = username + headers["X-User-Roles"] = roles + try: response = await self._async_client.delete(url, headers=headers) if response.is_success: @@ -99,13 +175,23 @@ async def get_user_context(self, token: str) -> UserContext: async def _async_get_user_context(self, token: str) -> UserContext: """Retrieves user profile and vehicles (ASYNC helper).""" - # 1. Get User Profile (/auth/me endpoint) - user_data = await self._make_get_request(f"{self.auth_url}/me", token) + # 1. Get User Profile (/users/me endpoint) + base_url = self.auth_url.strip().rstrip('/') + if base_url.endswith('/users'): + url = f"{base_url}/me" + else: + url = f"{base_url}/users/me" + + user_data = await self._make_get_request(url, token) if "error" in user_data: return UserContext(user_id="anonymous", full_name="Guest", role="PUBLIC", vehicles=[]) # 2. Get User Vehicles (/vehicles endpoint) - vehicle_data = await self._make_get_request(f"{self.vehicle_url}", token) + url = self.vehicle_url.strip().rstrip('/') + if not url.endswith("/vehicles"): + url = f"{url}/vehicles" + vehicle_data = await self._make_get_request(url, token) + vehicles = [] if isinstance(vehicle_data, list): vehicles = [ @@ -118,7 +204,7 @@ async def _async_get_user_context(self, token: str) -> UserContext: ] return UserContext( - user_id=user_data.get("id") or user_data.get("userId", "unknown"), + user_id=str(user_data.get("id") or user_data.get("userId") or "unknown"), full_name=user_data.get("fullName") or user_data.get("username", "unknown"), role=user_data.get("role", "CUSTOMER"), vehicles=vehicles @@ -188,19 +274,29 @@ async def cancel_appointment(self, appointment_id: str, token: str) -> Dict[str, # 2. Vehicles async def get_customer_vehicles(self, token: str) -> List[Dict[str, Any]]: """Get all vehicles for the current user.""" - result = await self._make_get_request(self.vehicle_url, token) + url = self.vehicle_url.strip().rstrip('/') + if not url.endswith("/vehicles"): + url = f"{url}/vehicles" + result = await self._make_get_request(url, token) if isinstance(result, list): return result return [] async def get_vehicle_details(self, vehicle_id: str, token: str) -> Dict[str, Any]: """Get details for a specific vehicle.""" - url = f"{self.vehicle_url}/{vehicle_id}" + base_url = self.vehicle_url.strip().rstrip('/') + if base_url.endswith("/vehicles"): + url = f"{base_url}/{vehicle_id}" + else: + url = f"{base_url}/vehicles/{vehicle_id}" return await self._make_get_request(url, token) async def register_vehicle(self, vehicle_data: Dict[str, Any], token: str) -> Dict[str, Any]: """Register a new vehicle.""" - return await self._make_post_request(self.vehicle_url, token, vehicle_data) + url = self.vehicle_url.strip().rstrip('/') + if not url.endswith("/vehicles"): + url = f"{url}/vehicles" + return await self._make_post_request(url, token, vehicle_data) # 3. Projects async def request_modification_project(self, project_data: Dict[str, Any], token: str) -> Dict[str, Any]: @@ -228,12 +324,20 @@ async def get_project_details(self, project_id: str, token: str) -> Dict[str, An # 4. Profile async def get_my_profile(self, token: str) -> Dict[str, Any]: """Get current user profile.""" - url = f"{self.auth_url}/users/me" + base_url = self.auth_url.strip().rstrip('/') + if base_url.endswith('/users'): + url = f"{base_url}/me" + else: + url = f"{base_url}/users/me" return await self._make_get_request(url, token) async def update_my_profile(self, profile_data: Dict[str, Any], token: str) -> Dict[str, Any]: """Update current user profile.""" - url = f"{self.auth_url}/users/profile" + base_url = self.auth_url.strip().rstrip('/') + if base_url.endswith('/users'): + url = f"{base_url}/profile" + else: + url = f"{base_url}/users/profile" return await self._make_put_request(url, token, profile_data) diff --git a/test_agent_core_invocation.py b/test_agent_core_invocation.py new file mode 100644 index 0000000..ace0a57 --- /dev/null +++ b/test_agent_core_invocation.py @@ -0,0 +1,96 @@ +import sys +import pytest +import asyncio +from unittest.mock import AsyncMock, MagicMock + +# Some CI/test environments do not have langchain and google packages installed. +# Patch sys.modules with lightweight mocks so the module imports successfully and +# tests can instantiate AIAgentService via object.__new__ (skipping __init__). +import types + +# Add module stubs so the service module imports succeed in test environments +if 'langchain' not in sys.modules: + langchain_mod = types.ModuleType('langchain') + sys.modules['langchain'] = langchain_mod + +if 'langchain.tools' not in sys.modules: + sys.modules['langchain.tools'] = types.ModuleType('langchain.tools') + +if 'langchain.agents' not in sys.modules: + sys.modules['langchain.agents'] = types.ModuleType('langchain.agents') + +if 'langchain_core.prompts' not in sys.modules: + sys.modules['langchain_core.prompts'] = types.ModuleType('langchain_core.prompts') + +if 'langchain_google_genai' not in sys.modules: + sys.modules['langchain_google_genai'] = types.ModuleType('langchain_google_genai') + +# Populate required names so imports succeed +langchain_agents = sys.modules.get('langchain.agents') +setattr(langchain_agents, 'AgentExecutor', type('AgentExecutor', (), {})) +setattr(langchain_agents, 'initialize_agent', lambda *a, **k: MagicMock()) +setattr(langchain_agents, 'AgentType', type('AgentType', (), {})) + +langchain_tools = sys.modules.get('langchain.tools') +class _StructuredTool: + @classmethod + def from_function(cls, *args, **kwargs): + return None +setattr(langchain_tools, 'StructuredTool', _StructuredTool) + +langchain_prompts = sys.modules.get('langchain_core.prompts') +setattr(langchain_prompts, 'ChatPromptTemplate', type('ChatPromptTemplate', (), {'from_messages': classmethod(lambda cls, x: None)})) +setattr(langchain_prompts, 'MessagesPlaceholder', type('MessagesPlaceholder', (), {})) + +setattr(sys.modules.get('langchain_google_genai'), 'ChatGoogleGenerativeAI', type('ChatGoogleGenerativeAI', (), {})) + +# Import the class directly +from services.agent_core import AIAgentService + + +@pytest.mark.asyncio +async def test_invoke_agent_falls_back_to_sync_run(): + # Create an instance without running __init__ to avoid creating LLMs + agent = object.__new__(AIAgentService) + + # Prepare a synchronous agent_executor (no `ainvoke`) whose run returns (output, intermediate_steps) + class SyncExecutor: + def run(self, payload): + return ("sync output", [("action1", "tool-output")]) + + agent.agent_executor = SyncExecutor() + + # Mock microservice client and rag service + agent.ms_client = MagicMock() + agent.ms_client.get_user_context = AsyncMock(return_value={"id": "user-x"}) + + rag = MagicMock() + rag.retrieve_and_format = MagicMock(return_value={"context": "kb", "num_sources": 0}) + agent.rag_service = rag + + # Call invoke_agent + result = await agent.invoke_agent("hello", "s1", "tok", []) + + assert isinstance(result, dict) + assert result.get("output") == "sync output" + + +@pytest.mark.asyncio +async def test_invoke_agent_uses_ainvoke_when_present(): + agent = object.__new__(AIAgentService) + + class AsyncExecutor: + async def ainvoke(self, payload): + return {"output": "async output", "intermediate_steps": []} + + agent.agent_executor = AsyncExecutor() + agent.ms_client = MagicMock() + agent.ms_client.get_user_context = AsyncMock(return_value={}) + + rag = MagicMock() + rag.retrieve_and_format = MagicMock(return_value={"context": "kb", "num_sources": 1}) + agent.rag_service = rag + + result = await agent.invoke_agent("ask me", "s2", "tok2", []) + + assert result.get("output") == "async output" diff --git a/test_new_tools.py b/test_new_tools.py index 3ccae21..79aa277 100644 --- a/test_new_tools.py +++ b/test_new_tools.py @@ -46,19 +46,33 @@ async def test_book_appointment_tool(): @pytest.mark.asyncio async def test_get_my_vehicles_tool(): + # Case A: service returns camelCase fields with patch('services.agent_tools.client') as mock_client: mock_client.get_customer_vehicles = AsyncMock(return_value=[ {"id": "v1", "make": "Toyota", "model": "Camry", "year": 2020, "licensePlate": "ABC-123"} ]) - + token_context.set("user_token") - + result = await get_my_vehicles_tool() - + assert "Toyota Camry" in result assert "ABC-123" in result mock_client.get_customer_vehicles.assert_called_once_with("user_token") + # Case B: service returns snake_case fields (some services may use snake_case) + with patch('services.agent_tools.client') as mock_client2: + mock_client2.get_customer_vehicles = AsyncMock(return_value=[ + {"vehicle_id": "v2", "make": "Honda", "model": "Civic", "year": 2018, "license_plate": "XYZ-789"} + ]) + + token_context.set("user_token") + result2 = await get_my_vehicles_tool() + + assert "Honda Civic" in result2 + assert "XYZ-789" in result2 + mock_client2.get_customer_vehicles.assert_called_once_with("user_token") + @pytest.mark.asyncio async def test_concurrency_context(): """Verify that token_context works correctly in simulated concurrent calls."""