diff --git a/examples/arcade-dev/01_basic_integration/basic_arcade_client.py b/examples/arcade-dev/01_basic_integration/basic_arcade_client.py index 38f92c5..9266c4e 100644 --- a/examples/arcade-dev/01_basic_integration/basic_arcade_client.py +++ b/examples/arcade-dev/01_basic_integration/basic_arcade_client.py @@ -38,6 +38,7 @@ # Import official arcade SDK try: from arcadepy import Arcade + ARCADE_SDK_AVAILABLE = True except ImportError: ARCADE_SDK_AVAILABLE = False @@ -52,6 +53,7 @@ @dataclass class ArcadeConfig: """Configuration for Arcade.dev API client.""" + api_key: str user_id: str = "demo@example.com" timeout: int = 30 @@ -61,65 +63,82 @@ class ArcadeConfig: class MockArcadeResponse: """Mock response class for demo mode.""" + def __init__(self, data: Dict[str, Any]): - self.id = data.get('id', 'demo_response_123') - self.status = data.get('status', 'success') - self.result = data.get('result', {}) + self.id = data.get("id", "demo_response_123") + self.status = data.get("status", "success") + self.result = data.get("result", {}) self.data = data - + def __getattr__(self, name): return self.data.get(name) class BasicArcadeClient: """Basic Arcade.dev API client with FACT integration using official SDK.""" - - def __init__(self, config: ArcadeConfig, cache_manager: Optional[CacheManager] = None): + + def __init__( + self, config: ArcadeConfig, cache_manager: Optional[CacheManager] = None + ): self.config = config self.cache_manager = cache_manager self.logger = logging.getLogger(__name__) self.client: Optional[Arcade] = None - + # Mock data for demo mode self._demo_tools = [ { "name": "Math.Sqrt", "description": "Calculate square root of a number", "category": "mathematics", - "input_schema": {"type": "object", "properties": {"a": {"type": "number"}}} + "input_schema": { + "type": "object", + "properties": {"a": {"type": "number"}}, + }, }, { - "name": "Google.ListEmails", + "name": "Google.ListEmails", "description": "List emails from Gmail", "category": "email", - "input_schema": {"type": "object", "properties": {"n_emails": {"type": "integer"}}} + "input_schema": { + "type": "object", + "properties": {"n_emails": {"type": "integer"}}, + }, }, { "name": "Slack.PostMessage", "description": "Post message to Slack channel", - "category": "messaging", - "input_schema": {"type": "object", "properties": {"channel": {"type": "string"}, "message": {"type": "string"}}} - } + "category": "messaging", + "input_schema": { + "type": "object", + "properties": { + "channel": {"type": "string"}, + "message": {"type": "string"}, + }, + }, + }, ] - + async def __aenter__(self): """Async context manager entry.""" await self.connect() return self - + async def __aexit__(self, exc_type, exc_val, exc_tb): """Async context manager exit.""" await self.disconnect() - + async def connect(self): """Establish connection to Arcade.dev API.""" if self.config.demo_mode: self.logger.info("Running in demo mode - using mock responses") return - + if not ARCADE_SDK_AVAILABLE: - raise RuntimeError("arcadepy SDK not available. Install with: pip install arcadepy") - + raise RuntimeError( + "arcadepy SDK not available. Install with: pip install arcadepy" + ) + try: # Initialize the official Arcade client self.client = Arcade(api_key=self.config.api_key) @@ -127,19 +146,21 @@ async def connect(self): except Exception as e: self.logger.error(f"Failed to connect to Arcade.dev: {e}") raise - + async def disconnect(self): """Close connection to Arcade.dev API.""" if self.config.demo_mode: self.logger.info("Demo mode - no connection to close") return - + if self.client: # The official SDK doesn't require explicit disconnection self.client = None self.logger.info("Disconnected from Arcade.dev API") - - async def _execute_with_cache_and_retry(self, operation_name: str, operation_func, *args, **kwargs): + + async def _execute_with_cache_and_retry( + self, operation_name: str, operation_func, *args, **kwargs + ): """Execute operation with caching and retry logic.""" # Check cache first cache_key = f"arcade:{operation_name}:{hash(str(args + tuple(kwargs.items())))}" @@ -149,38 +170,44 @@ async def _execute_with_cache_and_retry(self, operation_name: str, operation_fun self.logger.debug(f"Cache hit for {operation_name}") # Return cached result, assuming it's stored as JSON string try: - if hasattr(cached_result, 'content'): + if hasattr(cached_result, "content"): return json.loads(cached_result.content) else: return cached_result except (json.JSONDecodeError, AttributeError): pass # Fall through to actual operation - + # Execute operation with retries last_exception = None for attempt in range(self.config.max_retries): try: result = await operation_func(*args, **kwargs) - + # Cache successful results if self.cache_manager and result: try: - cache_data = result if isinstance(result, dict) else result.__dict__ + cache_data = ( + result if isinstance(result, dict) else result.__dict__ + ) self.cache_manager.store(cache_key, json.dumps(cache_data)) except Exception as e: self.logger.warning(f"Failed to cache result: {e}") - + return result - + except Exception as e: last_exception = e - self.logger.warning(f"Operation {operation_name} attempt {attempt + 1} failed: {e}") + self.logger.warning( + f"Operation {operation_name} attempt {attempt + 1} failed: {e}" + ) if attempt == self.config.max_retries - 1: break - await asyncio.sleep(2 ** attempt) # Exponential backoff - - raise last_exception or RuntimeError(f"All retry attempts failed for {operation_name}") - + await asyncio.sleep(2**attempt) # Exponential backoff + + raise last_exception or RuntimeError( + f"All retry attempts failed for {operation_name}" + ) + async def health_check(self) -> Dict[str, Any]: """Check API health and connectivity.""" if self.config.demo_mode: @@ -191,12 +218,12 @@ async def health_check(self) -> Dict[str, Any]: "services": { "auth": "healthy", "tools": "healthy", - "database": "healthy" + "database": "healthy", }, - "_demo_mode": True + "_demo_mode": True, } return result - + # For real API, we'll try to list tools as a health check try: await self.list_tools() @@ -204,18 +231,15 @@ async def health_check(self) -> Dict[str, Any]: "status": "healthy", "version": "1.4.0", "timestamp": "2025-05-25T19:27:00Z", - "services": { - "auth": "healthy", - "tools": "healthy" - } + "services": {"auth": "healthy", "tools": "healthy"}, } except Exception as e: return { "status": "unhealthy", "error": str(e), - "timestamp": "2025-05-25T19:27:00Z" + "timestamp": "2025-05-25T19:27:00Z", } - + async def get_user_info(self) -> Dict[str, Any]: """Get current user information.""" if self.config.demo_mode: @@ -226,26 +250,26 @@ async def get_user_info(self) -> Dict[str, Any]: "tools_available": 100, "tools_used": 5, "quota_remaining": 95, - "_demo_mode": True + "_demo_mode": True, } - + # The arcadepy SDK doesn't have a direct user info endpoint in the examples # So we'll return basic info based on the configuration return { "user_id": self.config.user_id, "api_key_status": "valid" if self.client else "invalid", - "sdk_version": "1.4.0" + "sdk_version": "1.4.0", } - + async def list_tools(self) -> List[Dict[str, Any]]: """List available tools.""" if self.config.demo_mode: return { "tools": self._demo_tools, "count": len(self._demo_tools), - "_demo_mode": True + "_demo_mode": True, } - + async def _list_tools(): # The arcadepy SDK doesn't provide a direct list tools method in the examples # This would need to be implemented based on the actual API @@ -253,53 +277,55 @@ async def _list_tools(): return { "tools": [], "count": 0, - "message": "Tool listing requires specific API implementation" + "message": "Tool listing requires specific API implementation", } - + return await self._execute_with_cache_and_retry("list_tools", _list_tools) - - async def execute_tool(self, tool_name: str, tool_input: Dict[str, Any]) -> Dict[str, Any]: + + async def execute_tool( + self, tool_name: str, tool_input: Dict[str, Any] + ) -> Dict[str, Any]: """Execute a tool using Arcade.dev.""" if self.config.demo_mode: return self._generate_mock_tool_execution(tool_name, tool_input) - + async def _execute_tool(): try: response = self.client.tools.execute( - tool_name=tool_name, - input=tool_input, - user_id=self.config.user_id + tool_name=tool_name, input=tool_input, user_id=self.config.user_id ) - + # Convert response to dict format result = { "id": response.id, - "status": getattr(response, 'status', 'completed'), + "status": getattr(response, "status", "completed"), "tool_name": tool_name, "input": tool_input, - "result": getattr(response, 'result', None), - "execution_time_ms": getattr(response, 'execution_time_ms', None) + "result": getattr(response, "result", None), + "execution_time_ms": getattr(response, "execution_time_ms", None), } - + return result - + except Exception as e: self.logger.error(f"Tool execution failed: {e}") return { "status": "failed", "tool_name": tool_name, "input": tool_input, - "error": str(e) + "error": str(e), } - + return await self._execute_with_cache_and_retry("execute_tool", _execute_tool) - - def _generate_mock_tool_execution(self, tool_name: str, tool_input: Dict[str, Any]) -> Dict[str, Any]: + + def _generate_mock_tool_execution( + self, tool_name: str, tool_input: Dict[str, Any] + ) -> Dict[str, Any]: """Generate a mock tool execution response.""" # Generate realistic mock responses based on the tool if tool_name == "Math.Sqrt": - number = float(tool_input.get('a', 16)) - result = number ** 0.5 + number = float(tool_input.get("a", 16)) + result = number**0.5 return { "_demo_mode": True, "_demo_timestamp": "2025-05-25T19:27:00Z", @@ -308,13 +334,17 @@ def _generate_mock_tool_execution(self, tool_name: str, tool_input: Dict[str, An "status": "completed", "result": {"value": result, "input": number}, "input": tool_input, - "execution_time_ms": 150 + "execution_time_ms": 150, } - + elif tool_name == "Google.ListEmails": - n_emails = tool_input.get('n_emails', 5) + n_emails = tool_input.get("n_emails", 5) emails = [ - {"id": f"email_{i}", "subject": f"Demo Email {i}", "from": f"sender{i}@example.com"} + { + "id": f"email_{i}", + "subject": f"Demo Email {i}", + "from": f"sender{i}@example.com", + } for i in range(1, min(n_emails + 1, 11)) # Cap at 10 emails ] return { @@ -325,9 +355,9 @@ def _generate_mock_tool_execution(self, tool_name: str, tool_input: Dict[str, An "status": "completed", "result": {"emails": emails, "count": len(emails)}, "input": tool_input, - "execution_time_ms": 1200 + "execution_time_ms": 1200, } - + elif tool_name == "Slack.PostMessage": return { "_demo_mode": True, @@ -336,14 +366,14 @@ def _generate_mock_tool_execution(self, tool_name: str, tool_input: Dict[str, An "tool_name": tool_name, "status": "completed", "result": { - "message_ts": "1640995200.000100", - "channel": tool_input.get('channel', '#general'), - "message_id": "demo_msg_789" + "message_ts": "1640995200.000100", + "channel": tool_input.get("channel", "#general"), + "message_id": "demo_msg_789", }, "input": tool_input, - "execution_time_ms": 800 + "execution_time_ms": 800, } - + # Default tool execution response return { "_demo_mode": True, @@ -353,7 +383,7 @@ def _generate_mock_tool_execution(self, tool_name: str, tool_input: Dict[str, An "status": "completed", "result": {"message": f"Demo execution of {tool_name}", "data": tool_input}, "input": tool_input, - "execution_time_ms": 500 + "execution_time_ms": 500, } @@ -362,31 +392,31 @@ async def main(): # Configure logging logging.basicConfig( level=logging.INFO, - format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", ) - + print("šŸŽ® Arcade.dev Python SDK Integration Example") print("=" * 50) - + # Load configuration from environment api_key = os.getenv("ARCADE_API_KEY", "") user_id = os.getenv("ARCADE_USER_ID", "demo@example.com") - + # Enable demo mode if no API key or if API key looks like a placeholder demo_mode = ( - not bool(api_key.strip()) or - api_key.strip() in ["your_api_key", "demo_key", "placeholder"] or - len(api_key.strip()) < 10 # Real API keys are typically longer + not bool(api_key.strip()) + or api_key.strip() in ["your_api_key", "demo_key", "placeholder"] + or len(api_key.strip()) < 10 # Real API keys are typically longer ) - + config = ArcadeConfig( api_key=api_key if not demo_mode else "demo_key", user_id=user_id, timeout=int(os.getenv("ARCADE_TIMEOUT", "30")), max_retries=int(os.getenv("ARCADE_MAX_RETRIES", "3")), - demo_mode=demo_mode + demo_mode=demo_mode, ) - + if demo_mode: print("šŸŽ­ Demo Mode: No API key found - using mock responses") print("šŸ’” To use real API: Set ARCADE_API_KEY environment variable") @@ -397,7 +427,7 @@ async def main(): print(f"šŸ”‘ Using API key: {api_key[:10]}...") print(f"šŸ‘¤ User ID: {user_id}") print() - + # Initialize cache manager with config cache_config = { "prefix": "arcade_demo", @@ -405,84 +435,92 @@ async def main(): "max_size": "10MB", "ttl_seconds": 3600, "hit_target_ms": 30, - "miss_target_ms": 120 + "miss_target_ms": 120, } cache_manager = CacheManager(cache_config) - + # Demonstrate basic API usage async with BasicArcadeClient(config, cache_manager) as client: try: # Health check print("šŸ” Checking API health...") health = await client.health_check() - status_icon = "šŸŽ­" if health.get('_demo_mode') else "āœ…" + status_icon = "šŸŽ­" if health.get("_demo_mode") else "āœ…" print(f"{status_icon} API Status: {health.get('status', 'unknown')}") - if health.get('_demo_mode'): + if health.get("_demo_mode"): print(" (Demo response)") - elif health.get('status') == 'unhealthy': + elif health.get("status") == "unhealthy": print(f" Error: {health.get('error', 'Unknown error')}") - + # User info print("\nšŸ‘¤ Getting user information...") user_info = await client.get_user_info() - user_icon = "šŸŽ­" if user_info.get('_demo_mode') else "āœ…" + user_icon = "šŸŽ­" if user_info.get("_demo_mode") else "āœ…" print(f"{user_icon} User: {user_info.get('user_id', 'unknown')}") - if user_info.get('tools_available'): + if user_info.get("tools_available"): print(f" Tools available: {user_info.get('tools_available', 'N/A')}") print(f" Quota remaining: {user_info.get('quota_remaining', 'N/A')}") - if user_info.get('_demo_mode'): + if user_info.get("_demo_mode"): print(" (Demo response)") - + # List available tools print("\nšŸ› ļø Listing available tools...") tools_response = await client.list_tools() - tools_icon = "šŸŽ­" if tools_response.get('_demo_mode') else "āœ…" - available_tools = tools_response.get('tools', []) + tools_icon = "šŸŽ­" if tools_response.get("_demo_mode") else "āœ…" + available_tools = tools_response.get("tools", []) print(f"{tools_icon} Found {len(available_tools)} tools") - if tools_response.get('_demo_mode'): + if tools_response.get("_demo_mode"): print(" (Demo response)") - + # Display available tools if available_tools: print("\nšŸ“‹ Available Tools:") for i, tool in enumerate(available_tools[:5], 1): # Show first 5 tools - print(f" {i}. {tool.get('name', 'Unknown')} - {tool.get('description', 'No description')}") - if tool.get('category'): + print( + f" {i}. {tool.get('name', 'Unknown')} - {tool.get('description', 'No description')}" + ) + if tool.get("category"): print(f" Category: {tool.get('category')}") - + # Tool execution examples print("\nšŸ”§ Tool Execution Examples:") - + # Example 1: Math.Sqrt print("\n1. Math.Sqrt - Calculate square root of 625...") execution = await client.execute_tool("Math.Sqrt", {"a": 625}) - exec_icon = "šŸŽ­" if execution.get('_demo_mode') else "āœ…" + exec_icon = "šŸŽ­" if execution.get("_demo_mode") else "āœ…" print(f"{exec_icon} Status: {execution.get('status', 'unknown')}") - if execution.get('result'): - result_val = execution['result'].get('value') if isinstance(execution['result'], dict) else execution['result'] + if execution.get("result"): + result_val = ( + execution["result"].get("value") + if isinstance(execution["result"], dict) + else execution["result"] + ) print(f" Result: √625 = {result_val}") print(f" Execution time: {execution.get('execution_time_ms', 'N/A')}ms") - if execution.get('_demo_mode'): + if execution.get("_demo_mode"): print(" (Demo response)") - + # Example 2: Google.ListEmails print("\n2. Google.ListEmails - List 3 emails...") execution = await client.execute_tool("Google.ListEmails", {"n_emails": 3}) - exec_icon = "šŸŽ­" if execution.get('_demo_mode') else "āœ…" + exec_icon = "šŸŽ­" if execution.get("_demo_mode") else "āœ…" print(f"{exec_icon} Status: {execution.get('status', 'unknown')}") - if execution.get('result') and isinstance(execution['result'], dict): - emails = execution['result'].get('emails', []) + if execution.get("result") and isinstance(execution["result"], dict): + emails = execution["result"].get("emails", []) print(f" Found {len(emails)} emails") for i, email in enumerate(emails[:2], 1): # Show first 2 - print(f" {i}. {email.get('subject', 'No subject')} from {email.get('from', 'Unknown')}") + print( + f" {i}. {email.get('subject', 'No subject')} from {email.get('from', 'Unknown')}" + ) print(f" Execution time: {execution.get('execution_time_ms', 'N/A')}ms") - if execution.get('_demo_mode'): + if execution.get("_demo_mode"): print(" (Demo response)") - + except Exception as e: print(f"āŒ Error: {e}") return 1 - + print("\n" + "=" * 50) if demo_mode: print("šŸŽ­ Basic integration example completed successfully in demo mode!") @@ -496,4 +534,4 @@ async def main(): if __name__ == "__main__": exit_code = asyncio.run(main()) - sys.exit(exit_code) \ No newline at end of file + sys.exit(exit_code) diff --git a/examples/arcade-dev/02_tool_registration/register_fact_tools.py b/examples/arcade-dev/02_tool_registration/register_fact_tools.py index 32e653d..6413d03 100644 --- a/examples/arcade-dev/02_tool_registration/register_fact_tools.py +++ b/examples/arcade-dev/02_tool_registration/register_fact_tools.py @@ -33,6 +33,7 @@ # Import official arcade SDK try: from arcadepy import Arcade + ARCADE_SDK_AVAILABLE = True except ImportError: ARCADE_SDK_AVAILABLE = False @@ -50,6 +51,7 @@ @dataclass class ToolRegistrationConfig: """Configuration for tool registration with Arcade.dev.""" + arcade_api_key: str arcade_base_url: str = "https://api.arcade.dev" workspace_id: Optional[str] = None @@ -62,188 +64,194 @@ class ToolRegistrationConfig: class FactToolRegistrar: """ Manages registration of FACT tools with Arcade.dev platform. - + Handles schema generation, authentication setup, permission configuration, and batch registration operations. """ - + def __init__(self, config: ToolRegistrationConfig): """ Initialize tool registrar with configuration. - + Args: config: Registration configuration """ self.config = config - + # Create arcade config from tool registration config arcade_config = ArcadeConfig( api_key=config.arcade_api_key, user_id=config.user_id, timeout=config.default_timeout, - demo_mode=config.demo_mode + demo_mode=config.demo_mode, ) - + self.arcade_client = BasicArcadeClient(arcade_config) self.tool_registry = get_tool_registry() self.auth_manager = AuthorizationManager() self.logger = logging.getLogger(__name__) - + # Track registration results self.registration_results: List[Dict[str, Any]] = [] - + async def connect(self) -> None: """Establish connection to Arcade.dev platform.""" try: await self.arcade_client.connect() self.logger.info("Successfully connected to Arcade.dev platform") - + # Verify workspace access if specified if self.config.workspace_id and not self.config.demo_mode: await self._verify_workspace_access() - + except Exception as e: self.logger.error(f"Failed to connect to Arcade.dev: {e}") raise - + async def register_all_tools(self) -> Dict[str, Any]: """ Register all FACT tools from the registry with Arcade.dev. - + Returns: Registration summary with results for each tool """ tool_names = self.tool_registry.list_tools() self.logger.info(f"Registering {len(tool_names)} tools with Arcade.dev") - + registration_summary = { "total_tools": len(tool_names), "successful": 0, "failed": 0, "skipped": 0, - "results": [] + "results": [], } - + for tool_name in tool_names: try: result = await self.register_single_tool(tool_name) registration_summary["results"].append(result) - + if result["status"] == "success": registration_summary["successful"] += 1 elif result["status"] == "failed": registration_summary["failed"] += 1 else: registration_summary["skipped"] += 1 - + except Exception as e: error_result = { "tool_name": tool_name, "status": "failed", "error": str(e), - "arcade_tool_id": None + "arcade_tool_id": None, } registration_summary["results"].append(error_result) registration_summary["failed"] += 1 - + self.logger.error(f"Failed to register tool {tool_name}: {e}") - + self.logger.info( f"Tool registration completed: {registration_summary['successful']} successful, " f"{registration_summary['failed']} failed, {registration_summary['skipped']} skipped" ) - + return registration_summary - + async def register_single_tool(self, tool_name: str) -> Dict[str, Any]: """ Register a single FACT tool with Arcade.dev. - + Args: tool_name: Name of the tool to register - + Returns: Registration result dictionary """ try: # Get tool definition from FACT registry tool_definition = self.tool_registry.get_tool(tool_name) - + # Check if tool already exists on Arcade.dev existing_tool = await self._check_existing_tool(tool_name) if existing_tool: self.logger.info(f"Tool {tool_name} already exists, updating...") - return await self._update_existing_tool(tool_name, tool_definition, existing_tool) - + return await self._update_existing_tool( + tool_name, tool_definition, existing_tool + ) + # Prepare tool definition for Arcade.dev - arcade_tool_definition = self._prepare_arcade_tool_definition(tool_definition) - + arcade_tool_definition = self._prepare_arcade_tool_definition( + tool_definition + ) + # Register tool with Arcade.dev - registration_result = await self._register_tool_with_arcade(arcade_tool_definition) - + registration_result = await self._register_tool_with_arcade( + arcade_tool_definition + ) + # Set up permissions and authentication if tool_definition.requires_auth: await self._setup_tool_permissions( - registration_result.get("id"), - tool_name, - tool_definition + registration_result.get("id"), tool_name, tool_definition ) - + result = { "tool_name": tool_name, "status": "success", "arcade_tool_id": registration_result.get("id"), "version": tool_definition.version, "requires_auth": tool_definition.requires_auth, - "message": "Tool registered successfully" + "message": "Tool registered successfully", } - + self.logger.info(f"Successfully registered tool: {tool_name}") return result - + except Exception as e: self.logger.error(f"Failed to register tool {tool_name}: {e}") return { "tool_name": tool_name, "status": "failed", "error": str(e), - "arcade_tool_id": None + "arcade_tool_id": None, } - - async def update_tool_permissions(self, tool_name: str, permissions: Dict[str, Any]) -> bool: + + async def update_tool_permissions( + self, tool_name: str, permissions: Dict[str, Any] + ) -> bool: """ Update permissions for a registered tool. - + Args: tool_name: Name of the tool permissions: Permission configuration - + Returns: True if update was successful """ try: # Get tool info from Arcade.dev tool_info = await self._get_tool_info(tool_name) - + if not tool_info: raise ValueError(f"Tool {tool_name} not found on Arcade.dev") - + tool_id = tool_info.get("id") - + # Update permissions via Arcade.dev API await self._update_tool_permissions(tool_id, permissions) - + self.logger.info(f"Updated permissions for tool: {tool_name}") return True - + except Exception as e: self.logger.error(f"Failed to update permissions for {tool_name}: {e}") return False - + async def list_registered_tools(self) -> List[Dict[str, Any]]: """ List all tools registered on Arcade.dev platform. - + Returns: List of registered tools with metadata """ @@ -257,7 +265,7 @@ async def list_registered_tools(self) -> List[Dict[str, Any]]: "version": "1.0.0", "status": "active", "requires_auth": True, - "_demo_mode": True + "_demo_mode": True, }, { "id": "demo_tool_data_transform", @@ -265,40 +273,42 @@ async def list_registered_tools(self) -> List[Dict[str, Any]]: "version": "1.0.0", "status": "active", "requires_auth": False, - "_demo_mode": True - } + "_demo_mode": True, + }, ] tools = demo_tools else: # For real mode, would use actual API tools = [] - + # Enhance with FACT registry information enhanced_tools = [] for tool in tools: tool_name = tool.get("name") enhanced_tool = dict(tool) - + # Add FACT registry info if available try: fact_tool = self.tool_registry.get_tool(tool_name) if fact_tool: enhanced_tool["fact_version"] = fact_tool.version - enhanced_tool["fact_created_at"] = getattr(fact_tool, 'created_at', None) + enhanced_tool["fact_created_at"] = getattr( + fact_tool, "created_at", None + ) enhanced_tool["fact_timeout"] = fact_tool.timeout_seconds except: enhanced_tool["fact_version"] = None enhanced_tool["fact_created_at"] = None enhanced_tool["fact_timeout"] = None - + enhanced_tools.append(enhanced_tool) - + return enhanced_tools - + except Exception as e: self.logger.error(f"Failed to list registered tools: {e}") return [] - + async def close(self) -> None: """Close connections and cleanup resources.""" try: @@ -306,14 +316,14 @@ async def close(self) -> None: self.logger.info("Tool registrar closed successfully") except Exception as e: self.logger.warning(f"Error closing tool registrar: {e}") - + def _prepare_arcade_tool_definition(self, tool_definition) -> Dict[str, Any]: """ Convert FACT tool definition to Arcade.dev format. - + Args: tool_definition: FACT tool definition - + Returns: Arcade.dev compatible tool definition """ @@ -323,7 +333,9 @@ def _prepare_arcade_tool_definition(self, tool_definition) -> Dict[str, Any]: "parameters": { "type": "object", "properties": tool_definition.parameters, - "required": self._extract_required_parameters(tool_definition.parameters) + "required": self._extract_required_parameters( + tool_definition.parameters + ), }, "version": tool_definition.version, "requires_auth": tool_definition.requires_auth, @@ -331,17 +343,21 @@ def _prepare_arcade_tool_definition(self, tool_definition) -> Dict[str, Any]: "metadata": { "created_by": "FACT", "framework_version": "1.0.0", - "category": tool_definition.name.split(".")[0] if "." in tool_definition.name else "general", - "workspace_id": self.config.workspace_id - } + "category": ( + tool_definition.name.split(".")[0] + if "." in tool_definition.name + else "general" + ), + "workspace_id": self.config.workspace_id, + }, } - + # Add workspace context if specified if self.config.workspace_id: arcade_definition["workspace_id"] = self.config.workspace_id - + return arcade_definition - + def _extract_required_parameters(self, parameters: Dict[str, Any]) -> List[str]: """Extract required parameter names from parameter schema.""" required = [] @@ -351,7 +367,7 @@ def _extract_required_parameters(self, parameters: Dict[str, Any]) -> List[str]: if "default" not in param_schema and param_schema.get("required", True): required.append(param_name) return required - + async def _check_existing_tool(self, tool_name: str) -> Optional[Dict[str, Any]]: """Check if tool already exists on Arcade.dev.""" try: @@ -362,7 +378,7 @@ async def _check_existing_tool(self, tool_name: str) -> Optional[Dict[str, Any]] "id": f"demo_tool_{tool_name.replace('.', '_').lower()}", "name": tool_name, "version": "1.0.0", - "status": "active" + "status": "active", } return None else: @@ -371,28 +387,34 @@ async def _check_existing_tool(self, tool_name: str) -> Optional[Dict[str, Any]] return None except Exception: return None - - async def _update_existing_tool(self, tool_name: str, tool_definition, existing_tool: Dict[str, Any]) -> Dict[str, Any]: + + async def _update_existing_tool( + self, tool_name: str, tool_definition, existing_tool: Dict[str, Any] + ) -> Dict[str, Any]: """Update an existing tool with new definition.""" try: # Compare versions to decide if update is needed existing_version = existing_tool.get("version", "0.0.0") new_version = tool_definition.version - + if self._is_newer_version(new_version, existing_version): # Update tool definition - arcade_definition = self._prepare_arcade_tool_definition(tool_definition) + arcade_definition = self._prepare_arcade_tool_definition( + tool_definition + ) # Note: Actual update implementation would depend on Arcade.dev API # For now, we'll delete and re-register await self._delete_tool(tool_name) - registration_result = await self._register_tool_with_arcade(arcade_definition) - + registration_result = await self._register_tool_with_arcade( + arcade_definition + ) + return { "tool_name": tool_name, "status": "success", "arcade_tool_id": registration_result.get("id"), "version": new_version, - "message": f"Tool updated from {existing_version} to {new_version}" + "message": f"Tool updated from {existing_version} to {new_version}", } else: return { @@ -400,43 +422,45 @@ async def _update_existing_tool(self, tool_name: str, tool_definition, existing_ "status": "skipped", "arcade_tool_id": existing_tool.get("id"), "version": existing_version, - "message": "Tool version is up to date" + "message": "Tool version is up to date", } - + except Exception as e: raise Exception(f"Failed to update existing tool: {e}") - + def _is_newer_version(self, new_version: str, existing_version: str) -> bool: """Compare version strings to determine if new version is newer.""" + def version_tuple(version: str) -> tuple: try: - return tuple(map(int, version.split('.'))) + return tuple(map(int, version.split("."))) except ValueError: return (0, 0, 0) - + return version_tuple(new_version) > version_tuple(existing_version) - - async def _setup_tool_permissions(self, tool_id: str, tool_name: str, tool_definition) -> None: + + async def _setup_tool_permissions( + self, tool_id: str, tool_name: str, tool_definition + ) -> None: """Set up authentication and permissions for a tool.""" try: # Configure tool permissions based on FACT requirements permissions_config = { "require_authentication": tool_definition.requires_auth, "allowed_scopes": ["tool:execute"], - "rate_limit": { - "calls_per_minute": 60, - "calls_per_hour": 1000 - }, - "audit_logging": True + "rate_limit": {"calls_per_minute": 60, "calls_per_hour": 1000}, + "audit_logging": True, } - + # Apply permissions via Arcade.dev API await self._update_tool_permissions(tool_id, permissions_config) - + except Exception as e: self.logger.warning(f"Failed to setup permissions for {tool_name}: {e}") - - async def _update_tool_permissions(self, tool_id: str, permissions: Dict[str, Any]) -> None: + + async def _update_tool_permissions( + self, tool_id: str, permissions: Dict[str, Any] + ) -> None: """Update tool permissions via Arcade.dev API.""" if self.config.demo_mode: self.logger.info(f"Demo: Updated permissions for tool {tool_id}") @@ -444,17 +468,21 @@ async def _update_tool_permissions(self, tool_id: str, permissions: Dict[str, An # Note: This would use actual Arcade.dev permissions API # Implementation depends on their specific API structure self.logger.debug(f"Updating permissions for tool {tool_id}: {permissions}") - + async def _verify_workspace_access(self) -> None: """Verify access to the specified workspace.""" try: # Verify workspace access # Implementation would depend on Arcade.dev workspace API - self.logger.info(f"Verified access to workspace: {self.config.workspace_id}") + self.logger.info( + f"Verified access to workspace: {self.config.workspace_id}" + ) except Exception as e: raise Exception(f"Workspace access verification failed: {e}") - async def _register_tool_with_arcade(self, tool_definition: Dict[str, Any]) -> Dict[str, Any]: + async def _register_tool_with_arcade( + self, tool_definition: Dict[str, Any] + ) -> Dict[str, Any]: """Register a tool with Arcade.dev platform.""" if self.config.demo_mode: # Generate mock registration response @@ -465,7 +493,7 @@ async def _register_tool_with_arcade(self, tool_definition: Dict[str, Any]) -> D "status": "registered", "version": tool_definition["version"], "created_at": "2025-05-25T19:33:00Z", - "_demo_mode": True + "_demo_mode": True, } else: # For real mode, would use actual Arcade.dev API @@ -475,7 +503,7 @@ async def _register_tool_with_arcade(self, tool_definition: Dict[str, Any]) -> D "id": f"real_tool_{tool_definition['name'].replace('.', '_').lower()}", "name": tool_definition["name"], "status": "registered", - "version": tool_definition["version"] + "version": tool_definition["version"], } async def _get_tool_info(self, tool_name: str) -> Optional[Dict[str, Any]]: @@ -487,7 +515,7 @@ async def _get_tool_info(self, tool_name: str) -> Optional[Dict[str, Any]]: "name": tool_name, "version": "1.0.0", "status": "active", - "_demo_mode": True + "_demo_mode": True, } else: # For real mode, would query actual API @@ -513,29 +541,28 @@ async def _delete_tool(self, tool_name: str) -> bool: "type": "string", "description": "Text content to process", "minLength": 1, - "maxLength": 10000 + "maxLength": 10000, }, "analysis_type": { "type": "string", "description": "Type of analysis to perform", "enum": ["sentiment", "keywords", "summary", "entities"], - "default": "summary" - } + "default": "summary", + }, }, requires_auth=True, - timeout_seconds=30 + timeout_seconds=30, ) -def process_text_content(content: str, analysis_type: str = "summary") -> Dict[str, Any]: +def process_text_content( + content: str, analysis_type: str = "summary" +) -> Dict[str, Any]: """Process text content with specified analysis type.""" # Mock implementation for example return { "content_length": len(content), "analysis_type": analysis_type, "result": f"Processed {len(content)} characters with {analysis_type} analysis", - "metadata": { - "processing_time_ms": 150, - "word_count": len(content.split()) - } + "metadata": {"processing_time_ms": 150, "word_count": len(content.split())}, } @@ -543,37 +570,36 @@ def process_text_content(content: str, analysis_type: str = "summary") -> Dict[s name="Data_Transform", description="Transform data between different formats and structures", parameters={ - "data": { - "type": "object", - "description": "Data to transform" - }, + "data": {"type": "object", "description": "Data to transform"}, "target_format": { "type": "string", "description": "Target format for transformation", - "enum": ["json", "csv", "xml", "yaml"] + "enum": ["json", "csv", "xml", "yaml"], }, "options": { "type": "object", "description": "Transformation options", "required": False, - "default": {} - } + "default": {}, + }, }, requires_auth=False, - timeout_seconds=60 + timeout_seconds=60, ) -def transform_data(data: Dict[str, Any], target_format: str, options: Dict[str, Any] = None) -> Dict[str, Any]: +def transform_data( + data: Dict[str, Any], target_format: str, options: Dict[str, Any] = None +) -> Dict[str, Any]: """Transform data to specified format.""" if options is None: options = {} - + # Mock implementation for example return { "original_format": "object", "target_format": target_format, "transformed": True, "options_applied": options, - "transformation_id": "tx_123456" + "transformation_id": "tx_123456", } @@ -586,70 +612,77 @@ async def main(): # Configure logging logging.basicConfig( level=logging.INFO, - format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", ) - + print("šŸš€ FACT Tool Registration with Arcade.dev Example") print("=" * 55) - + # Example tools are automatically registered by @Tool decorator print("\nšŸ”§ Checking registered FACT tools...") registry = get_tool_registry() print(f"āœ… Found {len(registry.list_tools())} registered tools") - + # Load configuration from environment arcade_api_key = os.getenv("ARCADE_API_KEY", "") demo_mode = not arcade_api_key or not ARCADE_SDK_AVAILABLE - + if demo_mode: print("šŸŽ­ Running in DEMO MODE") print(" - No ARCADE_API_KEY found or arcadepy SDK not available") print(" - Will simulate tool registration with mock responses") print(" - To use real API: install arcadepy and set ARCADE_API_KEY") print() - + config = ToolRegistrationConfig( arcade_api_key=arcade_api_key or "demo_key", arcade_base_url=os.getenv("ARCADE_BASE_URL", "https://api.arcade.dev"), workspace_id=os.getenv("ARCADE_WORKSPACE_ID"), default_timeout=int(os.getenv("ARCADE_TIMEOUT", "30")), - requires_auth_by_default=os.getenv("ARCADE_REQUIRE_AUTH", "true").lower() == "true", + requires_auth_by_default=os.getenv("ARCADE_REQUIRE_AUTH", "true").lower() + == "true", demo_mode=demo_mode, - user_id=os.getenv("ARCADE_USER_ID", "demo@example.com") + user_id=os.getenv("ARCADE_USER_ID", "demo@example.com"), ) - + # Initialize tool registrar registrar = FactToolRegistrar(config) - + try: # Connect to Arcade.dev print("\nšŸ”Œ Connecting to Arcade.dev platform...") await registrar.connect() print("āœ… Connected successfully") - + # List existing tools print("\nšŸ“‹ Listing currently registered tools...") existing_tools = await registrar.list_registered_tools() print(f"āœ… Found {len(existing_tools)} existing tools") - + # Register all FACT tools print("\nšŸ”§ Registering FACT tools with Arcade.dev...") registration_summary = await registrar.register_all_tools() - + # Display results print("\nšŸ“Š Registration Summary:") print(f" Total tools: {registration_summary['total_tools']}") print(f" Successful: {registration_summary['successful']}") print(f" Failed: {registration_summary['failed']}") print(f" Skipped: {registration_summary['skipped']}") - + # Show detailed results - if registration_summary['results']: + if registration_summary["results"]: print("\nšŸ“‹ Detailed Results:") - for result in registration_summary['results']: - status_icon = "āœ…" if result['status'] == 'success' else "āŒ" if result['status'] == 'failed' else "ā­ļø" - print(f" {status_icon} {result['tool_name']}: {result.get('message', result['status'])}") - + for result in registration_summary["results"]: + status_icon = ( + "āœ…" + if result["status"] == "success" + else "āŒ" if result["status"] == "failed" else "ā­ļø" + ) + print( + f" {status_icon} {result['tool_name']}: {result.get('message', result['status'])}" + ) + # Demonstrate permission updates print("\nšŸ” Demonstrating permission updates...") success = await registrar.update_tool_permissions( @@ -657,30 +690,32 @@ async def main(): { "require_authentication": True, "allowed_scopes": ["tool:execute", "data:read"], - "rate_limit": {"calls_per_minute": 30} - } + "rate_limit": {"calls_per_minute": 30}, + }, ) print(f"āœ… Permission update {'successful' if success else 'failed'}") - + # Final tool list print("\nšŸ“‹ Final tool list:") final_tools = await registrar.list_registered_tools() for tool in final_tools: auth_status = "šŸ”" if tool.get("requires_auth") else "šŸ”“" - print(f" {auth_status} {tool.get('name')} (v{tool.get('version', 'unknown')})") - + print( + f" {auth_status} {tool.get('name')} (v{tool.get('version', 'unknown')})" + ) + except Exception as e: print(f"āŒ Error during tool registration: {e}") return 1 - + finally: # Clean up await registrar.close() - + print("\nšŸŽ‰ Tool registration example completed successfully!") return 0 if __name__ == "__main__": exit_code = asyncio.run(main()) - sys.exit(exit_code) \ No newline at end of file + sys.exit(exit_code) diff --git a/examples/arcade-dev/03_intelligent_routing/hybrid_execution.py b/examples/arcade-dev/03_intelligent_routing/hybrid_execution.py index 693c4b7..4f5e087 100644 --- a/examples/arcade-dev/03_intelligent_routing/hybrid_execution.py +++ b/examples/arcade-dev/03_intelligent_routing/hybrid_execution.py @@ -34,18 +34,22 @@ # Create alias for compatibility ArcadeClient = BasicArcadeClient + # Define classes that might not exist in the actual FACT implementation @dataclass class ToolCall: """Tool call data structure.""" + id: str name: str arguments: Dict[str, Any] user_id: Optional[str] = None + @dataclass class ToolResult: """Tool execution result.""" + call_id: str tool_name: str success: bool @@ -55,17 +59,18 @@ class ToolResult: status_code: int = 200 metadata: Optional[Dict[str, Any]] = None + # Mock ToolExecutor for this example class MockToolExecutor: """Mock tool executor for local execution.""" - + def __init__(self): self.registry = {} - + async def execute_tool_call(self, tool_call: ToolCall) -> ToolResult: """Execute a tool call locally.""" start_time = time.time() - + # Check if we have a registered function for this tool tool_func = self.registry.get(tool_call.name) if not tool_func: @@ -75,14 +80,14 @@ async def execute_tool_call(self, tool_call: ToolCall) -> ToolResult: success=False, error=f"Tool '{tool_call.name}' not found", execution_time_ms=(time.time() - start_time) * 1000, - status_code=404 + status_code=404, ) - + try: # Execute the tool function result = tool_func(**tool_call.arguments) execution_time = (time.time() - start_time) * 1000 - + return ToolResult( call_id=tool_call.id, tool_name=tool_call.name, @@ -90,9 +95,9 @@ async def execute_tool_call(self, tool_call: ToolCall) -> ToolResult: data=result, execution_time_ms=execution_time, status_code=200, - metadata={"execution_mode": "local"} + metadata={"execution_mode": "local"}, ) - + except Exception as e: execution_time = (time.time() - start_time) * 1000 return ToolResult( @@ -101,19 +106,21 @@ async def execute_tool_call(self, tool_call: ToolCall) -> ToolResult: success=False, error=str(e), execution_time_ms=execution_time, - status_code=500 + status_code=500, ) - + def register_tool(self, name: str, func): """Register a tool function.""" self.registry[name] = func + # Replace ToolExecutor with our mock ToolExecutor = MockToolExecutor class ExecutionMode(Enum): """Execution mode enumeration.""" + LOCAL = "local" REMOTE = "remote" HYBRID = "hybrid" @@ -123,6 +130,7 @@ class ExecutionMode(Enum): @dataclass class RoutingRule: """Rule for routing tool execution.""" + tool_pattern: str # Tool name pattern (supports wildcards) preferred_mode: ExecutionMode conditions: Dict[str, Any] # Conditions for applying this rule @@ -132,13 +140,14 @@ class RoutingRule: @dataclass class ExecutionMetrics: """Metrics for execution performance tracking.""" + mode: ExecutionMode execution_time_ms: float success: bool error_type: Optional[str] = None tool_name: str = "" timestamp: float = 0.0 - + def __post_init__(self): if self.timestamp == 0.0: self.timestamp = time.time() @@ -147,17 +156,19 @@ def __post_init__(self): class IntelligentRouter: """ Intelligent router for deciding between local and remote tool execution. - + Makes routing decisions based on tool characteristics, performance history, network conditions, and configurable rules. """ - - def __init__(self, - arcade_client: Optional[ArcadeClient] = None, - cache_manager: Optional[CacheManager] = None): + + def __init__( + self, + arcade_client: Optional[ArcadeClient] = None, + cache_manager: Optional[CacheManager] = None, + ): """ Initialize intelligent router. - + Args: arcade_client: Arcade.dev client for remote execution cache_manager: Cache manager for performance data @@ -167,52 +178,56 @@ def __init__(self, self.local_executor = ToolExecutor() self.metrics_collector = MetricsCollector() self.logger = logging.getLogger(__name__) - + # Routing configuration self.routing_rules: List[RoutingRule] = [] self.execution_history: List[ExecutionMetrics] = [] self.performance_cache: Dict[str, Dict[str, float]] = {} - + # Performance thresholds self.local_timeout_threshold = 5.0 # seconds self.remote_timeout_threshold = 30.0 # seconds self.network_latency_threshold = 2.0 # seconds - + # Initialize default routing rules self._setup_default_routing_rules() - + def add_routing_rule(self, rule: RoutingRule) -> None: """ Add a routing rule to the decision engine. - + Args: rule: Routing rule to add """ self.routing_rules.append(rule) # Sort by priority (highest first) self.routing_rules.sort(key=lambda r: r.priority, reverse=True) - - self.logger.info(f"Added routing rule for pattern '{rule.tool_pattern}' " - f"with mode {rule.preferred_mode.value}") - + + self.logger.info( + f"Added routing rule for pattern '{rule.tool_pattern}' " + f"with mode {rule.preferred_mode.value}" + ) + async def execute_tool(self, tool_call: ToolCall) -> ToolResult: """ Execute tool with intelligent routing decision. - + Args: tool_call: Tool call to execute - + Returns: Tool execution result """ start_time = time.time() - + try: # Determine optimal execution mode execution_mode = await self._determine_execution_mode(tool_call) - - self.logger.info(f"Executing tool '{tool_call.name}' using {execution_mode.value} mode") - + + self.logger.info( + f"Executing tool '{tool_call.name}' using {execution_mode.value} mode" + ) + # Execute based on determined mode if execution_mode == ExecutionMode.LOCAL: result = await self._execute_locally(tool_call) @@ -222,34 +237,38 @@ async def execute_tool(self, tool_call: ToolCall) -> ToolResult: result = await self._execute_hybrid(tool_call) else: # AUTO mode result = await self._execute_auto(tool_call) - + # Record successful execution metrics execution_time = (time.time() - start_time) * 1000 metrics = ExecutionMetrics( mode=execution_mode, execution_time_ms=execution_time, success=result.success, - tool_name=tool_call.name + tool_name=tool_call.name, ) - + await self._record_execution_metrics(metrics) - + return result - + except Exception as e: execution_time = (time.time() - start_time) * 1000 - + # Record failed execution metrics metrics = ExecutionMetrics( - mode=execution_mode if 'execution_mode' in locals() else ExecutionMode.LOCAL, + mode=( + execution_mode + if "execution_mode" in locals() + else ExecutionMode.LOCAL + ), execution_time_ms=execution_time, success=False, error_type=type(e).__name__, - tool_name=tool_call.name + tool_name=tool_call.name, ) - + await self._record_execution_metrics(metrics) - + # Create error result return ToolResult( call_id=tool_call.id, @@ -257,16 +276,16 @@ async def execute_tool(self, tool_call: ToolCall) -> ToolResult: success=False, error=str(e), execution_time_ms=execution_time, - status_code=500 + status_code=500, ) - + async def _determine_execution_mode(self, tool_call: ToolCall) -> ExecutionMode: """ Determine the optimal execution mode for a tool call. - + Args: tool_call: Tool call to analyze - + Returns: Optimal execution mode """ @@ -274,27 +293,31 @@ async def _determine_execution_mode(self, tool_call: ToolCall) -> ExecutionMode: for rule in self.routing_rules: if self._matches_rule(tool_call.name, rule): if await self._evaluate_rule_conditions(tool_call, rule): - self.logger.debug(f"Tool '{tool_call.name}' matched rule: {rule.preferred_mode.value}") + self.logger.debug( + f"Tool '{tool_call.name}' matched rule: {rule.preferred_mode.value}" + ) return rule.preferred_mode - + # Performance-based decision performance_score = await self._calculate_performance_score(tool_call.name) - + # Network condition check network_latency = await self._check_network_latency() - + # Tool complexity analysis complexity_score = self._analyze_tool_complexity(tool_call) - + # Make decision based on weighted factors decision_score = ( - performance_score * 0.4 + - (1.0 - min(network_latency / self.network_latency_threshold, 1.0)) * 0.3 + - complexity_score * 0.3 + performance_score * 0.4 + + (1.0 - min(network_latency / self.network_latency_threshold, 1.0)) * 0.3 + + complexity_score * 0.3 ) - - self.logger.debug(f"Decision score for '{tool_call.name}': {decision_score:.3f}") - + + self.logger.debug( + f"Decision score for '{tool_call.name}': {decision_score:.3f}" + ) + # Decision thresholds if decision_score > 0.7: return ExecutionMode.LOCAL @@ -302,7 +325,7 @@ async def _determine_execution_mode(self, tool_call: ToolCall) -> ExecutionMode: return ExecutionMode.REMOTE else: return ExecutionMode.HYBRID - + async def _execute_locally(self, tool_call: ToolCall) -> ToolResult: """Execute tool locally using FACT executor.""" try: @@ -314,23 +337,24 @@ async def _execute_locally(self, tool_call: ToolCall) -> ToolResult: self.logger.info("Falling back to remote execution") return await self._execute_remotely(tool_call) raise - + async def _execute_remotely(self, tool_call: ToolCall) -> ToolResult: """Execute tool remotely via Arcade.dev.""" if not self.arcade_client: - raise RuntimeError("Remote execution not available: Arcade client not configured") - + raise RuntimeError( + "Remote execution not available: Arcade client not configured" + ) + try: start_time = time.time() - + # Execute via Arcade.dev arcade_result = await self.arcade_client.execute_tool( - tool_name=tool_call.name, - tool_input=tool_call.arguments + tool_name=tool_call.name, tool_input=tool_call.arguments ) - + execution_time = (time.time() - start_time) * 1000 - + # Convert Arcade result to ToolResult format success = arcade_result.get("status") != "failed" return ToolResult( @@ -344,35 +368,34 @@ async def _execute_remotely(self, tool_call: ToolCall) -> ToolResult: metadata={ "execution_mode": "remote", "arcade_execution_time": arcade_result.get("execution_time_ms"), - "user_id": tool_call.user_id - } + "user_id": tool_call.user_id, + }, ) - + except Exception as e: self.logger.warning(f"Remote execution failed for '{tool_call.name}': {e}") # Fallback to local if possible self.logger.info("Falling back to local execution") return await self._execute_locally(tool_call) - + async def _execute_hybrid(self, tool_call: ToolCall) -> ToolResult: """ Execute tool using hybrid approach (race between local and remote). - + Returns the first successful result, cancels the other. """ self.logger.info(f"Executing '{tool_call.name}' in hybrid mode") - + # Create tasks for both execution modes local_task = asyncio.create_task(self._execute_locally(tool_call)) remote_task = asyncio.create_task(self._execute_remotely(tool_call)) - + try: # Wait for first completion done, pending = await asyncio.wait( - [local_task, remote_task], - return_when=asyncio.FIRST_COMPLETED + [local_task, remote_task], return_when=asyncio.FIRST_COMPLETED ) - + # Cancel pending tasks for task in pending: task.cancel() @@ -380,24 +403,24 @@ async def _execute_hybrid(self, tool_call: ToolCall) -> ToolResult: await task except asyncio.CancelledError: pass - + # Get result from completed task completed_task = done.pop() result = await completed_task - + # Determine which mode completed first winning_mode = "local" if completed_task == local_task else "remote" - + # Add hybrid execution metadata if result.metadata is None: result.metadata = {} result.metadata["execution_mode"] = "hybrid" result.metadata["winning_mode"] = winning_mode - + self.logger.info(f"Hybrid execution completed: {winning_mode} mode won") - + return result - + except Exception as e: # Cancel any remaining tasks for task in [local_task, remote_task]: @@ -407,21 +430,21 @@ async def _execute_hybrid(self, tool_call: ToolCall) -> ToolResult: await task except asyncio.CancelledError: pass - + raise RuntimeError(f"Hybrid execution failed: {e}") - + async def _execute_auto(self, tool_call: ToolCall) -> ToolResult: """ Execute tool in auto mode with intelligent retry and fallback. """ # Start with performance-based decision performance_data = await self._get_performance_data(tool_call.name) - + if performance_data: # Use historical performance to decide - local_avg = performance_data.get("local_avg_ms", float('inf')) - remote_avg = performance_data.get("remote_avg_ms", float('inf')) - + local_avg = performance_data.get("local_avg_ms", float("inf")) + remote_avg = performance_data.get("remote_avg_ms", float("inf")) + if local_avg < remote_avg * 1.5: # Local is significantly faster primary_mode = ExecutionMode.LOCAL fallback_mode = ExecutionMode.REMOTE @@ -432,7 +455,7 @@ async def _execute_auto(self, tool_call: ToolCall) -> ToolResult: # No historical data, default to local first primary_mode = ExecutionMode.LOCAL fallback_mode = ExecutionMode.REMOTE - + # Try primary mode first try: if primary_mode == ExecutionMode.LOCAL: @@ -441,7 +464,7 @@ async def _execute_auto(self, tool_call: ToolCall) -> ToolResult: return await self._execute_remotely(tool_call) except Exception as e: self.logger.warning(f"Primary mode ({primary_mode.value}) failed: {e}") - + # Try fallback mode try: if fallback_mode == ExecutionMode.LOCAL: @@ -449,18 +472,25 @@ async def _execute_auto(self, tool_call: ToolCall) -> ToolResult: else: return await self._execute_remotely(tool_call) except Exception as fallback_error: - self.logger.error(f"Fallback mode ({fallback_mode.value}) also failed: {fallback_error}") - raise RuntimeError(f"Both execution modes failed: {e}, {fallback_error}") - + self.logger.error( + f"Fallback mode ({fallback_mode.value}) also failed: {fallback_error}" + ) + raise RuntimeError( + f"Both execution modes failed: {e}, {fallback_error}" + ) + def _matches_rule(self, tool_name: str, rule: RoutingRule) -> bool: """Check if tool name matches routing rule pattern.""" import fnmatch + return fnmatch.fnmatch(tool_name, rule.tool_pattern) - - async def _evaluate_rule_conditions(self, tool_call: ToolCall, rule: RoutingRule) -> bool: + + async def _evaluate_rule_conditions( + self, tool_call: ToolCall, rule: RoutingRule + ) -> bool: """Evaluate whether rule conditions are met.""" conditions = rule.conditions - + # Check user-based conditions if "user_id" in conditions: allowed_users = conditions["user_id"] @@ -469,86 +499,87 @@ async def _evaluate_rule_conditions(self, tool_call: ToolCall, rule: RoutingRule return False elif tool_call.user_id != allowed_users: return False - + # Check argument-based conditions if "argument_size" in conditions: max_size = conditions["argument_size"] args_size = len(str(tool_call.arguments)) if args_size > max_size: return False - + # Check time-based conditions if "time_window" in conditions: import datetime + time_window = conditions["time_window"] current_hour = datetime.datetime.now().hour - + if isinstance(time_window, dict): start_hour = time_window.get("start", 0) end_hour = time_window.get("end", 23) if not (start_hour <= current_hour <= end_hour): return False - + return True - + async def _calculate_performance_score(self, tool_name: str) -> float: """Calculate performance score for a tool (0.0 to 1.0, higher is better for local).""" if tool_name in self.performance_cache: data = self.performance_cache[tool_name] local_avg = data.get("local_avg_ms", 1000) remote_avg = data.get("remote_avg_ms", 2000) - + # Score based on relative performance if remote_avg > 0: return min(remote_avg / (local_avg + remote_avg), 1.0) - + # Default score if no performance data return 0.5 - + async def _check_network_latency(self) -> float: """Check current network latency to Arcade.dev.""" if not self.arcade_client: - return float('inf') - + return float("inf") + try: start_time = time.time() # Simple ping-like check - use health check instead await self.arcade_client.health_check() return time.time() - start_time except: - return float('inf') - + return float("inf") + def _analyze_tool_complexity(self, tool_call: ToolCall) -> float: """Analyze tool complexity to inform routing decision.""" complexity_score = 0.0 - + # Argument complexity args_str = str(tool_call.arguments) if len(args_str) > 1000: complexity_score += 0.3 elif len(args_str) > 100: complexity_score += 0.1 - + # Tool name patterns indicating complexity complex_patterns = ["AI.", "ML.", "Analysis.", "Transform."] simple_patterns = ["Util.", "Helper.", "Cache."] - + for pattern in complex_patterns: if tool_call.name.startswith(pattern): complexity_score += 0.4 break - + for pattern in simple_patterns: if tool_call.name.startswith(pattern): complexity_score -= 0.2 break - + return max(0.0, min(1.0, complexity_score)) - + async def _record_execution_metrics(self, metrics: ExecutionMetrics) -> None: """Record execution metrics for future routing decisions.""" self.execution_history.append(metrics) - + # Update performance cache tool_name = metrics.tool_name if tool_name not in self.performance_cache: @@ -556,47 +587,53 @@ async def _record_execution_metrics(self, metrics: ExecutionMetrics) -> None: "local_avg_ms": 0.0, "remote_avg_ms": 0.0, "local_count": 0, - "remote_count": 0 + "remote_count": 0, } - + cache_data = self.performance_cache[tool_name] - + if metrics.mode == ExecutionMode.LOCAL and metrics.success: count = cache_data["local_count"] avg = cache_data["local_avg_ms"] - cache_data["local_avg_ms"] = (avg * count + metrics.execution_time_ms) / (count + 1) + cache_data["local_avg_ms"] = (avg * count + metrics.execution_time_ms) / ( + count + 1 + ) cache_data["local_count"] = count + 1 elif metrics.mode == ExecutionMode.REMOTE and metrics.success: count = cache_data["remote_count"] avg = cache_data["remote_avg_ms"] - cache_data["remote_avg_ms"] = (avg * count + metrics.execution_time_ms) / (count + 1) + cache_data["remote_avg_ms"] = (avg * count + metrics.execution_time_ms) / ( + count + 1 + ) cache_data["remote_count"] = count + 1 - + # Record with metrics collector self.metrics_collector.record_tool_execution( tool_name=metrics.tool_name, success=metrics.success, execution_time=metrics.execution_time_ms, - metadata={"execution_mode": metrics.mode.value} + metadata={"execution_mode": metrics.mode.value}, ) - + # Cache performance data if cache manager available if self.cache_manager: cache_key = f"routing:performance:{tool_name}" # CacheManager uses store method, not set try: import json + self.cache_manager.store(cache_key, json.dumps(cache_data)) except Exception as e: import structlog + logger = structlog.get_logger(__name__) logger.warning(f"Failed to cache performance data: {e}") - + async def _get_performance_data(self, tool_name: str) -> Optional[Dict[str, float]]: """Get cached performance data for a tool.""" if tool_name in self.performance_cache: return self.performance_cache[tool_name] - + if self.cache_manager: cache_key = f"routing:performance:{tool_name}" try: @@ -604,50 +641,60 @@ async def _get_performance_data(self, tool_name: str) -> Optional[Dict[str, floa cached_entry = self.cache_manager.retrieve(cache_key) if cached_entry: import json + cached_data = json.loads(cached_entry.content) self.performance_cache[tool_name] = cached_data return cached_data except Exception as e: import structlog + logger = structlog.get_logger(__name__) logger.warning(f"Failed to retrieve cached performance data: {e}") - + return None - + def _setup_default_routing_rules(self) -> None: """Set up default routing rules.""" # High-performance tools should run locally - self.add_routing_rule(RoutingRule( - tool_pattern="Cache.*", - preferred_mode=ExecutionMode.LOCAL, - conditions={"argument_size": 1000}, # Small arguments only - priority=100 - )) - + self.add_routing_rule( + RoutingRule( + tool_pattern="Cache.*", + preferred_mode=ExecutionMode.LOCAL, + conditions={"argument_size": 1000}, # Small arguments only + priority=100, + ) + ) + # AI/ML tools prefer remote execution (more resources) - self.add_routing_rule(RoutingRule( - tool_pattern="AI.*", - preferred_mode=ExecutionMode.REMOTE, - conditions={}, - priority=90 - )) - + self.add_routing_rule( + RoutingRule( + tool_pattern="AI.*", + preferred_mode=ExecutionMode.REMOTE, + conditions={}, + priority=90, + ) + ) + # Analysis tools use hybrid approach - self.add_routing_rule(RoutingRule( - tool_pattern="Analysis.*", - preferred_mode=ExecutionMode.HYBRID, - conditions={}, - priority=80 - )) - + self.add_routing_rule( + RoutingRule( + tool_pattern="Analysis.*", + preferred_mode=ExecutionMode.HYBRID, + conditions={}, + priority=80, + ) + ) + # Utility tools run locally - self.add_routing_rule(RoutingRule( - tool_pattern="Util.*", - preferred_mode=ExecutionMode.LOCAL, - conditions={}, - priority=70 - )) - + self.add_routing_rule( + RoutingRule( + tool_pattern="Util.*", + preferred_mode=ExecutionMode.LOCAL, + conditions={}, + priority=70, + ) + ) + def get_performance_summary(self) -> Dict[str, Any]: """Get performance summary for all tools.""" summary = { @@ -655,27 +702,27 @@ def get_performance_summary(self) -> Dict[str, Any]: "tools_tracked": len(self.performance_cache), "mode_distribution": {}, "average_execution_times": {}, - "success_rates": {} + "success_rates": {}, } - + # Calculate mode distribution mode_counts = {} for metrics in self.execution_history: mode = metrics.mode.value mode_counts[mode] = mode_counts.get(mode, 0) + 1 - + total_executions = len(self.execution_history) if total_executions > 0: for mode, count in mode_counts.items(): summary["mode_distribution"][mode] = count / total_executions - + # Calculate average execution times and success rates for tool_name, cache_data in self.performance_cache.items(): summary["average_execution_times"][tool_name] = { "local_ms": cache_data.get("local_avg_ms", 0), - "remote_ms": cache_data.get("remote_avg_ms", 0) + "remote_ms": cache_data.get("remote_avg_ms", 0), } - + return summary @@ -683,10 +730,8 @@ def get_performance_summary(self) -> Dict[str, Any]: @Tool( name="Cache_FastLookup", description="Fast cache lookup operation", - parameters={ - "key": {"type": "string", "description": "Cache key to lookup"} - }, - timeout_seconds=5 + parameters={"key": {"type": "string", "description": "Cache key to lookup"}}, + timeout_seconds=5, ) def fast_cache_lookup(key: str) -> Dict[str, Any]: """Fast local cache operation.""" @@ -699,9 +744,9 @@ def fast_cache_lookup(key: str) -> Dict[str, Any]: description="Complex AI analysis requiring significant resources", parameters={ "data": {"type": "object", "description": "Data to analyze"}, - "model_type": {"type": "string", "description": "AI model to use"} + "model_type": {"type": "string", "description": "AI model to use"}, }, - timeout_seconds=60 + timeout_seconds=60, ) def complex_ai_analysis(data: Dict[str, Any], model_type: str) -> Dict[str, Any]: """Complex AI analysis operation.""" @@ -709,24 +754,22 @@ def complex_ai_analysis(data: Dict[str, Any], model_type: str) -> Dict[str, Any] return { "analysis_result": f"AI analysis with {model_type}", "processed_items": len(data.get("items", [])), - "confidence": 0.95 + "confidence": 0.95, } @Tool( name="Util_StringHelper", description="Simple string utility function", - parameters={ - "text": {"type": "string", "description": "Text to process"} - }, - timeout_seconds=10 + parameters={"text": {"type": "string", "description": "Text to process"}}, + timeout_seconds=10, ) def string_helper(text: str) -> Dict[str, Any]: """Simple string processing utility.""" return { "length": len(text), "word_count": len(text.split()), - "uppercase": text.upper() + "uppercase": text.upper(), } @@ -735,12 +778,12 @@ async def main(): # Configure logging logging.basicConfig( level=logging.INFO, - format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", ) - + print("🧠 Intelligent Routing and Hybrid Execution Example") print("=" * 53) - + # Initialize components cache_config = { "prefix": "intelligent_routing", @@ -748,10 +791,10 @@ async def main(): "max_size": "10MB", "ttl_seconds": 3600, "hit_target_ms": 48, - "miss_target_ms": 140 + "miss_target_ms": 140, } cache_manager = CacheManager(cache_config) - + # Initialize Arcade client if configured arcade_client = None arcade_api_key = os.getenv("ARCADE_API_KEY") @@ -759,7 +802,7 @@ async def main(): config = ArcadeConfig( api_key=arcade_api_key, user_id=os.getenv("ARCADE_USER_ID", "demo@example.com"), - demo_mode=False + demo_mode=False, ) arcade_client = ArcadeClient(config, cache_manager) try: @@ -773,50 +816,52 @@ async def main(): print("ā„¹ļø No Arcade API key configured, using demo mode") # Create a demo mode client for testing config = ArcadeConfig( - api_key="demo_key", - user_id="demo@example.com", - demo_mode=True + api_key="demo_key", user_id="demo@example.com", demo_mode=True ) arcade_client = ArcadeClient(config, cache_manager) await arcade_client.connect() print("āœ… Using demo mode for testing intelligent routing") - + # Initialize intelligent router router = IntelligentRouter(arcade_client, cache_manager) - + # Register tools with the local executor router.local_executor.register_tool("Cache_FastLookup", fast_cache_lookup) router.local_executor.register_tool("AI_ComplexAnalysis", complex_ai_analysis) router.local_executor.register_tool("Util_StringHelper", string_helper) - + # Add custom routing rules print("\nšŸ“‹ Setting up custom routing rules...") - - router.add_routing_rule(RoutingRule( - tool_pattern="Cache.*", - preferred_mode=ExecutionMode.LOCAL, - conditions={"argument_size": 500}, - priority=100 - )) - - router.add_routing_rule(RoutingRule( - tool_pattern="AI.*", - preferred_mode=ExecutionMode.REMOTE if arcade_client else ExecutionMode.LOCAL, - conditions={}, - priority=90 - )) - + + router.add_routing_rule( + RoutingRule( + tool_pattern="Cache.*", + preferred_mode=ExecutionMode.LOCAL, + conditions={"argument_size": 500}, + priority=100, + ) + ) + + router.add_routing_rule( + RoutingRule( + tool_pattern="AI.*", + preferred_mode=( + ExecutionMode.REMOTE if arcade_client else ExecutionMode.LOCAL + ), + conditions={}, + priority=90, + ) + ) + print("āœ… Custom routing rules configured") - + # Test different execution scenarios test_cases = [ { "name": "Fast cache lookup", "tool_call": ToolCall( - id="test_1", - name="Cache_FastLookup", - arguments={"key": "user_123"} - ) + id="test_1", name="Cache_FastLookup", arguments={"key": "user_123"} + ), }, { "name": "Complex AI analysis", @@ -825,38 +870,42 @@ async def main(): name="AI_ComplexAnalysis", arguments={ "data": {"items": list(range(100))}, - "model_type": "neural_network" - } - ) + "model_type": "neural_network", + }, + ), }, { "name": "Simple string utility", "tool_call": ToolCall( id="test_3", name="Util_StringHelper", - arguments={"text": "Hello, intelligent routing world!"} - ) - } + arguments={"text": "Hello, intelligent routing world!"}, + ), + }, ] - + print("\nšŸš€ Executing test cases with intelligent routing...") - + for i, test_case in enumerate(test_cases, 1): print(f"\n{i}. {test_case['name']}") - print("-" * (len(test_case['name']) + 3)) - + print("-" * (len(test_case["name"]) + 3)) + try: start_time = time.time() - result = await router.execute_tool(test_case['tool_call']) + result = await router.execute_tool(test_case["tool_call"]) total_time = (time.time() - start_time) * 1000 - + if result.success: execution_mode = result.metadata.get("execution_mode", "unknown") winning_mode = result.metadata.get("winning_mode", "") - - print(f"āœ… Success ({execution_mode} mode{f', {winning_mode} won' if winning_mode else ''})") - print(f" Execution time: {result.execution_time_ms:.1f}ms (total: {total_time:.1f}ms)") - + + print( + f"āœ… Success ({execution_mode} mode{f', {winning_mode} won' if winning_mode else ''})" + ) + print( + f" Execution time: {result.execution_time_ms:.1f}ms (total: {total_time:.1f}ms)" + ) + if result.data: data_preview = str(result.data)[:100] if len(str(result.data)) > 100: @@ -864,42 +913,42 @@ async def main(): print(f" Result: {data_preview}") else: print(f"āŒ Failed: {result.error}") - + except Exception as e: print(f"āŒ Error: {e}") - + # Performance summary print("\nšŸ“Š Performance Summary") print("-" * 20) - + summary = router.get_performance_summary() print(f"Total executions: {summary['total_executions']}") print(f"Tools tracked: {summary['tools_tracked']}") - - if summary['mode_distribution']: + + if summary["mode_distribution"]: print("\nExecution mode distribution:") - for mode, percentage in summary['mode_distribution'].items(): + for mode, percentage in summary["mode_distribution"].items(): print(f" {mode}: {percentage:.1%}") - - if summary['average_execution_times']: + + if summary["average_execution_times"]: print("\nAverage execution times by tool:") - for tool_name, times in summary['average_execution_times'].items(): - local_time = times['local_ms'] - remote_time = times['remote_ms'] + for tool_name, times in summary["average_execution_times"].items(): + local_time = times["local_ms"] + remote_time = times["remote_ms"] print(f" {tool_name}:") if local_time > 0: print(f" Local: {local_time:.1f}ms") if remote_time > 0: print(f" Remote: {remote_time:.1f}ms") - + # Cleanup if arcade_client: await arcade_client.disconnect() - + print("\nšŸŽ‰ Intelligent routing example completed successfully!") return 0 if __name__ == "__main__": exit_code = asyncio.run(main()) - sys.exit(exit_code) \ No newline at end of file + sys.exit(exit_code) diff --git a/examples/arcade-dev/04_error_handling/resilient_execution.py b/examples/arcade-dev/04_error_handling/resilient_execution.py index ec3eec8..09d26ec 100644 --- a/examples/arcade-dev/04_error_handling/resilient_execution.py +++ b/examples/arcade-dev/04_error_handling/resilient_execution.py @@ -33,83 +33,104 @@ # Mock ToolExecutor for demo mode class ToolExecutor: """Mock ToolExecutor for demo purposes.""" + def __init__(self): self.logger = logging.getLogger(__name__) - - async def execute_tool(self, tool_name: str, arguments: Dict[str, Any]) -> Dict[str, Any]: + + async def execute_tool( + self, tool_name: str, arguments: Dict[str, Any] + ) -> Dict[str, Any]: """Mock tool execution.""" return {"status": "mock_execution", "tool": tool_name, "args": arguments} + try: from src.tools.decorators import Tool except ImportError: + def Tool(*args, **kwargs): """Mock Tool decorator for demo mode.""" + def decorator(func): - func._tool_name = kwargs.get('name', func.__name__) - func._tool_description = kwargs.get('description', '') - func._tool_parameters = kwargs.get('parameters', {}) + func._tool_name = kwargs.get("name", func.__name__) + func._tool_description = kwargs.get("description", "") + func._tool_parameters = kwargs.get("parameters", {}) return func + return decorator + try: from src.core.errors import * except ImportError: # Define error classes if they don't exist class ToolExecutionError(Exception): """Tool execution error.""" + pass class ToolValidationError(Exception): """Tool validation error.""" + pass class ToolNotFoundError(Exception): """Tool not found error.""" + pass class UnauthorizedError(Exception): """Unauthorized error.""" + pass class SecurityError(Exception): """Security error.""" + pass + try: from src.monitoring.metrics import MetricsCollector except ImportError: + class MetricsCollector: """Mock metrics collector for demo mode.""" + def __init__(self): self.metrics = {} - + def record(self, name, value, tags=None): """Record a metric.""" pass - + def increment(self, name, tags=None): """Increment a counter.""" pass + from src.cache.manager import CacheManager # Import the BasicArcadeClient from the basic integration example sys.path.insert(0, str(Path(__file__).parent.parent / "01_basic_integration")) from basic_arcade_client import BasicArcadeClient, ArcadeConfig + # Define classes for error handling demonstration @dataclass class ToolCall: """Tool call data structure.""" + id: str name: str arguments: Dict[str, Any] user_id: Optional[str] = None + @dataclass class ToolResult: """Tool execution result.""" + call_id: str tool_name: str success: bool @@ -119,13 +140,16 @@ class ToolResult: status_code: int = 200 metadata: Optional[Dict[str, Any]] = None + class FinalRetryError(Exception): """Final retry error when all retries are exhausted.""" + pass class ErrorType(Enum): """Classification of error types for different handling strategies.""" + NETWORK_ERROR = "network_error" AUTHENTICATION_ERROR = "authentication_error" VALIDATION_ERROR = "validation_error" @@ -138,6 +162,7 @@ class ErrorType(Enum): class RetryStrategy(Enum): """Retry strategy enumeration.""" + NO_RETRY = "no_retry" LINEAR_BACKOFF = "linear_backoff" EXPONENTIAL_BACKOFF = "exponential_backoff" @@ -147,6 +172,7 @@ class RetryStrategy(Enum): @dataclass class ErrorContext: """Context information for error handling decisions.""" + error: Exception error_type: ErrorType tool_name: str @@ -159,6 +185,7 @@ class ErrorContext: @dataclass class RetryConfig: """Configuration for retry behavior.""" + strategy: RetryStrategy max_attempts: int = 3 base_delay: float = 1.0 @@ -171,6 +198,7 @@ class RetryConfig: @dataclass class CircuitBreakerConfig: """Configuration for circuit breaker pattern.""" + failure_threshold: int = 5 success_threshold: int = 3 timeout_seconds: float = 60.0 @@ -179,6 +207,7 @@ class CircuitBreakerConfig: class CircuitState(Enum): """Circuit breaker states.""" + CLOSED = "closed" OPEN = "open" HALF_OPEN = "half_open" @@ -187,11 +216,11 @@ class CircuitState(Enum): class CircuitBreaker: """ Circuit breaker implementation for preventing cascading failures. - + Tracks failure rates and opens the circuit when threshold is exceeded, allowing the system to fail fast and recover gracefully. """ - + def __init__(self, config: CircuitBreakerConfig, name: str = "default"): """Initialize circuit breaker.""" self.config = config @@ -202,24 +231,30 @@ def __init__(self, config: CircuitBreakerConfig, name: str = "default"): self.last_failure_time = 0.0 self.failure_history: List[float] = [] self.logger = logging.getLogger(f"{__name__}.{name}") - + async def call(self, func: Callable, *args, **kwargs) -> Any: """Execute function through circuit breaker protection.""" if self.state == CircuitState.OPEN: if self._should_attempt_reset(): self.state = CircuitState.HALF_OPEN - self.logger.info(f"Circuit breaker '{self.name}' entering half-open state") + self.logger.info( + f"Circuit breaker '{self.name}' entering half-open state" + ) else: raise FinalRetryError(f"Circuit breaker '{self.name}' is open") - + try: - result = await func(*args, **kwargs) if asyncio.iscoroutinefunction(func) else func(*args, **kwargs) + result = ( + await func(*args, **kwargs) + if asyncio.iscoroutinefunction(func) + else func(*args, **kwargs) + ) await self._on_success() return result except Exception as e: await self._on_failure() raise - + async def _on_success(self) -> None: """Handle successful execution.""" if self.state == CircuitState.HALF_OPEN: @@ -233,28 +268,32 @@ async def _on_success(self) -> None: elif self.state == CircuitState.CLOSED: # Reset failure count on success self.failure_count = max(0, self.failure_count - 1) - + async def _on_failure(self) -> None: """Handle failed execution.""" current_time = time.time() self.failure_count += 1 self.last_failure_time = current_time self.failure_history.append(current_time) - + # Clean old failures outside monitoring window window_start = current_time - self.config.monitoring_window self.failure_history = [t for t in self.failure_history if t > window_start] - + if self.state == CircuitState.CLOSED: if len(self.failure_history) >= self.config.failure_threshold: self.state = CircuitState.OPEN self.success_count = 0 - self.logger.warning(f"Circuit breaker '{self.name}' opened due to failure threshold") + self.logger.warning( + f"Circuit breaker '{self.name}' opened due to failure threshold" + ) elif self.state == CircuitState.HALF_OPEN: self.state = CircuitState.OPEN self.success_count = 0 - self.logger.warning(f"Circuit breaker '{self.name}' reopened after half-open failure") - + self.logger.warning( + f"Circuit breaker '{self.name}' reopened after half-open failure" + ) + def _should_attempt_reset(self) -> bool: """Check if circuit should attempt reset.""" return time.time() - self.last_failure_time >= self.config.timeout_seconds @@ -263,16 +302,18 @@ def _should_attempt_reset(self) -> bool: class ResilientExecutor: """ Resilient tool executor with comprehensive error handling. - + Provides intelligent retry strategies, circuit breaking, graceful degradation, and detailed error classification for production workloads. """ - - def __init__(self, - arcade_client: Optional[BasicArcadeClient] = None, - cache_manager: Optional[CacheManager] = None, - retry_config: Optional[RetryConfig] = None, - circuit_config: Optional[CircuitBreakerConfig] = None): + + def __init__( + self, + arcade_client: Optional[BasicArcadeClient] = None, + cache_manager: Optional[CacheManager] = None, + retry_config: Optional[RetryConfig] = None, + circuit_config: Optional[CircuitBreakerConfig] = None, + ): """Initialize resilient executor.""" self.logger = logging.getLogger(__name__) self.arcade_client = arcade_client @@ -280,40 +321,42 @@ def __init__(self, try: self.local_executor = ToolExecutor() except Exception as e: - self.logger.warning(f"Could not initialize ToolExecutor: {e}. Using mock executor.") + self.logger.warning( + f"Could not initialize ToolExecutor: {e}. Using mock executor." + ) self.local_executor = None self.metrics_collector = MetricsCollector() - + # Configuration self.retry_config = retry_config or RetryConfig( strategy=RetryStrategy.EXPONENTIAL_BACKOFF, max_attempts=3, base_delay=1.0, - max_delay=30.0 + max_delay=30.0, ) - + # Circuit breakers for different components circuit_config = circuit_config or CircuitBreakerConfig() self.local_circuit_breaker = CircuitBreaker(circuit_config, "local_executor") self.remote_circuit_breaker = CircuitBreaker(circuit_config, "remote_executor") - + # Error tracking self.error_history: List[ErrorContext] = [] self.error_patterns: Dict[str, List[ErrorType]] = {} - + # Degradation flags self.degraded_mode = False self.degradation_start_time = 0.0 - + async def execute_tool_resilient(self, tool_call: ToolCall) -> ToolResult: """Execute a tool with comprehensive error handling and resilience.""" start_time = time.time() attempt_count = 0 last_exception = None - + while attempt_count < self.retry_config.max_attempts: attempt_count += 1 - + try: # Try circuit breaker protected execution if self.arcade_client: @@ -324,21 +367,21 @@ async def execute_tool_resilient(self, tool_call: ToolCall) -> ToolResult: result = await self.local_circuit_breaker.call( self._execute_local_tool, tool_call ) - + execution_time = (time.time() - start_time) * 1000 - + return ToolResult( call_id=tool_call.id, tool_name=tool_call.name, success=True, data=result, - execution_time_ms=execution_time + execution_time_ms=execution_time, ) - + except Exception as e: last_exception = e error_type = self._classify_error(e) - + # Create error context error_context = ErrorContext( error=e, @@ -346,62 +389,61 @@ async def execute_tool_resilient(self, tool_call: ToolCall) -> ToolResult: tool_name=tool_call.name, attempt_count=attempt_count, total_elapsed_time=time.time() - start_time, - user_id=tool_call.user_id + user_id=tool_call.user_id, ) - + # Track error self.error_history.append(error_context) - + # Check if we should retry if not self._should_retry(error_context): break - + # Calculate retry delay delay = self._calculate_retry_delay(attempt_count) self.logger.warning( f"Tool {tool_call.name} failed (attempt {attempt_count}): {e}. " f"Retrying in {delay:.1f}s..." ) - + await asyncio.sleep(delay) - + # All retries exhausted execution_time = (time.time() - start_time) * 1000 - + return ToolResult( call_id=tool_call.id, tool_name=tool_call.name, success=False, error=str(last_exception), - execution_time_ms=execution_time + execution_time_ms=execution_time, ) - + async def _execute_remote_tool(self, tool_call: ToolCall) -> Dict[str, Any]: """Execute tool using remote Arcade client.""" if not self.arcade_client: raise ToolExecutionError("No arcade client available") - + result = await self.arcade_client.execute_tool( - tool_call.name, - tool_call.arguments + tool_call.name, tool_call.arguments ) return result - + async def _execute_local_tool(self, tool_call: ToolCall) -> Dict[str, Any]: """Execute tool using local executor.""" if not self.local_executor: raise ToolExecutionError("No local executor available") - + # For demo purposes, use the unreliable_service function if tool_call.name == "Test_UnreliableService": return unreliable_service(**tool_call.arguments) - + raise ToolNotFoundError(f"Tool {tool_call.name} not found") - + def _classify_error(self, error: Exception) -> ErrorType: """Classify error type for handling strategy.""" error_str = str(error).lower() - + if "network" in error_str or "connection" in error_str: return ErrorType.NETWORK_ERROR elif "timeout" in error_str: @@ -416,23 +458,23 @@ def _classify_error(self, error: Exception) -> ErrorType: return ErrorType.SERVER_ERROR else: return ErrorType.UNKNOWN_ERROR - + def _should_retry(self, error_context: ErrorContext) -> bool: """Determine if error should be retried.""" # Don't retry validation errors if error_context.error_type == ErrorType.VALIDATION_ERROR: return False - + # Don't retry authentication errors if error_context.error_type == ErrorType.AUTHENTICATION_ERROR: return False - + # Don't retry if max attempts reached if error_context.attempt_count >= self.retry_config.max_attempts: return False - + return True - + def _calculate_retry_delay(self, attempt: int) -> float: """Calculate retry delay based on strategy.""" if self.retry_config.strategy == RetryStrategy.NO_RETRY: @@ -440,14 +482,16 @@ def _calculate_retry_delay(self, attempt: int) -> float: elif self.retry_config.strategy == RetryStrategy.LINEAR_BACKOFF: delay = self.retry_config.base_delay * attempt elif self.retry_config.strategy == RetryStrategy.EXPONENTIAL_BACKOFF: - delay = self.retry_config.base_delay * (self.retry_config.backoff_multiplier ** (attempt - 1)) + delay = self.retry_config.base_delay * ( + self.retry_config.backoff_multiplier ** (attempt - 1) + ) else: delay = self.retry_config.base_delay - + # Apply jitter if enabled if self.retry_config.jitter: - delay *= (0.5 + random.random() * 0.5) - + delay *= 0.5 + random.random() * 0.5 + # Cap at max delay return min(delay, self.retry_config.max_delay) @@ -457,88 +501,101 @@ def _calculate_retry_delay(self, attempt: int) -> float: name="Test_UnreliableService", description="Unreliable service for testing error handling", parameters={ - "failure_rate": {"type": "number", "description": "Probability of failure (0.0-1.0)", "default": 0.3}, - "delay_ms": {"type": "integer", "description": "Processing delay in milliseconds", "default": 100} - } + "failure_rate": { + "type": "number", + "description": "Probability of failure (0.0-1.0)", + "default": 0.3, + }, + "delay_ms": { + "type": "integer", + "description": "Processing delay in milliseconds", + "default": 100, + }, + }, ) -def unreliable_service(failure_rate: float = 0.3, delay_ms: int = 100) -> Dict[str, Any]: +def unreliable_service( + failure_rate: float = 0.3, delay_ms: int = 100 +) -> Dict[str, Any]: """Simulated unreliable service for testing.""" import time import random - + # Simulate processing delay time.sleep(delay_ms / 1000.0) - + # Random failure if random.random() < failure_rate: error_types = [ "Network connection failed", - "Service temporarily unavailable", + "Service temporarily unavailable", "Request timeout", "Rate limit exceeded", - "Internal server error" + "Internal server error", ] raise Exception(random.choice(error_types)) - + return { "success": True, "result": f"Service completed successfully after {delay_ms}ms delay", - "timestamp": time.time() + "timestamp": time.time(), } async def demonstrate_error_handling(): """Demonstrate resilient error handling capabilities.""" - + # Configure logging logging.basicConfig( level=logging.INFO, - format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", ) logger = logging.getLogger(__name__) - + logger.info("Starting resilient error handling demonstration") - + # Initialize components from environment arcade_api_key = os.getenv("ARCADE_API_KEY") cache_config = { "redis_url": os.getenv("REDIS_URL", "redis://localhost:6379"), - "default_ttl": 3600 + "default_ttl": 3600, } - + # Create resilient executor with custom configuration retry_config = RetryConfig( strategy=RetryStrategy.EXPONENTIAL_BACKOFF, max_attempts=5, base_delay=0.5, max_delay=10.0, - jitter=True + jitter=True, ) - + circuit_config = CircuitBreakerConfig( - failure_threshold=3, - success_threshold=2, - timeout_seconds=30.0 + failure_threshold=3, success_threshold=2, timeout_seconds=30.0 ) - + # Initialize clients with demo mode support demo_mode = ( - not bool(arcade_api_key.strip()) if arcade_api_key else True or - arcade_api_key.strip() in ["your_api_key", "demo_key", "placeholder"] if arcade_api_key else True or - len(arcade_api_key.strip()) < 10 if arcade_api_key else True # Real API keys are typically longer + not bool(arcade_api_key.strip()) + if arcade_api_key + else ( + True + or arcade_api_key.strip() in ["your_api_key", "demo_key", "placeholder"] + if arcade_api_key + else True or len(arcade_api_key.strip()) < 10 if arcade_api_key else True + ) # Real API keys are typically longer ) - + arcade_config = ArcadeConfig( api_key=arcade_api_key if not demo_mode else "demo_key", user_id=os.getenv("ARCADE_USER_ID", "demo@example.com"), - demo_mode=demo_mode + demo_mode=demo_mode, ) - + arcade_client = None if arcade_api_key or demo_mode: arcade_client = BasicArcadeClient(arcade_config) await arcade_client.connect() - + cache_manager = None try: if os.getenv("REDIS_URL"): @@ -549,94 +606,94 @@ async def demonstrate_error_handling(): "prefix": "resilient_demo", "min_tokens": 1, "max_size": "1MB", - "ttl_seconds": 300 + "ttl_seconds": 300, } cache_manager = CacheManager(basic_cache_config) except Exception as e: - logger.warning(f"Cache manager initialization failed: {e}. Continuing without cache.") - + logger.warning( + f"Cache manager initialization failed: {e}. Continuing without cache." + ) + executor = ResilientExecutor( arcade_client=arcade_client, cache_manager=cache_manager, retry_config=retry_config, - circuit_config=circuit_config + circuit_config=circuit_config, ) - + if demo_mode: - logger.info("šŸŽ­ Running in demo mode - using mock responses and local error simulation") + logger.info( + "šŸŽ­ Running in demo mode - using mock responses and local error simulation" + ) else: logger.info("šŸ”‘ Using real API key for Arcade.dev integration") - + # Test scenarios test_scenarios = [ # Low failure rate - should succeed quickly {"failure_rate": 0.1, "delay_ms": 50, "description": "Low failure rate"}, - # Medium failure rate - should retry and eventually succeed {"failure_rate": 0.5, "delay_ms": 100, "description": "Medium failure rate"}, - # High failure rate - should eventually fail gracefully {"failure_rate": 0.9, "delay_ms": 200, "description": "High failure rate"}, - # Network timeout simulation - {"failure_rate": 1.0, "delay_ms": 5000, "description": "Timeout simulation"} + {"failure_rate": 1.0, "delay_ms": 5000, "description": "Timeout simulation"}, ] - + for i, scenario in enumerate(test_scenarios, 1): logger.info(f"\n--- Test Scenario {i}: {scenario['description']} ---") - + tool_call = ToolCall( id=f"test_{i}", name="Test_UnreliableService", arguments=scenario, - user_id="demo_user" + user_id="demo_user", ) - + start_time = time.time() - + try: - logger.info(f"Executing unreliable service with {scenario['failure_rate']*100:.0f}% failure rate") - + logger.info( + f"Executing unreliable service with {scenario['failure_rate']*100:.0f}% failure rate" + ) + # Use the resilient executor result = await executor.execute_tool_resilient(tool_call) execution_time = (time.time() - start_time) * 1000 - + if result.success: logger.info(f"āœ“ Success in {execution_time:.0f}ms") logger.info(f"Result: {result.data}") else: logger.error(f"āœ— Failed after all retry attempts: {result.error}") logger.info("Applying graceful degradation...") - + # Demonstrate graceful degradation fallback_result = { "status": "degraded", "message": "Service temporarily unavailable, using cached or default response", - "fallback_data": {"placeholder": "default_value"} + "fallback_data": {"placeholder": "default_value"}, } logger.info(f"Fallback result: {fallback_result}") - + except Exception as e: logger.error(f"Unexpected error in scenario {i}: {e}") logger.debug(f"Exception traceback: {traceback.format_exc()}") - + # Brief pause between scenarios await asyncio.sleep(1) - + # Demonstrate health status monitoring logger.info("\n--- System Health Status ---") health_status = { - "circuit_breakers": { - "local": "closed", - "remote": "closed" - }, + "circuit_breakers": {"local": "closed", "remote": "closed"}, "degraded_mode": False, "error_rate": 0.2, - "recent_errors": 3 + "recent_errors": 3, } - + logger.info(f"Health Status: {health_status}") - + logger.info("Error handling demonstration completed") @@ -645,13 +702,15 @@ async def demonstrate_error_handling(): print("šŸ›”ļø FACT SDK - Resilient Error Handling Example") print("=" * 50) print() - print("This example demonstrates comprehensive error handling for Arcade.dev integration:") + print( + "This example demonstrates comprehensive error handling for Arcade.dev integration:" + ) print("• Error classification and routing") print("• Intelligent retry strategies") print("• Circuit breaker patterns") print("• Graceful degradation") print("• Comprehensive logging and monitoring") print() - + # Run demonstration - asyncio.run(demonstrate_error_handling()) \ No newline at end of file + asyncio.run(demonstrate_error_handling()) diff --git a/examples/arcade-dev/05_cache_integration/cached_arcade_client.py b/examples/arcade-dev/05_cache_integration/cached_arcade_client.py index 15dd915..bf18588 100644 --- a/examples/arcade-dev/05_cache_integration/cached_arcade_client.py +++ b/examples/arcade-dev/05_cache_integration/cached_arcade_client.py @@ -36,17 +36,18 @@ @dataclass class CacheStrategy: """Cache strategy configuration.""" + strategy_name: str ttl_seconds: int max_entries: int = 1000 compression_enabled: bool = False encryption_enabled: bool = False invalidation_pattern: str = "time" # time, lru, custom - + # Performance thresholds hit_rate_threshold: float = 0.8 response_time_threshold_ms: float = 100 - + # Advanced features prefetch_enabled: bool = False background_refresh_enabled: bool = False @@ -56,670 +57,734 @@ class CacheStrategy: @dataclass class CachePerformanceMetrics: """Cache performance tracking.""" + hits: int = 0 misses: int = 0 sets: int = 0 evictions: int = 0 errors: int = 0 - + # Timing metrics total_hit_time_ms: float = 0 total_miss_time_ms: float = 0 total_set_time_ms: float = 0 - + # Size metrics cache_size_bytes: int = 0 avg_entry_size_bytes: float = 0 - + # Advanced metrics prefetch_hits: int = 0 background_refreshes: int = 0 - + def get_hit_rate(self) -> float: total_requests = self.hits + self.misses return (self.hits / total_requests) if total_requests > 0 else 0.0 - + def get_avg_hit_time_ms(self) -> float: return (self.total_hit_time_ms / self.hits) if self.hits > 0 else 0.0 - + def get_avg_miss_time_ms(self) -> float: return (self.total_miss_time_ms / self.misses) if self.misses > 0 else 0.0 class HybridCacheManager: """Advanced cache manager with multiple cache levels and strategies.""" - - def __init__(self, primary_cache: CacheManager, strategies: Dict[str, CacheStrategy]): + + def __init__( + self, primary_cache: CacheManager, strategies: Dict[str, CacheStrategy] + ): self.primary_cache = primary_cache self.strategies = strategies self.logger = logging.getLogger(f"{__name__}.HybridCacheManager") - + # In-memory cache for hot data self.memory_cache: Dict[str, Dict[str, Any]] = {} self.memory_cache_access_times: Dict[str, datetime] = {} self.memory_cache_max_size = 100 - + # Performance tracking self.metrics: Dict[str, CachePerformanceMetrics] = { strategy_name: CachePerformanceMetrics() for strategy_name in strategies.keys() } - + # Cache key prefetching self.prefetch_queue: List[str] = [] self.prefetch_task: Optional[asyncio.Task] = None - + async def get(self, key: str, strategy_name: str = "default") -> Optional[Any]: """Get value with hybrid caching strategy.""" start_time = time.time() strategy = self.strategies.get(strategy_name, self.strategies["default"]) metrics = self.metrics[strategy_name] - + try: # Level 1: Memory cache (fastest) if strategy.multi_level_caching and key in self.memory_cache: self.memory_cache_access_times[key] = datetime.now(timezone.utc) - + cache_data = self.memory_cache[key] if self._is_cache_entry_valid(cache_data): metrics.hits += 1 metrics.total_hit_time_ms += (time.time() - start_time) * 1000 self.logger.debug(f"Memory cache hit for key: {key}") - return cache_data['data'] + return cache_data["data"] else: # Remove expired entry del self.memory_cache[key] del self.memory_cache_access_times[key] - + # Level 2: Primary cache (persistent) cache_entry = self.primary_cache.get(key) - - if cache_entry and self._is_cache_entry_valid({'cached_at': cache_entry.created_at, 'ttl': self.strategies[strategy_name].ttl_seconds}): + + if cache_entry and self._is_cache_entry_valid( + { + "cached_at": cache_entry.created_at, + "ttl": self.strategies[strategy_name].ttl_seconds, + } + ): # Cache hit - extract data from cache entry content try: cache_data = json.loads(cache_entry.content) metrics.hits += 1 metrics.total_hit_time_ms += (time.time() - start_time) * 1000 - + # Promote to memory cache if multi-level enabled if strategy.multi_level_caching: await self._promote_to_memory_cache(key, cache_data, strategy) - + self.logger.debug(f"Primary cache hit for key: {key}") - return cache_data.get('data', cache_data) + return cache_data.get("data", cache_data) except json.JSONDecodeError: self.logger.warning(f"Invalid JSON in cache entry for key: {key}") - + # Cache miss metrics.misses += 1 metrics.total_miss_time_ms += (time.time() - start_time) * 1000 self.logger.debug(f"Cache miss for key: {key}") - + # Schedule prefetch if enabled if strategy.prefetch_enabled: await self._schedule_prefetch(key, strategy_name) - + return None - + except Exception as e: metrics.errors += 1 self.logger.error(f"Cache get error for key {key}: {e}") return None - - async def set(self, key: str, value: Any, strategy_name: str = "default", - custom_ttl: int = None) -> bool: + + async def set( + self, + key: str, + value: Any, + strategy_name: str = "default", + custom_ttl: int = None, + ) -> bool: """Set value with caching strategy.""" start_time = time.time() strategy = self.strategies.get(strategy_name, self.strategies["default"]) metrics = self.metrics[strategy_name] - + try: ttl = custom_ttl or strategy.ttl_seconds - + # Prepare cache data with metadata cache_data = { - 'data': value, - 'cached_at': time.time(), - 'ttl': ttl, - 'strategy': strategy_name, - 'access_count': 0, - 'size_bytes': len(json.dumps(value, default=str)) + "data": value, + "cached_at": time.time(), + "ttl": ttl, + "strategy": strategy_name, + "access_count": 0, + "size_bytes": len(json.dumps(value, default=str)), } - + # Apply compression if enabled if strategy.compression_enabled: cache_data = await self._compress_cache_data(cache_data) - + # Apply encryption if enabled if strategy.encryption_enabled: cache_data = await self._encrypt_cache_data(cache_data) - + # Store in primary cache using store method try: - cache_entry = self.primary_cache.store(key, json.dumps(cache_data, default=str)) + cache_entry = self.primary_cache.store( + key, json.dumps(cache_data, default=str) + ) success = True except Exception as e: self.logger.error(f"Primary cache store failed: {e}") success = False - + if success: metrics.sets += 1 metrics.total_set_time_ms += (time.time() - start_time) * 1000 - metrics.cache_size_bytes += cache_data.get('size_bytes', 0) - + metrics.cache_size_bytes += cache_data.get("size_bytes", 0) + # Update average entry size if metrics.sets > 0: - metrics.avg_entry_size_bytes = metrics.cache_size_bytes / metrics.sets - + metrics.avg_entry_size_bytes = ( + metrics.cache_size_bytes / metrics.sets + ) + # Set in memory cache if multi-level enabled if strategy.multi_level_caching: await self._set_memory_cache(key, cache_data, strategy) - + self.logger.debug(f"Cache set successful for key: {key}") - + # Schedule background refresh if enabled if strategy.background_refresh_enabled: await self._schedule_background_refresh(key, strategy_name, ttl) - + return True else: metrics.errors += 1 return False - + except Exception as e: metrics.errors += 1 self.logger.error(f"Cache set error for key {key}: {e}") return False - + async def invalidate(self, pattern: str, strategy_name: str = "default") -> int: """Invalidate cache entries matching pattern.""" strategy = self.strategies.get(strategy_name, self.strategies["default"]) metrics = self.metrics[strategy_name] - + try: # Invalidate from memory cache invalidated_count = 0 - + if strategy.multi_level_caching: keys_to_remove = [] for key in self.memory_cache.keys(): if self._matches_pattern(key, pattern): keys_to_remove.append(key) - + for key in keys_to_remove: del self.memory_cache[key] if key in self.memory_cache_access_times: del self.memory_cache_access_times[key] invalidated_count += 1 - + # Invalidate from primary cache # Note: This is a simplified implementation # In production, you'd need pattern-based invalidation in the cache manager - + metrics.evictions += invalidated_count - self.logger.info(f"Invalidated {invalidated_count} entries matching pattern: {pattern}") + self.logger.info( + f"Invalidated {invalidated_count} entries matching pattern: {pattern}" + ) return invalidated_count - + except Exception as e: metrics.errors += 1 self.logger.error(f"Cache invalidation error for pattern {pattern}: {e}") return 0 - + async def optimize_cache(self, strategy_name: str = "default"): """Optimize cache performance based on usage patterns.""" strategy = self.strategies.get(strategy_name, self.strategies["default"]) metrics = self.metrics[strategy_name] - + try: # Check if optimization is needed hit_rate = metrics.get_hit_rate() - + if hit_rate < strategy.hit_rate_threshold: - self.logger.warning(f"Cache hit rate ({hit_rate:.2%}) below threshold ({strategy.hit_rate_threshold:.2%})") - + self.logger.warning( + f"Cache hit rate ({hit_rate:.2%}) below threshold ({strategy.hit_rate_threshold:.2%})" + ) + # Optimization strategies await self._optimize_memory_cache_size(strategy) await self._optimize_ttl_values(strategy_name) await self._cleanup_expired_entries(strategy_name) - - self.logger.info(f"Cache optimization completed for strategy: {strategy_name}") - + + self.logger.info( + f"Cache optimization completed for strategy: {strategy_name}" + ) + except Exception as e: self.logger.error(f"Cache optimization error: {e}") - + def _is_cache_entry_valid(self, cache_data: Dict[str, Any]) -> bool: """Check if cache entry is still valid.""" if not isinstance(cache_data, dict): return True # Assume simple values are valid - - cached_at = cache_data.get('cached_at', 0) - ttl = cache_data.get('ttl', 3600) - + + cached_at = cache_data.get("cached_at", 0) + ttl = cache_data.get("ttl", 3600) + return (time.time() - cached_at) < ttl - - async def _promote_to_memory_cache(self, key: str, cache_data: Dict[str, Any], strategy: CacheStrategy): + + async def _promote_to_memory_cache( + self, key: str, cache_data: Dict[str, Any], strategy: CacheStrategy + ): """Promote frequently accessed data to memory cache.""" # Check if memory cache has space if len(self.memory_cache) >= self.memory_cache_max_size: await self._evict_lru_memory_cache() - + self.memory_cache[key] = cache_data self.memory_cache_access_times[key] = datetime.utcnow() - - async def _set_memory_cache(self, key: str, cache_data: Dict[str, Any], strategy: CacheStrategy): + + async def _set_memory_cache( + self, key: str, cache_data: Dict[str, Any], strategy: CacheStrategy + ): """Set data in memory cache.""" if len(self.memory_cache) >= self.memory_cache_max_size: await self._evict_lru_memory_cache() - + self.memory_cache[key] = cache_data self.memory_cache_access_times[key] = datetime.now(timezone.utc) - + async def _evict_lru_memory_cache(self): """Evict least recently used entry from memory cache.""" if not self.memory_cache_access_times: return - + # Find LRU entry - lru_key = min(self.memory_cache_access_times.keys(), - key=lambda k: self.memory_cache_access_times[k]) - + lru_key = min( + self.memory_cache_access_times.keys(), + key=lambda k: self.memory_cache_access_times[k], + ) + # Remove LRU entry del self.memory_cache[lru_key] del self.memory_cache_access_times[lru_key] - + async def _schedule_prefetch(self, key: str, strategy_name: str): """Schedule prefetching for related keys.""" if key not in self.prefetch_queue: self.prefetch_queue.append(key) - + if not self.prefetch_task or self.prefetch_task.done(): self.prefetch_task = asyncio.create_task(self._process_prefetch_queue()) - + async def _process_prefetch_queue(self): """Process prefetch queue in background.""" while self.prefetch_queue: key = self.prefetch_queue.pop(0) # In a real implementation, you'd prefetch related data await asyncio.sleep(0.1) # Simulate prefetch work - - async def _schedule_background_refresh(self, key: str, strategy_name: str, ttl: int): + + async def _schedule_background_refresh( + self, key: str, strategy_name: str, ttl: int + ): """Schedule background refresh before TTL expires.""" refresh_time = ttl * 0.8 # Refresh at 80% of TTL await asyncio.sleep(refresh_time) - + # In a real implementation, you'd refresh the data from source self.logger.debug(f"Background refresh triggered for key: {key}") - + async def _compress_cache_data(self, cache_data: Dict[str, Any]) -> Dict[str, Any]: """Compress cache data to save space.""" # Simplified compression simulation - cache_data['compressed'] = True - cache_data['size_bytes'] = int(cache_data['size_bytes'] * 0.7) # 30% compression + cache_data["compressed"] = True + cache_data["size_bytes"] = int( + cache_data["size_bytes"] * 0.7 + ) # 30% compression return cache_data - + async def _encrypt_cache_data(self, cache_data: Dict[str, Any]) -> Dict[str, Any]: """Encrypt sensitive cache data.""" # Simplified encryption simulation - cache_data['encrypted'] = True + cache_data["encrypted"] = True return cache_data - + def _matches_pattern(self, key: str, pattern: str) -> bool: """Check if key matches invalidation pattern.""" # Simplified pattern matching if pattern == "*": return True return pattern in key - + async def _optimize_memory_cache_size(self, strategy: CacheStrategy): """Optimize memory cache size based on hit patterns.""" current_hit_rate = self.metrics[strategy.strategy_name].get_hit_rate() - + if current_hit_rate < 0.5: # Increase memory cache size self.memory_cache_max_size = min(self.memory_cache_max_size * 2, 500) elif current_hit_rate > 0.9: # Decrease memory cache size self.memory_cache_max_size = max(self.memory_cache_max_size // 2, 50) - - self.logger.debug(f"Optimized memory cache size to: {self.memory_cache_max_size}") - + + self.logger.debug( + f"Optimized memory cache size to: {self.memory_cache_max_size}" + ) + async def _optimize_ttl_values(self, strategy_name: str): """Optimize TTL values based on access patterns.""" # In a real implementation, analyze access patterns and adjust TTLs self.logger.debug(f"TTL optimization completed for strategy: {strategy_name}") - + async def _cleanup_expired_entries(self, strategy_name: str): """Clean up expired entries from memory cache.""" expired_keys = [] - + for key, cache_data in self.memory_cache.items(): if not self._is_cache_entry_valid(cache_data): expired_keys.append(key) - + for key in expired_keys: del self.memory_cache[key] if key in self.memory_cache_access_times: del self.memory_cache_access_times[key] - + self.metrics[strategy_name].evictions += len(expired_keys) self.logger.debug(f"Cleaned up {len(expired_keys)} expired entries") - + def get_performance_report(self) -> Dict[str, Any]: """Generate comprehensive performance report.""" report = { - 'timestamp': datetime.now(timezone.utc).isoformat(), - 'strategies': {}, - 'memory_cache': { - 'size': len(self.memory_cache), - 'max_size': self.memory_cache_max_size, - 'utilization': len(self.memory_cache) / self.memory_cache_max_size - } + "timestamp": datetime.now(timezone.utc).isoformat(), + "strategies": {}, + "memory_cache": { + "size": len(self.memory_cache), + "max_size": self.memory_cache_max_size, + "utilization": len(self.memory_cache) / self.memory_cache_max_size, + }, } - + for strategy_name, metrics in self.metrics.items(): strategy_report = { - 'hit_rate': metrics.get_hit_rate(), - 'total_requests': metrics.hits + metrics.misses, - 'cache_hits': metrics.hits, - 'cache_misses': metrics.misses, - 'cache_sets': metrics.sets, - 'cache_errors': metrics.errors, - 'avg_hit_time_ms': metrics.get_avg_hit_time_ms(), - 'avg_miss_time_ms': metrics.get_avg_miss_time_ms(), - 'cache_size_bytes': metrics.cache_size_bytes, - 'avg_entry_size_bytes': metrics.avg_entry_size_bytes + "hit_rate": metrics.get_hit_rate(), + "total_requests": metrics.hits + metrics.misses, + "cache_hits": metrics.hits, + "cache_misses": metrics.misses, + "cache_sets": metrics.sets, + "cache_errors": metrics.errors, + "avg_hit_time_ms": metrics.get_avg_hit_time_ms(), + "avg_miss_time_ms": metrics.get_avg_miss_time_ms(), + "cache_size_bytes": metrics.cache_size_bytes, + "avg_entry_size_bytes": metrics.avg_entry_size_bytes, } - - report['strategies'][strategy_name] = strategy_report - + + report["strategies"][strategy_name] = strategy_report + return report class AdvancedArcadeClient: """Advanced Arcade.dev client with sophisticated caching.""" - + def __init__(self, api_key: str, cache_manager: CacheManager): self.api_key = api_key self.logger = logging.getLogger(__name__) - + # Setup cache strategies self.cache_strategies = { "default": CacheStrategy( strategy_name="default", ttl_seconds=3600, max_entries=1000, - multi_level_caching=True + multi_level_caching=True, ), "fast": CacheStrategy( strategy_name="fast", ttl_seconds=300, max_entries=100, multi_level_caching=True, - prefetch_enabled=True + prefetch_enabled=True, ), "persistent": CacheStrategy( strategy_name="persistent", ttl_seconds=86400, # 24 hours max_entries=5000, compression_enabled=True, - background_refresh_enabled=True + background_refresh_enabled=True, ), "secure": CacheStrategy( strategy_name="secure", ttl_seconds=1800, max_entries=500, encryption_enabled=True, - multi_level_caching=False - ) + multi_level_caching=False, + ), } - + # Initialize hybrid cache manager self.hybrid_cache = HybridCacheManager(cache_manager, self.cache_strategies) - + def _generate_cache_key(self, operation: str, **kwargs) -> str: """Generate cache key for operation.""" params_str = json.dumps(kwargs, sort_keys=True, default=str) params_hash = hashlib.sha256(params_str.encode()).hexdigest()[:16] return f"arcade:{operation}:{params_hash}" - - async def cached_operation(self, operation: str, strategy: str = "default", - force_refresh: bool = False, **kwargs) -> Dict[str, Any]: + + async def cached_operation( + self, + operation: str, + strategy: str = "default", + force_refresh: bool = False, + **kwargs, + ) -> Dict[str, Any]: """Execute operation with advanced caching.""" cache_key = self._generate_cache_key(operation, **kwargs) - + # Check cache first if not force_refresh: cached_result = await self.hybrid_cache.get(cache_key, strategy) if cached_result: self.logger.debug(f"Cache hit for {operation} with strategy {strategy}") return cached_result - + # Execute operation start_time = time.time() result = await self._execute_operation(operation, **kwargs) execution_time = (time.time() - start_time) * 1000 - + # Add metadata - result['_execution_time_ms'] = execution_time - result['_cached'] = False - result['_timestamp'] = datetime.now(timezone.utc).isoformat() - + result["_execution_time_ms"] = execution_time + result["_cached"] = False + result["_timestamp"] = datetime.now(timezone.utc).isoformat() + # Cache result await self.hybrid_cache.set(cache_key, result, strategy) - + return result - + async def _execute_operation(self, operation: str, **kwargs) -> Dict[str, Any]: """Execute the actual operation (mock implementation).""" # Simulate different operation types operation_times = { - 'code_analysis': 2.0, - 'test_generation': 3.0, - 'documentation': 1.5, - 'refactoring': 2.5 + "code_analysis": 2.0, + "test_generation": 3.0, + "documentation": 1.5, + "refactoring": 2.5, } - + await asyncio.sleep(operation_times.get(operation, 1.0)) - + # Generate longer, more realistic responses to meet 500 token minimum detailed_results = { - 'code_analysis': { - 'summary': f"Comprehensive code analysis completed for {operation}", - 'complexity_score': 7.5, - 'maintainability_index': 85.2, - 'cyclomatic_complexity': 12, - 'lines_of_code': 342, - 'functions_analyzed': 15, - 'classes_analyzed': 3, - 'issues_found': [ - {'type': 'warning', 'line': 45, 'message': 'Consider breaking down this complex function'}, - {'type': 'info', 'line': 78, 'message': 'Variable naming could be more descriptive'}, - {'type': 'suggestion', 'line': 92, 'message': 'Consider using list comprehension here'} + "code_analysis": { + "summary": f"Comprehensive code analysis completed for {operation}", + "complexity_score": 7.5, + "maintainability_index": 85.2, + "cyclomatic_complexity": 12, + "lines_of_code": 342, + "functions_analyzed": 15, + "classes_analyzed": 3, + "issues_found": [ + { + "type": "warning", + "line": 45, + "message": "Consider breaking down this complex function", + }, + { + "type": "info", + "line": 78, + "message": "Variable naming could be more descriptive", + }, + { + "type": "suggestion", + "line": 92, + "message": "Consider using list comprehension here", + }, ], - 'suggestions': [ - 'Implement proper error handling in the main processing loop', - 'Add type hints to improve code readability and IDE support', - 'Consider extracting utility functions to reduce code duplication', - 'Implement comprehensive logging for better debugging capabilities' + "suggestions": [ + "Implement proper error handling in the main processing loop", + "Add type hints to improve code readability and IDE support", + "Consider extracting utility functions to reduce code duplication", + "Implement comprehensive logging for better debugging capabilities", ], - 'metrics': { - 'code_coverage': 87.5, - 'test_coverage': 92.1, - 'documentation_coverage': 78.3, - 'performance_score': 8.7 + "metrics": { + "code_coverage": 87.5, + "test_coverage": 92.1, + "documentation_coverage": 78.3, + "performance_score": 8.7, + }, + "dependencies": [ + "requests", + "asyncio", + "json", + "pathlib", + "dataclasses", + ], + "security_analysis": { + "vulnerabilities_found": 0, + "security_score": 9.2, + "recommendations": [ + "Update dependencies to latest versions", + "Implement input validation", + ], }, - 'dependencies': ['requests', 'asyncio', 'json', 'pathlib', 'dataclasses'], - 'security_analysis': { - 'vulnerabilities_found': 0, - 'security_score': 9.2, - 'recommendations': ['Update dependencies to latest versions', 'Implement input validation'] - } }, - 'test_generation': { - 'summary': f"Automated test generation completed for {operation}", - 'tests_generated': 23, - 'coverage_improvement': 15.7, - 'test_types': ['unit', 'integration', 'edge_cases'], - 'generated_tests': [ + "test_generation": { + "summary": f"Automated test generation completed for {operation}", + "tests_generated": 23, + "coverage_improvement": 15.7, + "test_types": ["unit", "integration", "edge_cases"], + "generated_tests": [ { - 'name': 'test_basic_functionality', - 'type': 'unit', - 'description': 'Tests basic function behavior with valid inputs', - 'expected_coverage': 85.2 + "name": "test_basic_functionality", + "type": "unit", + "description": "Tests basic function behavior with valid inputs", + "expected_coverage": 85.2, }, { - 'name': 'test_edge_cases', - 'type': 'edge_case', - 'description': 'Tests function behavior with boundary conditions and edge cases', - 'expected_coverage': 92.8 + "name": "test_edge_cases", + "type": "edge_case", + "description": "Tests function behavior with boundary conditions and edge cases", + "expected_coverage": 92.8, }, { - 'name': 'test_error_handling', - 'type': 'error', - 'description': 'Tests proper error handling and exception management', - 'expected_coverage': 78.5 - } + "name": "test_error_handling", + "type": "error", + "description": "Tests proper error handling and exception management", + "expected_coverage": 78.5, + }, ], - 'recommendations': [ - 'Add property-based tests for more comprehensive coverage', - 'Include performance benchmarks for critical functions', - 'Implement integration tests for external dependencies', - 'Add mock tests for API interactions' + "recommendations": [ + "Add property-based tests for more comprehensive coverage", + "Include performance benchmarks for critical functions", + "Implement integration tests for external dependencies", + "Add mock tests for API interactions", ], - 'quality_metrics': { - 'test_readability': 9.1, - 'maintainability': 8.8, - 'execution_speed': 7.9, - 'reliability': 9.3 - } + "quality_metrics": { + "test_readability": 9.1, + "maintainability": 8.8, + "execution_speed": 7.9, + "reliability": 9.3, + }, }, - 'documentation': { - 'summary': f"Documentation generation completed for {operation}", - 'pages_generated': 12, - 'sections_created': ['API Reference', 'User Guide', 'Examples', 'FAQ'], - 'documentation_structure': { - 'getting_started': 'Complete setup and installation guide', - 'api_reference': 'Detailed API documentation with examples', - 'tutorials': 'Step-by-step tutorials for common use cases', - 'troubleshooting': 'Common issues and their solutions' + "documentation": { + "summary": f"Documentation generation completed for {operation}", + "pages_generated": 12, + "sections_created": ["API Reference", "User Guide", "Examples", "FAQ"], + "documentation_structure": { + "getting_started": "Complete setup and installation guide", + "api_reference": "Detailed API documentation with examples", + "tutorials": "Step-by-step tutorials for common use cases", + "troubleshooting": "Common issues and their solutions", }, - 'content_analysis': { - 'readability_score': 8.7, - 'completeness': 91.2, - 'accuracy': 94.8, - 'usefulness': 8.9 + "content_analysis": { + "readability_score": 8.7, + "completeness": 91.2, + "accuracy": 94.8, + "usefulness": 8.9, }, - 'generated_content': [ - 'Function docstrings with parameter descriptions and return types', - 'Class documentation with usage examples', - 'Module-level documentation explaining purpose and structure', - 'README sections with installation and basic usage instructions' + "generated_content": [ + "Function docstrings with parameter descriptions and return types", + "Class documentation with usage examples", + "Module-level documentation explaining purpose and structure", + "README sections with installation and basic usage instructions", + ], + "improvements_suggested": [ + "Add more code examples to illustrate complex concepts", + "Include visual diagrams for architecture overview", + "Expand troubleshooting section with common error scenarios", + "Add FAQ section based on user feedback", ], - 'improvements_suggested': [ - 'Add more code examples to illustrate complex concepts', - 'Include visual diagrams for architecture overview', - 'Expand troubleshooting section with common error scenarios', - 'Add FAQ section based on user feedback' - ] }, - 'refactoring': { - 'summary': f"Code refactoring analysis completed for {operation}", - 'refactoring_opportunities': 18, - 'complexity_reduction': 23.5, - 'performance_improvement': 12.8, - 'suggested_changes': [ + "refactoring": { + "summary": f"Code refactoring analysis completed for {operation}", + "refactoring_opportunities": 18, + "complexity_reduction": 23.5, + "performance_improvement": 12.8, + "suggested_changes": [ { - 'type': 'extract_method', - 'location': 'lines 45-67', - 'description': 'Extract complex logic into separate method for better readability', - 'impact': 'high' + "type": "extract_method", + "location": "lines 45-67", + "description": "Extract complex logic into separate method for better readability", + "impact": "high", }, { - 'type': 'remove_duplication', - 'location': 'multiple locations', - 'description': 'Consolidate duplicate code patterns into reusable functions', - 'impact': 'medium' + "type": "remove_duplication", + "location": "multiple locations", + "description": "Consolidate duplicate code patterns into reusable functions", + "impact": "medium", }, { - 'type': 'simplify_conditionals', - 'location': 'lines 89-102', - 'description': 'Simplify nested conditional statements using early returns', - 'impact': 'medium' - } + "type": "simplify_conditionals", + "location": "lines 89-102", + "description": "Simplify nested conditional statements using early returns", + "impact": "medium", + }, ], - 'code_quality_improvements': { - 'maintainability_score_increase': 15.3, - 'readability_improvement': 22.7, - 'testability_enhancement': 18.9, - 'performance_optimization': 8.4 + "code_quality_improvements": { + "maintainability_score_increase": 15.3, + "readability_improvement": 22.7, + "testability_enhancement": 18.9, + "performance_optimization": 8.4, }, - 'design_patterns_suggested': [ - 'Strategy pattern for algorithm selection', - 'Factory pattern for object creation', - 'Observer pattern for event handling' + "design_patterns_suggested": [ + "Strategy pattern for algorithm selection", + "Factory pattern for object creation", + "Observer pattern for event handling", ], - 'best_practices_recommendations': [ - 'Implement proper dependency injection', - 'Use configuration files for magic numbers', - 'Add comprehensive error handling', - 'Implement proper logging throughout the application' - ] - } + "best_practices_recommendations": [ + "Implement proper dependency injection", + "Use configuration files for magic numbers", + "Add comprehensive error handling", + "Implement proper logging throughout the application", + ], + }, } - + # Add padding to ensure we meet 500 token minimum for caching - padding_text = " ".join([ - "This is additional padding content to ensure the response meets the minimum token requirement for caching.", - "The cache system requires at least 500 tokens to store responses effectively.", - "This padding includes various technical details and explanations about the operation.", - "Performance metrics indicate optimal execution patterns with efficient resource utilization.", - "The system maintains high availability and reliability standards throughout the operation lifecycle.", - "Comprehensive monitoring and logging capabilities provide detailed insights into system behavior.", - "Advanced error handling mechanisms ensure robust operation under various conditions.", - "Security protocols maintain data integrity and prevent unauthorized access attempts.", - "Scalability features allow the system to handle increased load and concurrent operations.", - "Integration capabilities enable seamless interaction with external systems and APIs.", - "Configuration management provides flexible deployment options across different environments.", - "Testing frameworks ensure quality assurance and prevent regression issues.", - "Documentation standards maintain clear and comprehensive technical specifications.", - "Version control systems track changes and enable collaborative development workflows.", - "Deployment pipelines automate the release process and ensure consistent deployments.", - "Backup and recovery procedures protect against data loss and system failures.", - "Performance optimization techniques improve response times and resource efficiency.", - "User experience considerations guide interface design and interaction patterns.", - "Accessibility standards ensure inclusive design for all user groups.", - "Internationalization features support multiple languages and regional preferences." - ]) - + padding_text = " ".join( + [ + "This is additional padding content to ensure the response meets the minimum token requirement for caching.", + "The cache system requires at least 500 tokens to store responses effectively.", + "This padding includes various technical details and explanations about the operation.", + "Performance metrics indicate optimal execution patterns with efficient resource utilization.", + "The system maintains high availability and reliability standards throughout the operation lifecycle.", + "Comprehensive monitoring and logging capabilities provide detailed insights into system behavior.", + "Advanced error handling mechanisms ensure robust operation under various conditions.", + "Security protocols maintain data integrity and prevent unauthorized access attempts.", + "Scalability features allow the system to handle increased load and concurrent operations.", + "Integration capabilities enable seamless interaction with external systems and APIs.", + "Configuration management provides flexible deployment options across different environments.", + "Testing frameworks ensure quality assurance and prevent regression issues.", + "Documentation standards maintain clear and comprehensive technical specifications.", + "Version control systems track changes and enable collaborative development workflows.", + "Deployment pipelines automate the release process and ensure consistent deployments.", + "Backup and recovery procedures protect against data loss and system failures.", + "Performance optimization techniques improve response times and resource efficiency.", + "User experience considerations guide interface design and interaction patterns.", + "Accessibility standards ensure inclusive design for all user groups.", + "Internationalization features support multiple languages and regional preferences.", + ] + ) + return { - 'operation': operation, - 'status': 'success', - 'result': detailed_results.get(operation, f"Detailed result for {operation}"), - 'parameters': kwargs, - 'execution_metadata': { - 'timestamp': datetime.now(timezone.utc).isoformat(), - 'execution_id': hashlib.sha256(f"{operation}_{time.time()}".encode()).hexdigest()[:16], - 'processing_time_seconds': operation_times.get(operation, 1.0), - 'resource_usage': { - 'memory_mb': 45.7, - 'cpu_percent': 23.4 - } + "operation": operation, + "status": "success", + "result": detailed_results.get( + operation, f"Detailed result for {operation}" + ), + "parameters": kwargs, + "execution_metadata": { + "timestamp": datetime.now(timezone.utc).isoformat(), + "execution_id": hashlib.sha256( + f"{operation}_{time.time()}".encode() + ).hexdigest()[:16], + "processing_time_seconds": operation_times.get(operation, 1.0), + "resource_usage": {"memory_mb": 45.7, "cpu_percent": 23.4}, + }, + "padding_for_cache_minimum": padding_text, + "technical_notes": { + "cache_strategy": "Hybrid caching with multi-tier storage optimization", + "performance_profile": "Optimized for high-throughput and low-latency operations", + "reliability_measures": "Comprehensive error handling and fallback mechanisms", + "security_features": "End-to-end encryption and secure authentication protocols", + "monitoring_capabilities": "Real-time metrics collection and alerting systems", }, - 'padding_for_cache_minimum': padding_text, - 'technical_notes': { - 'cache_strategy': 'Hybrid caching with multi-tier storage optimization', - 'performance_profile': 'Optimized for high-throughput and low-latency operations', - 'reliability_measures': 'Comprehensive error handling and fallback mechanisms', - 'security_features': 'End-to-end encryption and secure authentication protocols', - 'monitoring_capabilities': 'Real-time metrics collection and alerting systems' - } } - + async def optimize_caching(self): """Optimize all cache strategies.""" for strategy_name in self.cache_strategies.keys(): await self.hybrid_cache.optimize_cache(strategy_name) - + def get_cache_analytics(self) -> Dict[str, Any]: """Get comprehensive cache analytics.""" return self.hybrid_cache.get_performance_report() @@ -729,24 +794,28 @@ async def demonstrate_advanced_caching(): """Demonstrate advanced caching capabilities.""" print("šŸš€ Advanced Cache Integration Demo") print("=" * 50) - + # Initialize cache manager with proper configuration cache_config = get_default_cache_config() cache_manager = CacheManager(cache_config) - + # Create advanced client client = AdvancedArcadeClient("demo_api_key", cache_manager) - + print("\nšŸ“Š Testing different cache strategies...") - + # Test operations with different strategies operations = [ - ("code_analysis", "fast", {"code": "def hello(): print('world')", "language": "python"}), + ( + "code_analysis", + "fast", + {"code": "def hello(): print('world')", "language": "python"}, + ), ("test_generation", "default", {"code": "def add(a, b): return a + b"}), ("documentation", "persistent", {"code": "class MyClass: pass"}), - ("refactoring", "secure", {"code": "legacy_code_here"}) + ("refactoring", "secure", {"code": "legacy_code_here"}), ] - + # First run (cache miss) print("\nšŸ”„ First execution (cache miss):") for operation, strategy, params in operations: @@ -754,7 +823,7 @@ async def demonstrate_advanced_caching(): result = await client.cached_operation(operation, strategy, **params) duration = (time.time() - start_time) * 1000 print(f" {operation} ({strategy}): {duration:.1f}ms") - + # Second run (cache hit) print("\n⚔ Second execution (cache hit):") for operation, strategy, params in operations: @@ -762,27 +831,27 @@ async def demonstrate_advanced_caching(): result = await client.cached_operation(operation, strategy, **params) duration = (time.time() - start_time) * 1000 print(f" {operation} ({strategy}): {duration:.1f}ms") - + # Cache optimization print("\nšŸ”§ Optimizing cache performance...") await client.optimize_caching() - + # Analytics print("\nšŸ“ˆ Cache Performance Analytics:") analytics = client.get_cache_analytics() - - for strategy_name, metrics in analytics['strategies'].items(): + + for strategy_name, metrics in analytics["strategies"].items(): print(f"\n {strategy_name.upper()} Strategy:") print(f" Hit Rate: {metrics['hit_rate']:.1%}") print(f" Total Requests: {metrics['total_requests']}") print(f" Avg Hit Time: {metrics['avg_hit_time_ms']:.1f}ms") print(f" Cache Size: {metrics['cache_size_bytes']} bytes") - + print(f"\n Memory Cache:") - memory_stats = analytics['memory_cache'] + memory_stats = analytics["memory_cache"] print(f" Utilization: {memory_stats['utilization']:.1%}") print(f" Size: {memory_stats['size']}/{memory_stats['max_size']}") - + print("\nšŸŽ‰ Advanced caching demonstration completed!") @@ -790,9 +859,9 @@ async def main(): """Main demonstration function.""" logging.basicConfig( level=logging.INFO, - format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", ) - + try: await demonstrate_advanced_caching() return 0 @@ -803,4 +872,4 @@ async def main(): if __name__ == "__main__": exit_code = asyncio.run(main()) - sys.exit(exit_code) \ No newline at end of file + sys.exit(exit_code) diff --git a/examples/arcade-dev/05_cache_integration/cached_arcade_client_enhanced.py b/examples/arcade-dev/05_cache_integration/cached_arcade_client_enhanced.py index f40d517..592d3fe 100644 --- a/examples/arcade-dev/05_cache_integration/cached_arcade_client_enhanced.py +++ b/examples/arcade-dev/05_cache_integration/cached_arcade_client_enhanced.py @@ -46,16 +46,17 @@ @dataclass class CacheStrategy: """Cache strategy configuration.""" + strategy_name: str ttl_seconds: int max_entries: int = 1000 compression_enabled: bool = False encryption_enabled: bool = False - + # Performance thresholds hit_rate_threshold: float = 0.8 response_time_threshold_ms: float = 100 - + # Advanced features prefetch_enabled: bool = False background_refresh_enabled: bool = False @@ -64,54 +65,57 @@ class CacheStrategy: @dataclass class CachePerformanceMetrics: """Cache performance tracking.""" + hits: int = 0 misses: int = 0 sets: int = 0 errors: int = 0 - + # Timing metrics total_hit_time_ms: float = 0 total_miss_time_ms: float = 0 total_set_time_ms: float = 0 - + # Size metrics cache_size_bytes: int = 0 - + def get_hit_rate(self) -> float: total_requests = self.hits + self.misses return (self.hits / total_requests) if total_requests > 0 else 0.0 - + def get_avg_hit_time_ms(self) -> float: return (self.total_hit_time_ms / self.hits) if self.hits > 0 else 0.0 - + def get_avg_miss_time_ms(self) -> float: return (self.total_miss_time_ms / self.misses) if self.misses > 0 else 0.0 class EnhancedCacheManager: """Enhanced cache manager with response padding for successful caching.""" - - def __init__(self, primary_cache: CacheManager, strategies: Dict[str, CacheStrategy]): + + def __init__( + self, primary_cache: CacheManager, strategies: Dict[str, CacheStrategy] + ): self.primary_cache = primary_cache self.strategies = strategies self.logger = logging.getLogger(f"{__name__}.EnhancedCacheManager") - + # Performance tracking self.metrics: Dict[str, CachePerformanceMetrics] = { strategy_name: CachePerformanceMetrics() for strategy_name in strategies.keys() } - + def get(self, key: str, strategy_name: str = "default") -> Optional[Any]: """Get value with caching strategy.""" start_time = time.time() strategy = self.strategies.get(strategy_name, self.strategies["default"]) metrics = self.metrics[strategy_name] - + try: # Get from primary cache cache_entry = self.primary_cache.get(key) - + if cache_entry: # Cache hit metrics.hits += 1 @@ -124,172 +128,175 @@ def get(self, key: str, strategy_name: str = "default") -> Optional[Any]: metrics.total_miss_time_ms += (time.time() - start_time) * 1000 self.logger.debug(f"Cache miss for key: {key}") return None - + except Exception as e: metrics.errors += 1 self.logger.error(f"Cache get error for key {key}: {e}") return None - + def set(self, key: str, value: Any, strategy_name: str = "default") -> bool: """Set value with caching strategy and response padding.""" start_time = time.time() strategy = self.strategies.get(strategy_name, self.strategies["default"]) metrics = self.metrics[strategy_name] - + try: # Convert value to string for padding if it's a dict/object if isinstance(value, dict): value_str = json.dumps(value, indent=2, default=str) else: value_str = str(value) - + # Apply response padding to meet cache requirements content_type = self._determine_content_type(value) enhanced_content = pad_response_for_caching( content=value_str, content_type=content_type, target_tokens=self.primary_cache.min_tokens, - preserve_format=True + preserve_format=True, ) - + # Store enhanced content in primary cache cache_entry = self.primary_cache.store(key, enhanced_content) - + if cache_entry: metrics.sets += 1 metrics.total_set_time_ms += (time.time() - start_time) * 1000 metrics.cache_size_bytes += cache_entry.size_bytes - - self.logger.info(f"Cache set successful for key: {key}, tokens: {cache_entry.token_count}") + + self.logger.info( + f"Cache set successful for key: {key}, tokens: {cache_entry.token_count}" + ) return True else: metrics.errors += 1 return False - + except Exception as e: metrics.errors += 1 self.logger.error(f"Cache set error for key {key}: {e}") return False - + def _determine_content_type(self, value: Any) -> str: """Determine content type for appropriate padding.""" if isinstance(value, dict): - if 'operation' in value: - operation = value.get('operation', '') - if 'sql' in operation.lower(): - return 'sql' - elif 'api' in operation.lower(): - return 'json' - return 'json' + if "operation" in value: + operation = value.get("operation", "") + if "sql" in operation.lower(): + return "sql" + elif "api" in operation.lower(): + return "json" + return "json" elif isinstance(value, str): - if 'SELECT' in value.upper() or 'INSERT' in value.upper(): - return 'sql' - elif value.strip().startswith('{') or value.strip().startswith('['): - return 'json' - return 'generic' - + if "SELECT" in value.upper() or "INSERT" in value.upper(): + return "sql" + elif value.strip().startswith("{") or value.strip().startswith("["): + return "json" + return "generic" + def optimize_cache(self, strategy_name: str = "default"): """Optimize cache performance based on usage patterns.""" strategy = self.strategies.get(strategy_name, self.strategies["default"]) metrics = self.metrics[strategy_name] - + try: # Check if optimization is needed hit_rate = metrics.get_hit_rate() - + if hit_rate < strategy.hit_rate_threshold: - self.logger.warning(f"Cache hit rate ({hit_rate:.2%}) below threshold ({strategy.hit_rate_threshold:.2%})") + self.logger.warning( + f"Cache hit rate ({hit_rate:.2%}) below threshold ({strategy.hit_rate_threshold:.2%})" + ) else: - self.logger.info(f"Cache hit rate ({hit_rate:.2%}) meets threshold ({strategy.hit_rate_threshold:.2%})") - - self.logger.info(f"Cache optimization completed for strategy: {strategy_name}") - + self.logger.info( + f"Cache hit rate ({hit_rate:.2%}) meets threshold ({strategy.hit_rate_threshold:.2%})" + ) + + self.logger.info( + f"Cache optimization completed for strategy: {strategy_name}" + ) + except Exception as e: self.logger.error(f"Cache optimization error: {e}") - + def get_performance_report(self) -> Dict[str, Any]: """Generate comprehensive performance report.""" - report = { - 'timestamp': datetime.now(timezone.utc).isoformat(), - 'strategies': {} - } - + report = {"timestamp": datetime.now(timezone.utc).isoformat(), "strategies": {}} + for strategy_name, metrics in self.metrics.items(): strategy_report = { - 'hit_rate': metrics.get_hit_rate(), - 'total_requests': metrics.hits + metrics.misses, - 'cache_hits': metrics.hits, - 'cache_misses': metrics.misses, - 'cache_sets': metrics.sets, - 'cache_errors': metrics.errors, - 'avg_hit_time_ms': metrics.get_avg_hit_time_ms(), - 'avg_miss_time_ms': metrics.get_avg_miss_time_ms(), - 'cache_size_bytes': metrics.cache_size_bytes + "hit_rate": metrics.get_hit_rate(), + "total_requests": metrics.hits + metrics.misses, + "cache_hits": metrics.hits, + "cache_misses": metrics.misses, + "cache_sets": metrics.sets, + "cache_errors": metrics.errors, + "avg_hit_time_ms": metrics.get_avg_hit_time_ms(), + "avg_miss_time_ms": metrics.get_avg_miss_time_ms(), + "cache_size_bytes": metrics.cache_size_bytes, } - - report['strategies'][strategy_name] = strategy_report - + + report["strategies"][strategy_name] = strategy_report + return report class AdvancedArcadeClient: """Advanced Arcade.dev client with sophisticated caching and response padding.""" - + def __init__(self, api_key: str, cache_manager: CacheManager): self.api_key = api_key self.logger = logging.getLogger(__name__) - + # Detect demo mode based on API key self.demo_mode = ( - not api_key or - api_key in ["demo_api_key", "test-key-for-testing"] or - api_key.startswith("demo_") or - api_key.startswith("test_") + not api_key + or api_key in ["demo_api_key", "test-key-for-testing"] + or api_key.startswith("demo_") + or api_key.startswith("test_") ) - + if self.demo_mode: self.logger.info("Running in DEMO MODE - no real API calls will be made") else: self.logger.info("Running with real API credentials") - + # Setup cache strategies self.cache_strategies = { - "default": CacheStrategy( - strategy_name="default", - ttl_seconds=3600 - ), + "default": CacheStrategy(strategy_name="default", ttl_seconds=3600), "fast": CacheStrategy( - strategy_name="fast", - ttl_seconds=300, - prefetch_enabled=True + strategy_name="fast", ttl_seconds=300, prefetch_enabled=True ), "persistent": CacheStrategy( strategy_name="persistent", ttl_seconds=86400, # 24 hours compression_enabled=True, - background_refresh_enabled=True + background_refresh_enabled=True, ), "secure": CacheStrategy( - strategy_name="secure", - ttl_seconds=1800, - encryption_enabled=True - ) + strategy_name="secure", ttl_seconds=1800, encryption_enabled=True + ), } - + # Initialize enhanced cache manager self.enhanced_cache = EnhancedCacheManager(cache_manager, self.cache_strategies) - + def _generate_cache_key(self, operation: str, **kwargs) -> str: """Generate cache key for operation.""" params_str = json.dumps(kwargs, sort_keys=True, default=str) params_hash = hashlib.sha256(params_str.encode()).hexdigest()[:16] return f"arcade:{operation}:{params_hash}" - - async def cached_operation(self, operation: str, strategy: str = "default", - force_refresh: bool = False, **kwargs) -> Dict[str, Any]: + + async def cached_operation( + self, + operation: str, + strategy: str = "default", + force_refresh: bool = False, + **kwargs, + ) -> Dict[str, Any]: """Execute operation with advanced caching and response padding.""" cache_key = self._generate_cache_key(operation, **kwargs) - + # Check cache first if not force_refresh: cached_result = self.enhanced_cache.get(cache_key, strategy) @@ -299,152 +306,201 @@ async def cached_operation(self, operation: str, strategy: str = "default", # Extract original result from the enhanced content try: # Look for the original JSON in the enhanced content - lines = cached_result.split('\n') + lines = cached_result.split("\n") for line in lines: - if line.strip().startswith('{') and 'operation' in line: + if line.strip().startswith("{") and "operation" in line: original_result = json.loads(line.strip()) - original_result['_cached'] = True - original_result['_enhanced_content'] = True + original_result["_cached"] = True + original_result["_enhanced_content"] = True return original_result except: pass - + # Fallback: return a structured response indicating cache hit return { - 'operation': operation, - 'status': 'success', - 'result': f"Cached result for {operation}", - 'parameters': kwargs, - '_cached': True, - '_enhanced_content': True, - '_timestamp': datetime.now(timezone.utc).isoformat() + "operation": operation, + "status": "success", + "result": f"Cached result for {operation}", + "parameters": kwargs, + "_cached": True, + "_enhanced_content": True, + "_timestamp": datetime.now(timezone.utc).isoformat(), } - + # Execute operation start_time = time.time() result = await self._execute_operation(operation, **kwargs) execution_time = (time.time() - start_time) * 1000 - + # Add metadata - result['_execution_time_ms'] = execution_time - result['_cached'] = False - result['_timestamp'] = datetime.now(timezone.utc).isoformat() - + result["_execution_time_ms"] = execution_time + result["_cached"] = False + result["_timestamp"] = datetime.now(timezone.utc).isoformat() + # Cache result with padding self.enhanced_cache.set(cache_key, result, strategy) - + return result - + async def _execute_operation(self, operation: str, **kwargs) -> Dict[str, Any]: """Execute the actual operation (mock implementation for demo).""" if self.demo_mode: # Simulate different operation types with realistic delays operation_times = { - 'code_analysis': 2.0, - 'test_generation': 3.0, - 'documentation': 1.5, - 'refactoring': 2.5 + "code_analysis": 2.0, + "test_generation": 3.0, + "documentation": 1.5, + "refactoring": 2.5, } - + await asyncio.sleep(operation_times.get(operation, 1.0)) - + # Generate comprehensive demo results demo_results = { - 'code_analysis': { - 'summary': f"Comprehensive analysis of {kwargs.get('language', 'code')} code", - 'metrics': { - 'lines_of_code': len(kwargs.get('code', '').split('\n')), - 'complexity_score': 7.2, - 'maintainability_index': 85, - 'technical_debt_ratio': 0.15 + "code_analysis": { + "summary": f"Comprehensive analysis of {kwargs.get('language', 'code')} code", + "metrics": { + "lines_of_code": len(kwargs.get("code", "").split("\n")), + "complexity_score": 7.2, + "maintainability_index": 85, + "technical_debt_ratio": 0.15, }, - 'issues': [ - {'type': 'warning', 'message': 'Consider extracting method for better readability'}, - {'type': 'info', 'message': 'Add type hints for better code documentation'}, - {'type': 'suggestion', 'message': 'Use f-strings for string formatting'} + "issues": [ + { + "type": "warning", + "message": "Consider extracting method for better readability", + }, + { + "type": "info", + "message": "Add type hints for better code documentation", + }, + { + "type": "suggestion", + "message": "Use f-strings for string formatting", + }, + ], + "recommendations": [ + "Implement unit tests for critical functions", + "Add docstrings to public methods", + "Consider using design patterns for scalability", ], - 'recommendations': [ - 'Implement unit tests for critical functions', - 'Add docstrings to public methods', - 'Consider using design patterns for scalability' - ] }, - 'test_generation': { - 'summary': f"Generated comprehensive test suite", - 'test_cases': [ - {'name': 'test_basic_functionality', 'type': 'unit', 'coverage': '95%'}, - {'name': 'test_edge_cases', 'type': 'unit', 'coverage': '88%'}, - {'name': 'test_error_handling', 'type': 'unit', 'coverage': '92%'}, - {'name': 'test_integration', 'type': 'integration', 'coverage': '87%'}, - {'name': 'test_performance', 'type': 'performance', 'coverage': '90%'} + "test_generation": { + "summary": f"Generated comprehensive test suite", + "test_cases": [ + { + "name": "test_basic_functionality", + "type": "unit", + "coverage": "95%", + }, + {"name": "test_edge_cases", "type": "unit", "coverage": "88%"}, + { + "name": "test_error_handling", + "type": "unit", + "coverage": "92%", + }, + { + "name": "test_integration", + "type": "integration", + "coverage": "87%", + }, + { + "name": "test_performance", + "type": "performance", + "coverage": "90%", + }, ], - 'coverage_metrics': { - 'line_coverage': '91%', - 'branch_coverage': '88%', - 'function_coverage': '100%' + "coverage_metrics": { + "line_coverage": "91%", + "branch_coverage": "88%", + "function_coverage": "100%", }, - 'frameworks': ['pytest', 'unittest', 'hypothesis'] + "frameworks": ["pytest", "unittest", "hypothesis"], }, - 'documentation': { - 'summary': f"Generated technical documentation", - 'sections': [ - {'title': 'API Reference', 'pages': 12, 'status': 'complete'}, - {'title': 'Usage Examples', 'pages': 8, 'status': 'complete'}, - {'title': 'Architecture Overview', 'pages': 15, 'status': 'complete'}, - {'title': 'Troubleshooting Guide', 'pages': 6, 'status': 'complete'} + "documentation": { + "summary": f"Generated technical documentation", + "sections": [ + {"title": "API Reference", "pages": 12, "status": "complete"}, + {"title": "Usage Examples", "pages": 8, "status": "complete"}, + { + "title": "Architecture Overview", + "pages": 15, + "status": "complete", + }, + { + "title": "Troubleshooting Guide", + "pages": 6, + "status": "complete", + }, ], - 'formats': ['markdown', 'html', 'pdf'], - 'word_count': 12500, - 'diagrams_generated': 8 + "formats": ["markdown", "html", "pdf"], + "word_count": 12500, + "diagrams_generated": 8, }, - 'refactoring': { - 'summary': f"Comprehensive code refactoring analysis", - 'improvements': [ - {'type': 'Extract Method', 'impact': 'high', 'effort': 'medium'}, - {'type': 'Reduce Complexity', 'impact': 'high', 'effort': 'high'}, - {'type': 'Improve Naming', 'impact': 'medium', 'effort': 'low'}, - {'type': 'Add Type Hints', 'impact': 'medium', 'effort': 'medium'}, - {'type': 'Optimize Performance', 'impact': 'high', 'effort': 'high'} + "refactoring": { + "summary": f"Comprehensive code refactoring analysis", + "improvements": [ + { + "type": "Extract Method", + "impact": "high", + "effort": "medium", + }, + { + "type": "Reduce Complexity", + "impact": "high", + "effort": "high", + }, + {"type": "Improve Naming", "impact": "medium", "effort": "low"}, + { + "type": "Add Type Hints", + "impact": "medium", + "effort": "medium", + }, + { + "type": "Optimize Performance", + "impact": "high", + "effort": "high", + }, ], - 'metrics_before': { - 'cyclomatic_complexity': 12, - 'maintainability_index': 65, - 'lines_of_code': 450 + "metrics_before": { + "cyclomatic_complexity": 12, + "maintainability_index": 65, + "lines_of_code": 450, }, - 'metrics_after': { - 'cyclomatic_complexity': 7, - 'maintainability_index': 85, - 'lines_of_code': 380 + "metrics_after": { + "cyclomatic_complexity": 7, + "maintainability_index": 85, + "lines_of_code": 380, }, - 'estimated_time_saved': '15 hours/month' - } + "estimated_time_saved": "15 hours/month", + }, } - + return { - 'operation': operation, - 'status': 'success', - 'result': demo_results.get(operation, f"Demo result for {operation}"), - 'parameters': kwargs, - 'demo_mode': True + "operation": operation, + "status": "success", + "result": demo_results.get(operation, f"Demo result for {operation}"), + "parameters": kwargs, + "demo_mode": True, } else: # In live mode, this would make actual API calls to Arcade.dev self.logger.warning("Live API mode not yet implemented - using simulation") await asyncio.sleep(1.0) # Faster for live mode - + return { - 'operation': operation, - 'status': 'simulated', - 'result': f"Simulated {operation} result (would be real API call)", - 'parameters': kwargs, - 'demo_mode': False + "operation": operation, + "status": "simulated", + "result": f"Simulated {operation} result (would be real API call)", + "parameters": kwargs, + "demo_mode": False, } - + def optimize_caching(self): """Optimize all cache strategies.""" for strategy_name in self.cache_strategies.keys(): self.enhanced_cache.optimize_cache(strategy_name) - + def get_cache_analytics(self) -> Dict[str, Any]: """Get comprehensive cache analytics.""" return self.enhanced_cache.get_performance_report() @@ -454,32 +510,36 @@ async def demonstrate_enhanced_caching(): """Demonstrate enhanced caching capabilities with response padding.""" print("šŸš€ Enhanced Cache Integration Demo with Response Padding") print("=" * 70) - + # Initialize cache manager with proper configuration cache_config = get_default_cache_config() cache_manager = CacheManager(cache_config) - + # Get API key from environment or use demo mode api_key = os.getenv("ARCADE_API_KEY", "demo_api_key") - + # Create advanced client client = AdvancedArcadeClient(api_key, cache_manager) - + # Show mode status mode_status = "šŸŽ® DEMO MODE" if client.demo_mode else "šŸ”‘ LIVE MODE" print(f"\n{mode_status} - Using API key: {api_key[:10]}...") print(f"šŸ’¾ Cache minimum tokens: {cache_manager.min_tokens}") - + print("\nšŸ“Š Testing different cache strategies with response padding...") - + # Test operations with different strategies operations = [ - ("code_analysis", "fast", {"code": "def hello(): print('world')", "language": "python"}), + ( + "code_analysis", + "fast", + {"code": "def hello(): print('world')", "language": "python"}, + ), ("test_generation", "default", {"code": "def add(a, b): return a + b"}), ("documentation", "persistent", {"code": "class MyClass: pass"}), - ("refactoring", "secure", {"code": "legacy_code_here"}) + ("refactoring", "secure", {"code": "legacy_code_here"}), ] - + # First run (cache miss) print("\nšŸ”„ First execution (cache miss with response padding):") for operation, strategy, params in operations: @@ -487,26 +547,28 @@ async def demonstrate_enhanced_caching(): result = await client.cached_operation(operation, strategy, **params) duration = (time.time() - start_time) * 1000 print(f" {operation} ({strategy}): {duration:.1f}ms") - + # Second run (cache hit) print("\n⚔ Second execution (cache hit):") for operation, strategy, params in operations: start_time = time.time() result = await client.cached_operation(operation, strategy, **params) duration = (time.time() - start_time) * 1000 - cached_status = "HIT" if result.get('_cached') else "MISS" - enhanced_status = "ENHANCED" if result.get('_enhanced_content') else "ORIGINAL" - print(f" {operation} ({strategy}): {duration:.1f}ms [{cached_status}] [{enhanced_status}]") - + cached_status = "HIT" if result.get("_cached") else "MISS" + enhanced_status = "ENHANCED" if result.get("_enhanced_content") else "ORIGINAL" + print( + f" {operation} ({strategy}): {duration:.1f}ms [{cached_status}] [{enhanced_status}]" + ) + # Cache optimization print("\nšŸ”§ Optimizing cache performance...") client.optimize_caching() - + # Analytics print("\nšŸ“ˆ Cache Performance Analytics:") analytics = client.get_cache_analytics() - - for strategy_name, metrics in analytics['strategies'].items(): + + for strategy_name, metrics in analytics["strategies"].items(): print(f"\n {strategy_name.upper()} Strategy:") print(f" Hit Rate: {metrics['hit_rate']:.1%}") print(f" Total Requests: {metrics['total_requests']}") @@ -515,7 +577,7 @@ async def demonstrate_enhanced_caching(): print(f" Cache Sets: {metrics['cache_sets']}") print(f" Avg Hit Time: {metrics['avg_hit_time_ms']:.1f}ms") print(f" Cache Size: {metrics['cache_size_bytes']} bytes") - + print("\nšŸŽ‰ Enhanced caching demonstration completed successfully!") print("\nšŸ’” Key Features Demonstrated:") print(" āœ… Response padding to meet minimum token requirements") @@ -532,19 +594,20 @@ async def main(): """Main demonstration function.""" logging.basicConfig( level=logging.INFO, - format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", ) - + try: await demonstrate_enhanced_caching() return 0 except Exception as e: print(f"āŒ Error: {e}") import traceback + traceback.print_exc() return 1 if __name__ == "__main__": exit_code = asyncio.run(main()) - sys.exit(exit_code) \ No newline at end of file + sys.exit(exit_code) diff --git a/examples/arcade-dev/05_cache_integration/cached_arcade_client_fixed.py b/examples/arcade-dev/05_cache_integration/cached_arcade_client_fixed.py index 58cd5b8..5a8dc86 100644 --- a/examples/arcade-dev/05_cache_integration/cached_arcade_client_fixed.py +++ b/examples/arcade-dev/05_cache_integration/cached_arcade_client_fixed.py @@ -45,16 +45,17 @@ @dataclass class CacheStrategy: """Cache strategy configuration.""" + strategy_name: str ttl_seconds: int max_entries: int = 1000 compression_enabled: bool = False encryption_enabled: bool = False - + # Performance thresholds hit_rate_threshold: float = 0.8 response_time_threshold_ms: float = 100 - + # Advanced features prefetch_enabled: bool = False background_refresh_enabled: bool = False @@ -63,54 +64,57 @@ class CacheStrategy: @dataclass class CachePerformanceMetrics: """Cache performance tracking.""" + hits: int = 0 misses: int = 0 sets: int = 0 errors: int = 0 - + # Timing metrics total_hit_time_ms: float = 0 total_miss_time_ms: float = 0 total_set_time_ms: float = 0 - + # Size metrics cache_size_bytes: int = 0 - + def get_hit_rate(self) -> float: total_requests = self.hits + self.misses return (self.hits / total_requests) if total_requests > 0 else 0.0 - + def get_avg_hit_time_ms(self) -> float: return (self.total_hit_time_ms / self.hits) if self.hits > 0 else 0.0 - + def get_avg_miss_time_ms(self) -> float: return (self.total_miss_time_ms / self.misses) if self.misses > 0 else 0.0 class SimpleCacheManager: """Simplified cache manager that works with FACT CacheManager.""" - - def __init__(self, primary_cache: CacheManager, strategies: Dict[str, CacheStrategy]): + + def __init__( + self, primary_cache: CacheManager, strategies: Dict[str, CacheStrategy] + ): self.primary_cache = primary_cache self.strategies = strategies self.logger = logging.getLogger(f"{__name__}.SimpleCacheManager") - + # Performance tracking self.metrics: Dict[str, CachePerformanceMetrics] = { strategy_name: CachePerformanceMetrics() for strategy_name in strategies.keys() } - + def get(self, key: str, strategy_name: str = "default") -> Optional[Any]: """Get value with caching strategy.""" start_time = time.time() strategy = self.strategies.get(strategy_name, self.strategies["default"]) metrics = self.metrics[strategy_name] - + try: # Get from primary cache cache_entry = self.primary_cache.get(key) - + if cache_entry: # Cache hit metrics.hits += 1 @@ -123,137 +127,136 @@ def get(self, key: str, strategy_name: str = "default") -> Optional[Any]: metrics.total_miss_time_ms += (time.time() - start_time) * 1000 self.logger.debug(f"Cache miss for key: {key}") return None - + except Exception as e: metrics.errors += 1 self.logger.error(f"Cache get error for key {key}: {e}") return None - + def set(self, key: str, value: Any, strategy_name: str = "default") -> bool: """Set value with caching strategy.""" start_time = time.time() strategy = self.strategies.get(strategy_name, self.strategies["default"]) metrics = self.metrics[strategy_name] - + try: # Store in primary cache cache_entry = self.primary_cache.store(key, value) - + if cache_entry: metrics.sets += 1 metrics.total_set_time_ms += (time.time() - start_time) * 1000 metrics.cache_size_bytes += cache_entry.size_bytes - + self.logger.debug(f"Cache set successful for key: {key}") return True else: metrics.errors += 1 return False - + except Exception as e: metrics.errors += 1 self.logger.error(f"Cache set error for key {key}: {e}") return False - + def optimize_cache(self, strategy_name: str = "default"): """Optimize cache performance based on usage patterns.""" strategy = self.strategies.get(strategy_name, self.strategies["default"]) metrics = self.metrics[strategy_name] - + try: # Check if optimization is needed hit_rate = metrics.get_hit_rate() - + if hit_rate < strategy.hit_rate_threshold: - self.logger.warning(f"Cache hit rate ({hit_rate:.2%}) below threshold ({strategy.hit_rate_threshold:.2%})") - self.logger.info(f"Cache optimization completed for strategy: {strategy_name}") - + self.logger.warning( + f"Cache hit rate ({hit_rate:.2%}) below threshold ({strategy.hit_rate_threshold:.2%})" + ) + self.logger.info( + f"Cache optimization completed for strategy: {strategy_name}" + ) + except Exception as e: self.logger.error(f"Cache optimization error: {e}") - + def get_performance_report(self) -> Dict[str, Any]: """Generate comprehensive performance report.""" - report = { - 'timestamp': datetime.now(timezone.utc).isoformat(), - 'strategies': {} - } - + report = {"timestamp": datetime.now(timezone.utc).isoformat(), "strategies": {}} + for strategy_name, metrics in self.metrics.items(): strategy_report = { - 'hit_rate': metrics.get_hit_rate(), - 'total_requests': metrics.hits + metrics.misses, - 'cache_hits': metrics.hits, - 'cache_misses': metrics.misses, - 'cache_sets': metrics.sets, - 'cache_errors': metrics.errors, - 'avg_hit_time_ms': metrics.get_avg_hit_time_ms(), - 'avg_miss_time_ms': metrics.get_avg_miss_time_ms(), - 'cache_size_bytes': metrics.cache_size_bytes + "hit_rate": metrics.get_hit_rate(), + "total_requests": metrics.hits + metrics.misses, + "cache_hits": metrics.hits, + "cache_misses": metrics.misses, + "cache_sets": metrics.sets, + "cache_errors": metrics.errors, + "avg_hit_time_ms": metrics.get_avg_hit_time_ms(), + "avg_miss_time_ms": metrics.get_avg_miss_time_ms(), + "cache_size_bytes": metrics.cache_size_bytes, } - - report['strategies'][strategy_name] = strategy_report - + + report["strategies"][strategy_name] = strategy_report + return report class AdvancedArcadeClient: """Advanced Arcade.dev client with sophisticated caching.""" - + def __init__(self, api_key: str, cache_manager: CacheManager): self.api_key = api_key self.logger = logging.getLogger(__name__) - + # Detect demo mode based on API key self.demo_mode = ( - not api_key or - api_key in ["demo_api_key", "test-key-for-testing"] or - api_key.startswith("demo_") or - api_key.startswith("test_") + not api_key + or api_key in ["demo_api_key", "test-key-for-testing"] + or api_key.startswith("demo_") + or api_key.startswith("test_") ) - + if self.demo_mode: self.logger.info("Running in DEMO MODE - no real API calls will be made") else: self.logger.info("Running with real API credentials") - + # Setup cache strategies self.cache_strategies = { - "default": CacheStrategy( - strategy_name="default", - ttl_seconds=3600 - ), + "default": CacheStrategy(strategy_name="default", ttl_seconds=3600), "fast": CacheStrategy( - strategy_name="fast", - ttl_seconds=300, - prefetch_enabled=True + strategy_name="fast", ttl_seconds=300, prefetch_enabled=True ), "persistent": CacheStrategy( strategy_name="persistent", ttl_seconds=86400, # 24 hours compression_enabled=True, - background_refresh_enabled=True + background_refresh_enabled=True, ), "secure": CacheStrategy( - strategy_name="secure", - ttl_seconds=1800, - encryption_enabled=True - ) + strategy_name="secure", ttl_seconds=1800, encryption_enabled=True + ), } - + # Initialize simplified cache manager self.simple_cache = SimpleCacheManager(cache_manager, self.cache_strategies) - + def _generate_cache_key(self, operation: str, **kwargs) -> str: """Generate cache key for operation.""" params_str = json.dumps(kwargs, sort_keys=True, default=str) params_hash = hashlib.sha256(params_str.encode()).hexdigest()[:16] return f"arcade:{operation}:{params_hash}" - - async def cached_operation(self, operation: str, strategy: str = "default", - force_refresh: bool = False, **kwargs) -> Dict[str, Any]: + + async def cached_operation( + self, + operation: str, + strategy: str = "default", + force_refresh: bool = False, + **kwargs, + ) -> Dict[str, Any]: """Execute operation with advanced caching.""" cache_key = self._generate_cache_key(operation, **kwargs) - + # Check cache first if not force_refresh: cached_result = self.simple_cache.get(cache_key, strategy) @@ -266,78 +269,78 @@ async def cached_operation(self, operation: str, strategy: str = "default", except json.JSONDecodeError: # If not JSON, wrap in result structure return { - 'operation': operation, - 'status': 'success', - 'result': cached_result, - 'parameters': kwargs, - '_cached': True, - '_timestamp': datetime.now(timezone.utc).isoformat() + "operation": operation, + "status": "success", + "result": cached_result, + "parameters": kwargs, + "_cached": True, + "_timestamp": datetime.now(timezone.utc).isoformat(), } return cached_result - + # Execute operation start_time = time.time() result = await self._execute_operation(operation, **kwargs) execution_time = (time.time() - start_time) * 1000 - + # Add metadata - result['_execution_time_ms'] = execution_time - result['_cached'] = False - result['_timestamp'] = datetime.now(timezone.utc).isoformat() - + result["_execution_time_ms"] = execution_time + result["_cached"] = False + result["_timestamp"] = datetime.now(timezone.utc).isoformat() + # Cache result as JSON string result_json = json.dumps(result, default=str) self.simple_cache.set(cache_key, result_json, strategy) - + return result - + async def _execute_operation(self, operation: str, **kwargs) -> Dict[str, Any]: """Execute the actual operation (mock implementation for demo).""" if self.demo_mode: # Simulate different operation types with realistic delays operation_times = { - 'code_analysis': 2.0, - 'test_generation': 3.0, - 'documentation': 1.5, - 'refactoring': 2.5 + "code_analysis": 2.0, + "test_generation": 3.0, + "documentation": 1.5, + "refactoring": 2.5, } - + await asyncio.sleep(operation_times.get(operation, 1.0)) - + # Generate demo results based on operation type demo_results = { - 'code_analysis': f"Demo analysis of {kwargs.get('language', 'code')}: Found {len(kwargs.get('code', ''))} characters, suggests 3 improvements", - 'test_generation': f"Demo test suite generated with 5 test cases for function: {kwargs.get('code', 'function')[:20]}...", - 'documentation': f"Demo documentation generated for {len(kwargs.get('code', ''))} lines of code", - 'refactoring': f"Demo refactoring suggestions: Extract 2 methods, reduce complexity by 15%" + "code_analysis": f"Demo analysis of {kwargs.get('language', 'code')}: Found {len(kwargs.get('code', ''))} characters, suggests 3 improvements", + "test_generation": f"Demo test suite generated with 5 test cases for function: {kwargs.get('code', 'function')[:20]}...", + "documentation": f"Demo documentation generated for {len(kwargs.get('code', ''))} lines of code", + "refactoring": f"Demo refactoring suggestions: Extract 2 methods, reduce complexity by 15%", } - + return { - 'operation': operation, - 'status': 'success', - 'result': demo_results.get(operation, f"Demo result for {operation}"), - 'parameters': kwargs, - 'demo_mode': True + "operation": operation, + "status": "success", + "result": demo_results.get(operation, f"Demo result for {operation}"), + "parameters": kwargs, + "demo_mode": True, } else: # In live mode, this would make actual API calls to Arcade.dev # For now, we'll simulate since we don't have real API integration self.logger.warning("Live API mode not yet implemented - using simulation") await asyncio.sleep(1.0) # Faster for live mode - + return { - 'operation': operation, - 'status': 'simulated', - 'result': f"Simulated {operation} result (would be real API call)", - 'parameters': kwargs, - 'demo_mode': False + "operation": operation, + "status": "simulated", + "result": f"Simulated {operation} result (would be real API call)", + "parameters": kwargs, + "demo_mode": False, } - + def optimize_caching(self): """Optimize all cache strategies.""" for strategy_name in self.cache_strategies.keys(): self.simple_cache.optimize_cache(strategy_name) - + def get_cache_analytics(self) -> Dict[str, Any]: """Get comprehensive cache analytics.""" return self.simple_cache.get_performance_report() @@ -347,31 +350,35 @@ async def demonstrate_advanced_caching(): """Demonstrate advanced caching capabilities.""" print("šŸš€ Advanced Cache Integration Demo - Fixed Version") print("=" * 60) - + # Initialize cache manager with proper configuration cache_config = get_default_cache_config() cache_manager = CacheManager(cache_config) - + # Get API key from environment or use demo mode api_key = os.getenv("ARCADE_API_KEY", "demo_api_key") - + # Create advanced client client = AdvancedArcadeClient(api_key, cache_manager) - + # Show mode status mode_status = "šŸŽ® DEMO MODE" if client.demo_mode else "šŸ”‘ LIVE MODE" print(f"\n{mode_status} - Using API key: {api_key[:10]}...") - + print("\nšŸ“Š Testing different cache strategies...") - + # Test operations with different strategies operations = [ - ("code_analysis", "fast", {"code": "def hello(): print('world')", "language": "python"}), + ( + "code_analysis", + "fast", + {"code": "def hello(): print('world')", "language": "python"}, + ), ("test_generation", "default", {"code": "def add(a, b): return a + b"}), ("documentation", "persistent", {"code": "class MyClass: pass"}), - ("refactoring", "secure", {"code": "legacy_code_here"}) + ("refactoring", "secure", {"code": "legacy_code_here"}), ] - + # First run (cache miss) print("\nšŸ”„ First execution (cache miss):") for operation, strategy, params in operations: @@ -379,25 +386,25 @@ async def demonstrate_advanced_caching(): result = await client.cached_operation(operation, strategy, **params) duration = (time.time() - start_time) * 1000 print(f" {operation} ({strategy}): {duration:.1f}ms") - + # Second run (cache hit) print("\n⚔ Second execution (cache hit):") for operation, strategy, params in operations: start_time = time.time() result = await client.cached_operation(operation, strategy, **params) duration = (time.time() - start_time) * 1000 - cached_status = "HIT" if result.get('_cached') else "MISS" + cached_status = "HIT" if result.get("_cached") else "MISS" print(f" {operation} ({strategy}): {duration:.1f}ms [{cached_status}]") - + # Cache optimization print("\nšŸ”§ Optimizing cache performance...") client.optimize_caching() - + # Analytics print("\nšŸ“ˆ Cache Performance Analytics:") analytics = client.get_cache_analytics() - - for strategy_name, metrics in analytics['strategies'].items(): + + for strategy_name, metrics in analytics["strategies"].items(): print(f"\n {strategy_name.upper()} Strategy:") print(f" Hit Rate: {metrics['hit_rate']:.1%}") print(f" Total Requests: {metrics['total_requests']}") @@ -405,7 +412,7 @@ async def demonstrate_advanced_caching(): print(f" Cache Misses: {metrics['cache_misses']}") print(f" Avg Hit Time: {metrics['avg_hit_time_ms']:.1f}ms") print(f" Cache Size: {metrics['cache_size_bytes']} bytes") - + print("\nšŸŽ‰ Advanced caching demonstration completed successfully!") print("\nšŸ’” Key Features Demonstrated:") print(" āœ… Multiple cache strategies (fast, default, persistent, secure)") @@ -420,19 +427,20 @@ async def main(): """Main demonstration function.""" logging.basicConfig( level=logging.INFO, - format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", ) - + try: await demonstrate_advanced_caching() return 0 except Exception as e: print(f"āŒ Error: {e}") import traceback + traceback.print_exc() return 1 if __name__ == "__main__": exit_code = asyncio.run(main()) - sys.exit(exit_code) \ No newline at end of file + sys.exit(exit_code) diff --git a/examples/arcade-dev/06_security/secure_tool_execution.py b/examples/arcade-dev/06_security/secure_tool_execution.py index 52258c5..eebf67c 100644 --- a/examples/arcade-dev/06_security/secure_tool_execution.py +++ b/examples/arcade-dev/06_security/secure_tool_execution.py @@ -36,25 +36,26 @@ from src.security.token_manager import TokenManager from src.core.driver import FACTDriver from src.cache.manager import CacheManager + FACT_AVAILABLE = True except ImportError as e: print(f"āš ļø FACT modules not fully available: {e}") print("šŸ”„ Running in demo mode with mock implementations") FACT_AVAILABLE = False - + # Create mock classes for demo class MockAuthorizationManager: def __init__(self, *args, **kwargs): pass - + class MockInputSanitizer: def __init__(self, *args, **kwargs): pass - + class MockTokenManager: def __init__(self, *args, **kwargs): pass - + AuthorizationManager = MockAuthorizationManager InputSanitizer = MockInputSanitizer TokenManager = MockTokenManager @@ -63,23 +64,26 @@ def __init__(self, *args, **kwargs): @dataclass class SecurityConfig: """Security configuration for Arcade.dev integration.""" + enable_input_validation: bool = True enable_audit_logging: bool = True enable_rate_limiting: bool = True enable_data_encryption: bool = True - + # Rate limiting max_requests_per_minute: int = 60 max_requests_per_hour: int = 1000 - + # Input validation max_input_size: int = 1024 * 1024 # 1MB - allowed_file_types: Set[str] = field(default_factory=lambda: {'.py', '.js', '.ts', '.java', '.cpp'}) - + allowed_file_types: Set[str] = field( + default_factory=lambda: {".py", ".js", ".ts", ".java", ".cpp"} + ) + # Audit settings audit_log_file: str = "arcade_security_audit.log" log_sensitive_data: bool = False - + # Token settings token_expiry_minutes: int = 60 refresh_token_expiry_days: int = 7 @@ -88,6 +92,7 @@ class SecurityConfig: @dataclass class UserPermissions: """User permission model.""" + user_id: str scopes: Set[str] max_daily_requests: int = 1000 @@ -100,6 +105,7 @@ class UserPermissions: @dataclass class AuditLogEntry: """Audit log entry for security tracking.""" + timestamp: datetime user_id: str operation: str @@ -113,38 +119,40 @@ class AuditLogEntry: class SecureCredentialManager: """Manages secure storage and retrieval of API credentials.""" - + def __init__(self): self.logger = logging.getLogger(f"{__name__}.CredentialManager") self._credentials: Dict[str, str] = {} self._demo_mode = not FACT_AVAILABLE - + if self._demo_mode: - self.logger.info("Running in demo mode - using simplified credential storage") - self._encryption_key = b'demo_key_for_testing_only_not_secure' + self.logger.info( + "Running in demo mode - using simplified credential storage" + ) + self._encryption_key = b"demo_key_for_testing_only_not_secure" else: self._encryption_key = self._get_or_create_encryption_key() - + def _get_or_create_encryption_key(self) -> bytes: """Get or create encryption key for credential storage.""" # In demo mode, use a simple key if self._demo_mode: - return b'demo_key_for_testing_only_not_secure' - - key_file = Path.home() / '.fact' / 'encryption.key' - + return b"demo_key_for_testing_only_not_secure" + + key_file = Path.home() / ".fact" / "encryption.key" + if key_file.exists(): - with open(key_file, 'rb') as f: + with open(key_file, "rb") as f: return f.read() else: # Create new key key_file.parent.mkdir(exist_ok=True) key = secrets.token_bytes(32) - with open(key_file, 'wb') as f: + with open(key_file, "wb") as f: f.write(key) key_file.chmod(0o600) # Read only for owner return key - + def store_credential(self, service: str, credential: str) -> bool: """Securely store a credential.""" try: @@ -156,134 +164,148 @@ def store_credential(self, service: str, credential: str) -> bool: except Exception as e: self.logger.error(f"Failed to store credential for {service}: {e}") return False - + def get_credential(self, service: str) -> Optional[str]: """Retrieve a credential securely.""" if self._demo_mode: # In demo mode, return a mock credential if no env var is set env_key = f"{service.upper()}_API_KEY" credential = os.getenv(env_key) - + if not credential: # Return demo credential for testing credential = f"demo_key_for_{service}_testing_only" self.logger.info(f"Using demo credential for service: {service}") else: self.logger.info(f"Using environment credential for service: {service}") - + return credential else: # In production, implement proper decryption # For now, use environment variables env_key = f"{service.upper()}_API_KEY" credential = os.getenv(env_key) - + if not credential: self.logger.warning(f"No credential found for service: {service}") - + return credential class SecurityValidator: """Validates and sanitizes inputs for security.""" - + def __init__(self, config: SecurityConfig): self.config = config self.logger = logging.getLogger(f"{__name__}.SecurityValidator") - + # Dangerous patterns to detect self.dangerous_patterns = [ - r'eval\s*\(', - r'exec\s*\(', - r'import\s+os', - r'import\s+subprocess', - r'__import__', - r'open\s*\(', - r'file\s*\(', - r'input\s*\(', - r'raw_input\s*\(', + r"eval\s*\(", + r"exec\s*\(", + r"import\s+os", + r"import\s+subprocess", + r"__import__", + r"open\s*\(", + r"file\s*\(", + r"input\s*\(", + r"raw_input\s*\(", ] - - def validate_code_input(self, code: str, language: str = 'python') -> Dict[str, Any]: + + def validate_code_input( + self, code: str, language: str = "python" + ) -> Dict[str, Any]: """Validate code input for security risks.""" validation_result = { - 'is_safe': True, - 'warnings': [], - 'blocked_patterns': [], - 'sanitized_code': code + "is_safe": True, + "warnings": [], + "blocked_patterns": [], + "sanitized_code": code, } - + if not self.config.enable_input_validation: return validation_result - + # Size check if len(code) > self.config.max_input_size: - validation_result['is_safe'] = False - validation_result['warnings'].append(f"Code exceeds maximum size limit: {self.config.max_input_size}") - + validation_result["is_safe"] = False + validation_result["warnings"].append( + f"Code exceeds maximum size limit: {self.config.max_input_size}" + ) + # Pattern detection for pattern in self.dangerous_patterns: if re.search(pattern, code, re.IGNORECASE): - validation_result['blocked_patterns'].append(pattern) - validation_result['warnings'].append(f"Potentially dangerous pattern detected: {pattern}") - + validation_result["blocked_patterns"].append(pattern) + validation_result["warnings"].append( + f"Potentially dangerous pattern detected: {pattern}" + ) + # If dangerous patterns found, mark as unsafe - if validation_result['blocked_patterns']: - validation_result['is_safe'] = False - + if validation_result["blocked_patterns"]: + validation_result["is_safe"] = False + return validation_result class AuditLogger: """Handles security audit logging.""" - + def __init__(self, config: SecurityConfig): self.config = config self.audit_file = Path(config.audit_log_file) self.logger = logging.getLogger(f"{__name__}.AuditLogger") - + # Setup audit log file self._setup_audit_logging() - + def _setup_audit_logging(self): """Setup secure audit logging.""" if self.config.enable_audit_logging: # Create audit log directory self.audit_file.parent.mkdir(exist_ok=True) - + # Setup file permissions (read/write for owner only) if self.audit_file.exists(): self.audit_file.chmod(0o600) - + def log_event(self, entry: AuditLogEntry): """Log a security audit event.""" if not self.config.enable_audit_logging: return - + try: log_data = { - 'timestamp': entry.timestamp.isoformat(), - 'request_id': entry.request_id, - 'user_id': entry.user_id if not self.config.log_sensitive_data else self._hash_user_id(entry.user_id), - 'operation': entry.operation, - 'status': entry.status, - 'ip_address': entry.ip_address if not self.config.log_sensitive_data else self._hash_ip(entry.ip_address), - 'risk_score': entry.risk_score, - 'metadata': entry.metadata + "timestamp": entry.timestamp.isoformat(), + "request_id": entry.request_id, + "user_id": ( + entry.user_id + if not self.config.log_sensitive_data + else self._hash_user_id(entry.user_id) + ), + "operation": entry.operation, + "status": entry.status, + "ip_address": ( + entry.ip_address + if not self.config.log_sensitive_data + else self._hash_ip(entry.ip_address) + ), + "risk_score": entry.risk_score, + "metadata": entry.metadata, } - - with open(self.audit_file, 'a') as f: - f.write(json.dumps(log_data) + '\n') - + + with open(self.audit_file, "a") as f: + f.write(json.dumps(log_data) + "\n") + self.logger.debug(f"Audit event logged: {entry.operation} - {entry.status}") - + except Exception as e: self.logger.error(f"Failed to log audit event: {e}") - + def _hash_user_id(self, user_id: str) -> str: """Hash user ID for privacy.""" return hashlib.sha256(user_id.encode()).hexdigest()[:16] - + def _hash_ip(self, ip_address: Optional[str]) -> Optional[str]: """Hash IP address for privacy.""" if not ip_address: @@ -293,140 +315,154 @@ def _hash_ip(self, ip_address: Optional[str]) -> Optional[str]: class SecureArcadeClient: """Secure Arcade.dev client with comprehensive security measures.""" - + def __init__(self, config: SecurityConfig): self.config = config self.credential_manager = SecureCredentialManager() self.validator = SecurityValidator(config) self.audit_logger = AuditLogger(config) self.logger = logging.getLogger(__name__) - + # User permission storage (in production, use database) self.user_permissions: Dict[str, UserPermissions] = {} - + # Active sessions self.active_sessions: Dict[str, Dict[str, Any]] = {} - - def register_user(self, user_id: str, scopes: Set[str], allowed_operations: Set[str] = None) -> bool: + + def register_user( + self, user_id: str, scopes: Set[str], allowed_operations: Set[str] = None + ) -> bool: """Register a user with specific permissions.""" try: if allowed_operations is None: - allowed_operations = {'code_analysis', 'test_generation', 'documentation'} - + allowed_operations = { + "code_analysis", + "test_generation", + "documentation", + } + permissions = UserPermissions( - user_id=user_id, - scopes=scopes, - allowed_operations=allowed_operations + user_id=user_id, scopes=scopes, allowed_operations=allowed_operations ) - + self.user_permissions[user_id] = permissions - + audit_entry = AuditLogEntry( timestamp=datetime.now(timezone.utc), user_id=user_id, - operation='user_registration', - status='success', - metadata={'scopes': list(scopes), 'operations': list(allowed_operations)} + operation="user_registration", + status="success", + metadata={ + "scopes": list(scopes), + "operations": list(allowed_operations), + }, ) self.audit_logger.log_event(audit_entry) - + self.logger.info(f"User registered: {user_id}") return True - + except Exception as e: self.logger.error(f"Failed to register user {user_id}: {e}") return False - - def authenticate_user(self, user_id: str, api_key: str, ip_address: str = None) -> Dict[str, Any]: + + def authenticate_user( + self, user_id: str, api_key: str, ip_address: str = None + ) -> Dict[str, Any]: """Authenticate user and create secure session.""" auth_result = { - 'authenticated': False, - 'session_token': None, - 'permissions': None, - 'error': None + "authenticated": False, + "session_token": None, + "permissions": None, + "error": None, } - + try: # Check if user exists if user_id not in self.user_permissions: audit_entry = AuditLogEntry( timestamp=datetime.now(timezone.utc), user_id=user_id, - operation='authentication', - status='failure', + operation="authentication", + status="failure", ip_address=ip_address, - metadata={'reason': 'User not found'}, - risk_score=7.0 + metadata={"reason": "User not found"}, + risk_score=7.0, ) self.audit_logger.log_event(audit_entry) - auth_result['error'] = 'User not found' + auth_result["error"] = "User not found" return auth_result - + # Create session token session_token = secrets.token_urlsafe(32) user_perms = self.user_permissions[user_id] session_data = { - 'user_id': user_id, - 'created_at': datetime.now(timezone.utc), - 'expires_at': datetime.now(timezone.utc) + timedelta(minutes=self.config.token_expiry_minutes), - 'ip_address': ip_address, - 'permissions': user_perms + "user_id": user_id, + "created_at": datetime.now(timezone.utc), + "expires_at": datetime.now(timezone.utc) + + timedelta(minutes=self.config.token_expiry_minutes), + "ip_address": ip_address, + "permissions": user_perms, } - + self.active_sessions[session_token] = session_data - + audit_entry = AuditLogEntry( timestamp=datetime.now(timezone.utc), user_id=user_id, - operation='authentication', - status='success', - ip_address=ip_address + operation="authentication", + status="success", + ip_address=ip_address, ) self.audit_logger.log_event(audit_entry) - - auth_result.update({ - 'authenticated': True, - 'session_token': session_token, - 'permissions': { - 'scopes': list(user_perms.scopes), - 'operations': list(user_perms.allowed_operations), - 'expires_at': session_data['expires_at'].isoformat() + + auth_result.update( + { + "authenticated": True, + "session_token": session_token, + "permissions": { + "scopes": list(user_perms.scopes), + "operations": list(user_perms.allowed_operations), + "expires_at": session_data["expires_at"].isoformat(), + }, } - }) - + ) + return auth_result - + except Exception as e: self.logger.error(f"Authentication error for {user_id}: {e}") - auth_result['error'] = 'Authentication failed' + auth_result["error"] = "Authentication failed" return auth_result - - async def secure_code_analysis(self, session_token: str, code: str, language: str = 'python') -> Dict[str, Any]: + + async def secure_code_analysis( + self, session_token: str, code: str, language: str = "python" + ) -> Dict[str, Any]: """Perform secure code analysis with full security validation.""" # Validate session if session_token not in self.active_sessions: - return {'error': 'Invalid session', 'status': 'unauthorized'} - + return {"error": "Invalid session", "status": "unauthorized"} + session = self.active_sessions[session_token] - user_id = session['user_id'] - + user_id = session["user_id"] + # Input validation validation_result = self.validator.validate_code_input(code, language) - if not validation_result['is_safe']: + if not validation_result["is_safe"]: return { - 'error': 'Input validation failed', - 'warnings': validation_result['warnings'], - 'status': 'invalid_input' + "error": "Input validation failed", + "warnings": validation_result["warnings"], + "status": "invalid_input", } - + # Mock analysis result return { - 'status': 'success', - 'analysis': { - 'score': 8.5, - 'suggestions': ['Use type hints', 'Add error handling'], - 'security_check_passed': True - } + "status": "success", + "analysis": { + "score": 8.5, + "suggestions": ["Use type hints", "Add error handling"], + "security_check_passed": True, + }, } @@ -434,55 +470,53 @@ async def demonstrate_security(): """Demonstrate security features.""" print("šŸ”’ Secure Arcade.dev Integration Demo") print("=" * 50) - + # Configure security (ensure logs directory exists) logs_dir = Path("logs") logs_dir.mkdir(exist_ok=True) - + config = SecurityConfig( enable_input_validation=True, enable_audit_logging=True, - audit_log_file="logs/arcade_security_audit.log" + audit_log_file="logs/arcade_security_audit.log", ) - + # Create secure client client = SecureArcadeClient(config) - + # Register user client.register_user( - user_id='demo_user', - scopes={'read', 'write'}, - allowed_operations={'code_analysis', 'test_generation'} + user_id="demo_user", + scopes={"read", "write"}, + allowed_operations={"code_analysis", "test_generation"}, ) - + # Authenticate auth_result = client.authenticate_user( - user_id='demo_user', - api_key='demo_key', - ip_address='127.0.0.1' + user_id="demo_user", api_key="demo_key", ip_address="127.0.0.1" ) - - if auth_result['authenticated']: + + if auth_result["authenticated"]: print("āœ… User authenticated successfully") - session_token = auth_result['session_token'] - + session_token = auth_result["session_token"] + # Test safe code safe_code = "def hello_world():\n print('Hello, World!')" result = await client.secure_code_analysis(session_token, safe_code) print(f"āœ… Safe code analysis: {result['status']}") - + # Test dangerous code dangerous_code = "import os\nos.system('rm -rf /')" result = await client.secure_code_analysis(session_token, dangerous_code) print(f"šŸ›”ļø Dangerous code blocked: {result['status']}") - + print("\nšŸŽ‰ Security demonstration completed!") async def main(): """Main demonstration function.""" logging.basicConfig(level=logging.INFO) - + try: await demonstrate_security() return 0 @@ -493,4 +527,4 @@ async def main(): if __name__ == "__main__": exit_code = asyncio.run(main()) - sys.exit(exit_code) \ No newline at end of file + sys.exit(exit_code) diff --git a/examples/arcade-dev/07_cache_integration/cached_arcade_client.py b/examples/arcade-dev/07_cache_integration/cached_arcade_client.py index 931af45..0ec09d3 100644 --- a/examples/arcade-dev/07_cache_integration/cached_arcade_client.py +++ b/examples/arcade-dev/07_cache_integration/cached_arcade_client.py @@ -32,321 +32,465 @@ @dataclass class CacheConfig: """Configuration for caching behavior.""" + default_ttl: int = 3600 # 1 hour max_size: int = 1000 enabled: bool = True key_prefix: str = "arcade" - + # Cache TTL by operation type ttl_by_operation: Dict[str, int] = None - + def __post_init__(self): if self.ttl_by_operation is None: self.ttl_by_operation = { - 'health': 300, # 5 minutes - 'user_info': 1800, # 30 minutes - 'code_analysis': 7200, # 2 hours - 'test_generation': 3600, # 1 hour - 'documentation': 7200, # 2 hours - 'refactoring': 3600, # 1 hour + "health": 300, # 5 minutes + "user_info": 1800, # 30 minutes + "code_analysis": 7200, # 2 hours + "test_generation": 3600, # 1 hour + "documentation": 7200, # 2 hours + "refactoring": 3600, # 1 hour } class CachedArcadeClient: """Arcade.dev client with intelligent caching.""" - - def __init__(self, api_key: str, cache_manager: CacheManager, cache_config: CacheConfig = None): + + def __init__( + self, + api_key: str, + cache_manager: CacheManager, + cache_config: CacheConfig = None, + ): self.api_key = api_key self.cache_manager = cache_manager self.cache_config = cache_config or CacheConfig() self.logger = logging.getLogger(__name__) self.stats = { - 'cache_hits': 0, - 'cache_misses': 0, - 'api_calls': 0, - 'cache_sets': 0, - 'cache_errors': 0 + "cache_hits": 0, + "cache_misses": 0, + "api_calls": 0, + "cache_sets": 0, + "cache_errors": 0, } - + def _generate_cache_key(self, operation: str, **kwargs) -> str: """Generate a consistent cache key for the operation.""" # Create a stable hash of the parameters param_str = json.dumps(kwargs, sort_keys=True, default=str) param_hash = hashlib.sha256(param_str.encode()).hexdigest()[:16] - + return f"{self.cache_config.key_prefix}:{operation}:{param_hash}" - + def _get_cache_ttl(self, operation: str) -> int: """Get appropriate TTL for the operation.""" return self.cache_config.ttl_by_operation.get( - operation, - self.cache_config.default_ttl + operation, self.cache_config.default_ttl ) - + async def _get_from_cache(self, cache_key: str) -> Optional[Dict[str, Any]]: """Retrieve data from cache with error handling.""" if not self.cache_config.enabled: return None - + try: cache_entry = self.cache_manager.get(cache_key) if cache_entry and cache_entry.is_valid: - self.stats['cache_hits'] += 1 + self.stats["cache_hits"] += 1 self.logger.debug(f"Cache hit for key: {cache_key}") # Parse the JSON content back to dict return json.loads(cache_entry.content) else: - self.stats['cache_misses'] += 1 + self.stats["cache_misses"] += 1 self.logger.debug(f"Cache miss for key: {cache_key}") return None - + except Exception as e: - self.stats['cache_errors'] += 1 + self.stats["cache_errors"] += 1 self.logger.warning(f"Cache retrieval error for {cache_key}: {e}") return None - + async def _set_cache(self, cache_key: str, data: Dict[str, Any], ttl: int) -> bool: """Store data in cache with error handling.""" if not self.cache_config.enabled: return False - + try: # Add metadata to cached data cache_data = { - 'data': data, - 'cached_at': time.time(), - 'ttl': ttl, - 'operation': cache_key.split(':')[1] if ':' in cache_key else 'unknown' + "data": data, + "cached_at": time.time(), + "ttl": ttl, + "operation": cache_key.split(":")[1] if ":" in cache_key else "unknown", } - + self.cache_manager.store(cache_key, json.dumps(cache_data)) - self.stats['cache_sets'] += 1 + self.stats["cache_sets"] += 1 self.logger.debug(f"Cached data for key: {cache_key} (TTL: {ttl}s)") return True - + except Exception as e: - self.stats['cache_errors'] += 1 + self.stats["cache_errors"] += 1 self.logger.warning(f"Cache storage error for {cache_key}: {e}") return False - + async def _simulate_api_call(self, operation: str, **kwargs) -> Dict[str, Any]: """Simulate an API call to Arcade.dev (replace with actual implementation).""" - self.stats['api_calls'] += 1 - + self.stats["api_calls"] += 1 + # Simulate different response times delay_map = { - 'health': 0.1, - 'user_info': 0.3, - 'code_analysis': 2.0, - 'test_generation': 3.0, - 'documentation': 2.5, - 'refactoring': 1.8, + "health": 0.1, + "user_info": 0.3, + "code_analysis": 2.0, + "test_generation": 3.0, + "documentation": 2.5, + "refactoring": 1.8, } - + await asyncio.sleep(delay_map.get(operation, 1.0)) - + # Generate mock responses based on operation - if operation == 'health': - return {'status': 'healthy', 'timestamp': time.time()} - - elif operation == 'user_info': + if operation == "health": + return {"status": "healthy", "timestamp": time.time()} + + elif operation == "user_info": return { - 'user_id': '12345', - 'username': 'demo_user', - 'tier': 'premium', - 'api_calls_remaining': 1000 + "user_id": "12345", + "username": "demo_user", + "tier": "premium", + "api_calls_remaining": 1000, } - - elif operation == 'code_analysis': - code = kwargs.get('code', '') + + elif operation == "code_analysis": + code = kwargs.get("code", "") return { - 'analysis_id': f"analysis_{hash(code) % 10000}", - 'language': kwargs.get('language', 'python'), - 'lines_analyzed': len(code.split('\n')), - 'suggestions': [ - {'type': 'performance', 'message': 'Consider using list comprehension for better performance and readability'}, - {'type': 'style', 'message': 'Function name should be snake_case according to PEP 8 style guidelines'}, - {'type': 'security', 'message': 'Validate input parameters to prevent injection attacks and ensure data integrity'}, - {'type': 'maintainability', 'message': 'Break down large functions into smaller, more focused methods'}, - {'type': 'documentation', 'message': 'Add comprehensive docstrings with parameter descriptions and return types'}, - {'type': 'error_handling', 'message': 'Implement proper exception handling for robust error management'} + "analysis_id": f"analysis_{hash(code) % 10000}", + "language": kwargs.get("language", "python"), + "lines_analyzed": len(code.split("\n")), + "suggestions": [ + { + "type": "performance", + "message": "Consider using list comprehension for better performance and readability", + }, + { + "type": "style", + "message": "Function name should be snake_case according to PEP 8 style guidelines", + }, + { + "type": "security", + "message": "Validate input parameters to prevent injection attacks and ensure data integrity", + }, + { + "type": "maintainability", + "message": "Break down large functions into smaller, more focused methods", + }, + { + "type": "documentation", + "message": "Add comprehensive docstrings with parameter descriptions and return types", + }, + { + "type": "error_handling", + "message": "Implement proper exception handling for robust error management", + }, ], - 'detailed_analysis': { - 'complexity_metrics': {'cyclomatic_complexity': 12, 'cognitive_complexity': 15, 'nesting_depth': 4}, - 'quality_scores': {'maintainability': 8.5, 'reliability': 9.2, 'security': 8.8, 'efficiency': 7.9}, - 'code_smells': ['Long method', 'Feature envy', 'Data clumps', 'Primitive obsession'], - 'best_practices': ['Add type hints', 'Use constants for magic numbers', 'Implement logging', 'Add unit tests'] + "detailed_analysis": { + "complexity_metrics": { + "cyclomatic_complexity": 12, + "cognitive_complexity": 15, + "nesting_depth": 4, + }, + "quality_scores": { + "maintainability": 8.5, + "reliability": 9.2, + "security": 8.8, + "efficiency": 7.9, + }, + "code_smells": [ + "Long method", + "Feature envy", + "Data clumps", + "Primitive obsession", + ], + "best_practices": [ + "Add type hints", + "Use constants for magic numbers", + "Implement logging", + "Add unit tests", + ], }, - 'score': 8.5, - 'timestamp': time.time(), - 'cache_padding': 'This response has been padded to meet minimum token requirements for effective caching. The analysis includes comprehensive code quality metrics, security assessments, performance optimizations, and maintainability recommendations. Additional details about coding standards, best practices, and architectural considerations are included to provide thorough analysis results.' + "score": 8.5, + "timestamp": time.time(), + "cache_padding": "This response has been padded to meet minimum token requirements for effective caching. The analysis includes comprehensive code quality metrics, security assessments, performance optimizations, and maintainability recommendations. Additional details about coding standards, best practices, and architectural considerations are included to provide thorough analysis results.", } - - elif operation == 'test_generation': + + elif operation == "test_generation": return { - 'test_id': f"test_{time.time()}", - 'test_cases': [ - {'name': 'test_happy_path', 'type': 'unit', 'description': 'Tests normal execution flow with valid inputs'}, - {'name': 'test_edge_cases', 'type': 'unit', 'description': 'Tests boundary conditions and edge scenarios'}, - {'name': 'test_error_handling', 'type': 'unit', 'description': 'Tests proper error handling and exception management'}, - {'name': 'test_integration', 'type': 'integration', 'description': 'Tests interaction between different components'}, - {'name': 'test_performance', 'type': 'performance', 'description': 'Tests system performance under load conditions'}, - {'name': 'test_security', 'type': 'security', 'description': 'Tests security measures and vulnerability protection'} + "test_id": f"test_{time.time()}", + "test_cases": [ + { + "name": "test_happy_path", + "type": "unit", + "description": "Tests normal execution flow with valid inputs", + }, + { + "name": "test_edge_cases", + "type": "unit", + "description": "Tests boundary conditions and edge scenarios", + }, + { + "name": "test_error_handling", + "type": "unit", + "description": "Tests proper error handling and exception management", + }, + { + "name": "test_integration", + "type": "integration", + "description": "Tests interaction between different components", + }, + { + "name": "test_performance", + "type": "performance", + "description": "Tests system performance under load conditions", + }, + { + "name": "test_security", + "type": "security", + "description": "Tests security measures and vulnerability protection", + }, ], - 'detailed_test_plan': { - 'unit_tests': {'count': 15, 'coverage_target': 95, 'execution_time': '2.3s'}, - 'integration_tests': {'count': 8, 'coverage_target': 85, 'execution_time': '5.7s'}, - 'end_to_end_tests': {'count': 4, 'coverage_target': 75, 'execution_time': '12.1s'} + "detailed_test_plan": { + "unit_tests": { + "count": 15, + "coverage_target": 95, + "execution_time": "2.3s", + }, + "integration_tests": { + "count": 8, + "coverage_target": 85, + "execution_time": "5.7s", + }, + "end_to_end_tests": { + "count": 4, + "coverage_target": 75, + "execution_time": "12.1s", + }, }, - 'testing_framework_recommendations': ['pytest', 'unittest', 'mock', 'coverage.py'], - 'quality_gates': {'min_coverage': 90, 'max_execution_time': 300, 'zero_critical_bugs': True}, - 'coverage_estimate': 95.0, - 'timestamp': time.time(), - 'cache_padding': 'Comprehensive test generation results with detailed test cases, coverage analysis, performance metrics, and quality assurance recommendations for thorough testing strategy implementation.' + "testing_framework_recommendations": [ + "pytest", + "unittest", + "mock", + "coverage.py", + ], + "quality_gates": { + "min_coverage": 90, + "max_execution_time": 300, + "zero_critical_bugs": True, + }, + "coverage_estimate": 95.0, + "timestamp": time.time(), + "cache_padding": "Comprehensive test generation results with detailed test cases, coverage analysis, performance metrics, and quality assurance recommendations for thorough testing strategy implementation.", } - - elif operation == 'documentation': + + elif operation == "documentation": return { - 'doc_id': f"doc_{time.time()}", - 'sections': ['Overview', 'Parameters', 'Returns', 'Examples', 'Usage Guidelines', 'Best Practices', 'Troubleshooting'], - 'detailed_content': { - 'overview': 'Comprehensive documentation covering all aspects of the API functionality', - 'parameters': 'Detailed parameter descriptions with types, constraints, and examples', - 'returns': 'Complete return value documentation with success and error scenarios', - 'examples': 'Multiple code examples demonstrating various use cases and implementations', - 'usage_guidelines': 'Best practices for optimal API usage and integration patterns', - 'troubleshooting': 'Common issues, error codes, and resolution strategies' + "doc_id": f"doc_{time.time()}", + "sections": [ + "Overview", + "Parameters", + "Returns", + "Examples", + "Usage Guidelines", + "Best Practices", + "Troubleshooting", + ], + "detailed_content": { + "overview": "Comprehensive documentation covering all aspects of the API functionality", + "parameters": "Detailed parameter descriptions with types, constraints, and examples", + "returns": "Complete return value documentation with success and error scenarios", + "examples": "Multiple code examples demonstrating various use cases and implementations", + "usage_guidelines": "Best practices for optimal API usage and integration patterns", + "troubleshooting": "Common issues, error codes, and resolution strategies", }, - 'documentation_metrics': { - 'completeness_score': 92, - 'readability_score': 88, - 'example_coverage': 95, - 'accuracy_rating': 97 + "documentation_metrics": { + "completeness_score": 92, + "readability_score": 88, + "example_coverage": 95, + "accuracy_rating": 97, }, - 'generated_artifacts': ['API reference', 'User guide', 'Code examples', 'FAQ section'], - 'estimated_length': 2500, - 'timestamp': time.time(), - 'cache_padding': 'Comprehensive documentation generation with detailed content structure, quality metrics, and artifact descriptions for complete API documentation coverage.' + "generated_artifacts": [ + "API reference", + "User guide", + "Code examples", + "FAQ section", + ], + "estimated_length": 2500, + "timestamp": time.time(), + "cache_padding": "Comprehensive documentation generation with detailed content structure, quality metrics, and artifact descriptions for complete API documentation coverage.", } - - elif operation == 'refactoring': + + elif operation == "refactoring": return { - 'refactor_id': f"refactor_{time.time()}", - 'suggestions': [ - {'type': 'extract_method', 'confidence': 0.9, 'description': 'Extract complex logic into separate methods', 'impact': 'high'}, - {'type': 'rename_variable', 'confidence': 0.8, 'description': 'Improve variable naming for clarity', 'impact': 'medium'}, - {'type': 'optimize_loop', 'confidence': 0.7, 'description': 'Optimize loop performance and efficiency', 'impact': 'medium'}, - {'type': 'remove_duplication', 'confidence': 0.85, 'description': 'Eliminate duplicate code patterns', 'impact': 'high'}, - {'type': 'simplify_conditionals', 'confidence': 0.75, 'description': 'Simplify complex conditional statements', 'impact': 'medium'}, - {'type': 'improve_error_handling', 'confidence': 0.9, 'description': 'Enhance error handling mechanisms', 'impact': 'high'} + "refactor_id": f"refactor_{time.time()}", + "suggestions": [ + { + "type": "extract_method", + "confidence": 0.9, + "description": "Extract complex logic into separate methods", + "impact": "high", + }, + { + "type": "rename_variable", + "confidence": 0.8, + "description": "Improve variable naming for clarity", + "impact": "medium", + }, + { + "type": "optimize_loop", + "confidence": 0.7, + "description": "Optimize loop performance and efficiency", + "impact": "medium", + }, + { + "type": "remove_duplication", + "confidence": 0.85, + "description": "Eliminate duplicate code patterns", + "impact": "high", + }, + { + "type": "simplify_conditionals", + "confidence": 0.75, + "description": "Simplify complex conditional statements", + "impact": "medium", + }, + { + "type": "improve_error_handling", + "confidence": 0.9, + "description": "Enhance error handling mechanisms", + "impact": "high", + }, ], - 'refactoring_metrics': { - 'complexity_reduction': 25, - 'maintainability_improvement': 30, - 'performance_gain': 15, - 'code_quality_score': 8.7 + "refactoring_metrics": { + "complexity_reduction": 25, + "maintainability_improvement": 30, + "performance_gain": 15, + "code_quality_score": 8.7, }, - 'design_patterns_recommended': ['Strategy', 'Factory', 'Observer', 'Command'], - 'code_quality_improvements': { - 'cyclomatic_complexity': 'Reduced from 15 to 8', - 'code_duplication': 'Eliminated 23% duplicate code', - 'method_length': 'Average method length reduced by 40%', - 'class_coupling': 'Reduced coupling between components' + "design_patterns_recommended": [ + "Strategy", + "Factory", + "Observer", + "Command", + ], + "code_quality_improvements": { + "cyclomatic_complexity": "Reduced from 15 to 8", + "code_duplication": "Eliminated 23% duplicate code", + "method_length": "Average method length reduced by 40%", + "class_coupling": "Reduced coupling between components", }, - 'estimated_improvement': '25% performance boost with 30% maintainability increase', - 'timestamp': time.time(), - 'cache_padding': 'Detailed refactoring analysis with comprehensive suggestions, metrics, and quality improvements for optimal code structure and maintainability enhancement.' + "estimated_improvement": "25% performance boost with 30% maintainability increase", + "timestamp": time.time(), + "cache_padding": "Detailed refactoring analysis with comprehensive suggestions, metrics, and quality improvements for optimal code structure and maintainability enhancement.", } - + else: - return {'operation': operation, 'timestamp': time.time(), 'status': 'completed'} - - async def cached_request(self, operation: str, force_refresh: bool = False, **kwargs) -> Dict[str, Any]: + return { + "operation": operation, + "timestamp": time.time(), + "status": "completed", + } + + async def cached_request( + self, operation: str, force_refresh: bool = False, **kwargs + ) -> Dict[str, Any]: """Make a cached request to Arcade.dev API.""" cache_key = self._generate_cache_key(operation, **kwargs) - + # Check cache first (unless force refresh) if not force_refresh: cached_result = await self._get_from_cache(cache_key) if cached_result: # Return the data part, not the metadata - return cached_result.get('data', cached_result) - + return cached_result.get("data", cached_result) + # Make API call start_time = time.time() result = await self._simulate_api_call(operation, **kwargs) api_duration = time.time() - start_time - + # Add API call metadata - result['_api_duration'] = api_duration - result['_cached'] = False - + result["_api_duration"] = api_duration + result["_cached"] = False + # Cache the result ttl = self._get_cache_ttl(operation) await self._set_cache(cache_key, result, ttl) - + self.logger.info(f"API call completed: {operation} ({api_duration:.2f}s)") return result - + async def health_check(self, force_refresh: bool = False) -> Dict[str, Any]: """Get API health status with caching.""" - return await self.cached_request('health', force_refresh=force_refresh) - + return await self.cached_request("health", force_refresh=force_refresh) + async def get_user_info(self, force_refresh: bool = False) -> Dict[str, Any]: """Get user information with caching.""" - return await self.cached_request('user_info', force_refresh=force_refresh) - - async def analyze_code(self, code: str, language: str = 'python', force_refresh: bool = False) -> Dict[str, Any]: + return await self.cached_request("user_info", force_refresh=force_refresh) + + async def analyze_code( + self, code: str, language: str = "python", force_refresh: bool = False + ) -> Dict[str, Any]: """Analyze code with caching.""" return await self.cached_request( - 'code_analysis', - force_refresh=force_refresh, - code=code, - language=language + "code_analysis", force_refresh=force_refresh, code=code, language=language ) - - async def generate_tests(self, code: str, test_type: str = 'unit', force_refresh: bool = False) -> Dict[str, Any]: + + async def generate_tests( + self, code: str, test_type: str = "unit", force_refresh: bool = False + ) -> Dict[str, Any]: """Generate tests with caching.""" return await self.cached_request( - 'test_generation', + "test_generation", force_refresh=force_refresh, code=code, - test_type=test_type + test_type=test_type, ) - - async def generate_documentation(self, code: str, doc_type: str = 'api', force_refresh: bool = False) -> Dict[str, Any]: + + async def generate_documentation( + self, code: str, doc_type: str = "api", force_refresh: bool = False + ) -> Dict[str, Any]: """Generate documentation with caching.""" return await self.cached_request( - 'documentation', - force_refresh=force_refresh, - code=code, - doc_type=doc_type + "documentation", force_refresh=force_refresh, code=code, doc_type=doc_type ) - - async def suggest_refactoring(self, code: str, focus: str = 'performance', force_refresh: bool = False) -> Dict[str, Any]: + + async def suggest_refactoring( + self, code: str, focus: str = "performance", force_refresh: bool = False + ) -> Dict[str, Any]: """Get refactoring suggestions with caching.""" return await self.cached_request( - 'refactoring', - force_refresh=force_refresh, - code=code, - focus=focus + "refactoring", force_refresh=force_refresh, code=code, focus=focus ) - + def get_cache_stats(self) -> Dict[str, Any]: """Get cache performance statistics.""" - total_requests = self.stats['cache_hits'] + self.stats['cache_misses'] - hit_rate = (self.stats['cache_hits'] / total_requests * 100) if total_requests > 0 else 0 - + total_requests = self.stats["cache_hits"] + self.stats["cache_misses"] + hit_rate = ( + (self.stats["cache_hits"] / total_requests * 100) + if total_requests > 0 + else 0 + ) + return { - 'cache_hits': self.stats['cache_hits'], - 'cache_misses': self.stats['cache_misses'], - 'api_calls': self.stats['api_calls'], - 'cache_sets': self.stats['cache_sets'], - 'cache_errors': self.stats['cache_errors'], - 'hit_rate_percent': round(hit_rate, 2), - 'total_requests': total_requests + "cache_hits": self.stats["cache_hits"], + "cache_misses": self.stats["cache_misses"], + "api_calls": self.stats["api_calls"], + "cache_sets": self.stats["cache_sets"], + "cache_errors": self.stats["cache_errors"], + "hit_rate_percent": round(hit_rate, 2), + "total_requests": total_requests, } - + async def clear_cache(self, operation: str = None): """Clear cache for specific operation or all cache.""" if operation: @@ -360,26 +504,26 @@ async def clear_cache(self, operation: str = None): async def demonstrate_caching(): """Demonstrate the caching capabilities.""" print("šŸ”§ Initializing cached Arcade.dev client...") - + # Get default cache configuration and initialize cache manager fact_cache_config = get_default_cache_config() cache_manager = CacheManager(fact_cache_config) - + # Configure caching behavior cache_config = CacheConfig( default_ttl=3600, enabled=True, ttl_by_operation={ - 'health': 300, - 'code_analysis': 7200, # Cache code analysis for 2 hours - } + "health": 300, + "code_analysis": 7200, # Cache code analysis for 2 hours + }, ) - + # Create cached client - api_key = os.getenv('ARCADE_API_KEY', 'demo_key') + api_key = os.getenv("ARCADE_API_KEY", "demo_key") client = CachedArcadeClient(api_key, cache_manager, cache_config) - - sample_code = ''' + + sample_code = """ def calculate_fibonacci(n): if n <= 1: return n @@ -388,11 +532,11 @@ def calculate_fibonacci(n): def main(): for i in range(10): print(f"F({i}) = {calculate_fibonacci(i)}") -''' - +""" + print("\nšŸ“Š Performance Comparison: First vs Cached Requests") print("=" * 60) - + # First request (will hit API) print("šŸ” First code analysis request (API call)...") start_time = time.time() @@ -400,7 +544,7 @@ def main(): first_duration = time.time() - start_time print(f" Duration: {first_duration:.2f}s") print(f" Suggestions: {len(result1.get('suggestions', []))}") - + # Second request (should hit cache) print("\nšŸ” Second code analysis request (cached)...") start_time = time.time() @@ -408,41 +552,43 @@ def main(): second_duration = time.time() - start_time print(f" Duration: {second_duration:.2f}s") print(f" Suggestions: {len(result2.get('suggestions', []))}") - + # Performance improvement if first_duration > 0: improvement = ((first_duration - second_duration) / first_duration) * 100 print(f"\n⚔ Performance improvement: {improvement:.1f}% faster") - + # Demonstrate force refresh print("\nšŸ”„ Force refresh (bypassing cache)...") start_time = time.time() result3 = await client.analyze_code(sample_code, force_refresh=True) third_duration = time.time() - start_time print(f" Duration: {third_duration:.2f}s") - + # Show cache statistics print("\nšŸ“ˆ Cache Statistics:") stats = client.get_cache_stats() for key, value in stats.items(): print(f" {key}: {value}") - + # Demonstrate different operations print("\nšŸŽÆ Testing different operations...") - + operations = [ - ('Health Check', client.health_check), - ('User Info', client.get_user_info), - ('Test Generation', lambda: client.generate_tests(sample_code)), - ('Documentation', lambda: client.generate_documentation(sample_code)), - ('Refactoring', lambda: client.suggest_refactoring(sample_code)) + ("Health Check", client.health_check), + ("User Info", client.get_user_info), + ("Test Generation", lambda: client.generate_tests(sample_code)), + ("Documentation", lambda: client.generate_documentation(sample_code)), + ("Refactoring", lambda: client.suggest_refactoring(sample_code)), ] - + for name, operation in operations: print(f" {name}...") result = await operation() - print(f" āœ… Completed (cached for {client._get_cache_ttl(name.lower().replace(' ', '_'))}s)") - + print( + f" āœ… Completed (cached for {client._get_cache_ttl(name.lower().replace(' ', '_'))}s)" + ) + # Final statistics print("\nšŸ“Š Final Cache Statistics:") final_stats = client.get_cache_stats() @@ -454,17 +600,17 @@ async def main(): """Main demonstration function.""" logging.basicConfig( level=logging.INFO, - format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", ) - + print("šŸŽ® Arcade.dev Cache Integration Example") print("=" * 50) - + try: await demonstrate_caching() print("\nšŸŽ‰ Cache integration example completed successfully!") return 0 - + except Exception as e: print(f"\nāŒ Error: {e}") return 1 @@ -472,4 +618,4 @@ async def main(): if __name__ == "__main__": exit_code = asyncio.run(main()) - sys.exit(exit_code) \ No newline at end of file + sys.exit(exit_code) diff --git a/examples/arcade-dev/08_advanced_tools/advanced_tool_usage.py b/examples/arcade-dev/08_advanced_tools/advanced_tool_usage.py index a90ff72..1f02432 100644 --- a/examples/arcade-dev/08_advanced_tools/advanced_tool_usage.py +++ b/examples/arcade-dev/08_advanced_tools/advanced_tool_usage.py @@ -37,42 +37,46 @@ from src.monitoring.metrics import MetricsCollector from src.core.errors import ToolExecutionError + # Define mock ArcadeClient for demo mode class MockArcadeClient: def __init__(self, api_key: str = None, **kwargs): self.api_key = api_key self._demo_mode = True - + async def __aenter__(self): return self - + async def __aexit__(self, exc_type, exc_val, exc_tb): pass - + async def execute_tool(self, payload: Dict[str, Any]) -> Dict[str, Any]: # Mock response for demo - tool_name = payload.get('tool', 'unknown') + tool_name = payload.get("tool", "unknown") return { - 'result': { - 'success': True, - 'data': f'Mock result from {tool_name}', - 'execution_time': 0.1 + "result": { + "success": True, + "data": f"Mock result from {tool_name}", + "execution_time": 0.1, } } + # Try to import real ArcadeClient, fall back to mock try: from src.arcade.client import ArcadeClient as RealArcadeClient + REAL_ARCADE_AVAILABLE = True except ImportError: REAL_ARCADE_AVAILABLE = False + # Create a wrapper that tries real client first, falls back to mock class ArcadeClient: def __new__(cls, api_key: str = None, **kwargs): # Check if we should use demo mode demo_mode = api_key == "demo" or not api_key or not REAL_ARCADE_AVAILABLE - + if demo_mode: return MockArcadeClient(api_key, **kwargs) else: @@ -83,26 +87,29 @@ def __new__(cls, api_key: str = None, **kwargs): print("šŸŽÆ Falling back to demo mode due to missing dependencies") return MockArcadeClient(api_key, **kwargs) + ARCADE_AVAILABLE = REAL_ARCADE_AVAILABLE + # Mock classes for demo mode class MockToolExecutor: """Mock tool executor for demo mode.""" - + async def execute(self, tool_name: str, params: Dict[str, Any]) -> Dict[str, Any]: await asyncio.sleep(0.1) # Simulate processing return { - 'success': True, - 'data': f'Mock execution result for {tool_name}', - 'execution_time': 0.1 + "success": True, + "data": f"Mock execution result for {tool_name}", + "execution_time": 0.1, } + class MockMetricsCollector: """Mock metrics collector for demo mode.""" - + async def record_gauge(self, name: str, value: float, tags: Dict[str, str] = None): pass - + async def record_counter(self, name: str, value: int = 1): pass @@ -110,6 +117,7 @@ async def record_counter(self, name: str, value: int = 1): @dataclass class ToolChainStep: """Represents a single step in a tool execution chain.""" + tool_name: str params: Dict[str, Any] condition: Optional[Callable[[Dict[str, Any]], bool]] = None @@ -121,6 +129,7 @@ class ToolChainStep: @dataclass class ToolOrchestrationResult: """Result of tool orchestration execution.""" + success: bool results: List[Dict[str, Any]] = field(default_factory=list) errors: List[str] = field(default_factory=list) @@ -131,55 +140,55 @@ class ToolOrchestrationResult: class AdvancedToolOrchestrator: """Advanced tool orchestration system with chaining and conditional execution.""" - + def __init__(self, arcade_client: ArcadeClient, cache_manager: CacheManager): self.arcade_client = arcade_client self.cache_manager = cache_manager - + # Create demo-compatible tool executor try: self.tool_executor = ToolExecutor(arcade_client=arcade_client) except Exception: # Create minimal mock executor for demo self.tool_executor = MockToolExecutor() - + try: self.metrics = MetricsCollector() except Exception: # Create minimal mock metrics collector self.metrics = MockMetricsCollector() - + self.logger = logging.getLogger(__name__) self.thread_pool = ThreadPoolExecutor(max_workers=4) - + # Dynamic tool registry self.dynamic_tools: Dict[str, Callable] = {} - + async def __aenter__(self): """Async context manager entry.""" await self.arcade_client.__aenter__() return self - + async def __aexit__(self, exc_type, exc_val, exc_tb): """Async context manager exit.""" await self.arcade_client.__aexit__(exc_type, exc_val, exc_tb) self.thread_pool.shutdown(wait=True) - + def register_dynamic_tool(self, name: str, func: Callable): """Register a dynamically created tool.""" self.dynamic_tools[name] = func self.logger.info(f"Registered dynamic tool: {name}") - + async def execute_tool_chain( - self, + self, steps: List[ToolChainStep], parallel: bool = False, - aggregate_results: bool = True + aggregate_results: bool = True, ) -> ToolOrchestrationResult: """Execute a chain of tools with conditional branching and result aggregation.""" start_time = time.time() result = ToolOrchestrationResult(success=True) - + try: if parallel: # Execute steps in parallel where possible @@ -187,180 +196,185 @@ async def execute_tool_chain( else: # Execute steps sequentially result = await self._execute_sequential_chain(steps, result) - + if aggregate_results: - result.metadata['aggregated_data'] = await self._aggregate_results(result.results) - + result.metadata["aggregated_data"] = await self._aggregate_results( + result.results + ) + except Exception as e: result.success = False result.errors.append(f"Chain execution failed: {str(e)}") self.logger.error(f"Tool chain execution error: {e}") - + finally: result.execution_time = time.time() - start_time await self._record_metrics(result) - + return result - + async def _execute_sequential_chain( - self, - steps: List[ToolChainStep], - result: ToolOrchestrationResult + self, steps: List[ToolChainStep], result: ToolOrchestrationResult ) -> ToolOrchestrationResult: """Execute tool chain sequentially.""" context = {} - + for i, step in enumerate(steps): try: # Check condition if provided if step.condition and not step.condition(context): self.logger.info(f"Skipping step {i}: condition not met") continue - + # Execute tool with retries step_result = await self._execute_tool_with_retry(step, context) - + # Transform result if transform function provided if step.transform: step_result = step.transform(step_result) - + # Update context and results context.update(step_result) result.results.append(step_result) result.steps_completed += 1 - + self.logger.info(f"Completed step {i}: {step.tool_name}") - + except Exception as e: error_msg = f"Step {i} ({step.tool_name}) failed: {str(e)}" result.errors.append(error_msg) self.logger.error(error_msg) - + # Decide whether to continue or abort chain if not self._should_continue_on_error(step, e): result.success = False break - + return result - + async def _execute_parallel_chain( - self, - steps: List[ToolChainStep], - result: ToolOrchestrationResult + self, steps: List[ToolChainStep], result: ToolOrchestrationResult ) -> ToolOrchestrationResult: """Execute independent tool steps in parallel.""" # Group steps by dependencies (simplified: assume all independent for now) tasks = [] - + for step in steps: task = asyncio.create_task(self._execute_tool_with_retry(step, {})) tasks.append((step, task)) - + # Wait for all tasks to complete for step, task in tasks: try: step_result = await task if step.transform: step_result = step.transform(step_result) - + result.results.append(step_result) result.steps_completed += 1 - + except Exception as e: error_msg = f"Parallel step ({step.tool_name}) failed: {str(e)}" result.errors.append(error_msg) self.logger.error(error_msg) - + return result - + async def _execute_tool_with_retry( - self, - step: ToolChainStep, - context: Dict[str, Any] + self, step: ToolChainStep, context: Dict[str, Any] ) -> Dict[str, Any]: """Execute a single tool with retry logic.""" last_error = None - + for attempt in range(step.retry_count): try: # Merge step params with context merged_params = {**step.params, **context} - + # Check if it's a dynamic tool if step.tool_name in self.dynamic_tools: - return await self._execute_dynamic_tool(step.tool_name, merged_params) - + return await self._execute_dynamic_tool( + step.tool_name, merged_params + ) + # Execute via Arcade.dev or local executor - if step.tool_name.startswith('arcade:'): + if step.tool_name.startswith("arcade:"): tool_name = step.tool_name[7:] # Remove 'arcade:' prefix return await self._execute_arcade_tool(tool_name, merged_params) else: return await self._execute_local_tool(step.tool_name, merged_params) - + except Exception as e: last_error = e if attempt < step.retry_count - 1: - wait_time = 2 ** attempt - self.logger.warning(f"Tool {step.tool_name} attempt {attempt + 1} failed, retrying in {wait_time}s") + wait_time = 2**attempt + self.logger.warning( + f"Tool {step.tool_name} attempt {attempt + 1} failed, retrying in {wait_time}s" + ) await asyncio.sleep(wait_time) - + raise last_error - - async def _execute_arcade_tool(self, tool_name: str, params: Dict[str, Any]) -> Dict[str, Any]: + + async def _execute_arcade_tool( + self, tool_name: str, params: Dict[str, Any] + ) -> Dict[str, Any]: """Execute tool via Arcade.dev API.""" payload = { "tool": tool_name, "parameters": params, - "context": { - "source": "fact_sdk", - "timestamp": time.time() - } + "context": {"source": "fact_sdk", "timestamp": time.time()}, } - + response = await self.arcade_client.execute_tool(payload) - return response.get('result', {}) - - async def _execute_local_tool(self, tool_name: str, params: Dict[str, Any]) -> Dict[str, Any]: + return response.get("result", {}) + + async def _execute_local_tool( + self, tool_name: str, params: Dict[str, Any] + ) -> Dict[str, Any]: """Execute tool locally using FACT executor.""" return await self.tool_executor.execute(tool_name, params) - - async def _execute_dynamic_tool(self, tool_name: str, params: Dict[str, Any]) -> Dict[str, Any]: + + async def _execute_dynamic_tool( + self, tool_name: str, params: Dict[str, Any] + ) -> Dict[str, Any]: """Execute dynamically registered tool.""" tool_func = self.dynamic_tools[tool_name] - + # Run in thread pool if not async if asyncio.iscoroutinefunction(tool_func): return await tool_func(**params) else: loop = asyncio.get_event_loop() - return await loop.run_in_executor(self.thread_pool, lambda: tool_func(**params)) - + return await loop.run_in_executor( + self.thread_pool, lambda: tool_func(**params) + ) + def _should_continue_on_error(self, step: ToolChainStep, error: Exception) -> bool: """Determine if chain should continue after an error.""" # Custom logic based on error type and step configuration if isinstance(error, ToolExecutionError) and error.is_recoverable: return True return False - + async def _aggregate_results(self, results: List[Dict[str, Any]]) -> Dict[str, Any]: """Aggregate results from multiple tool executions.""" aggregated = { "total_results": len(results), - "success_count": sum(1 for r in results if r.get('success', True)), + "success_count": sum(1 for r in results if r.get("success", True)), "combined_data": {}, "metrics": { - "execution_times": [r.get('execution_time', 0) for r in results], - "data_sizes": [len(str(r)) for r in results] - } + "execution_times": [r.get("execution_time", 0) for r in results], + "data_sizes": [len(str(r)) for r in results], + }, } - + # Merge data from all results for result in results: - if 'data' in result: - aggregated['combined_data'].update(result['data']) - + if "data" in result: + aggregated["combined_data"].update(result["data"]) + return aggregated - + async def _record_metrics(self, result: ToolOrchestrationResult): """Record execution metrics.""" # Use the correct MetricsCollector method @@ -368,34 +382,31 @@ async def _record_metrics(self, result: ToolOrchestrationResult): tool_name="tool_chain", success=result.success, execution_time=result.execution_time * 1000, # Convert to milliseconds - metadata={"steps_completed": result.steps_completed} + metadata={"steps_completed": result.steps_completed}, ) - + if result.errors: # Record error metrics using the correct method signature self.metrics.record_tool_execution( tool_name="tool_chain_errors", success=False, execution_time=0, - metadata={"error_count": len(result.errors)} + metadata={"error_count": len(result.errors)}, ) - + async def create_conditional_branch( - self, + self, condition_tool: str, true_branch: List[ToolChainStep], - false_branch: List[ToolChainStep] + false_branch: List[ToolChainStep], ) -> ToolOrchestrationResult: """Create conditional execution branches based on tool result.""" # Execute condition tool - condition_step = ToolChainStep( - tool_name=condition_tool, - params={} - ) - + condition_step = ToolChainStep(tool_name=condition_tool, params={}) + condition_result = await self._execute_tool_with_retry(condition_step, {}) - branch_condition = condition_result.get('condition_met', False) - + branch_condition = condition_result.get("condition_met", False) + # Execute appropriate branch if branch_condition: self.logger.info("Executing true branch") @@ -403,24 +414,22 @@ async def create_conditional_branch( else: self.logger.info("Executing false branch") return await self.execute_tool_chain(false_branch) - + def generate_dynamic_tool( - self, - name: str, - template: str, - parameters: Dict[str, Any] + self, name: str, template: str, parameters: Dict[str, Any] ) -> Callable: """Generate a dynamic tool from a template.""" + def dynamic_tool(**kwargs): # Simple template substitution for demonstration code = template.format(**parameters, **kwargs) - + # Execute the generated code (in a real implementation, use safer execution) local_vars = {} exec(code, {"__builtins__": {}}, local_vars) - - return local_vars.get('result', {}) - + + return local_vars.get("result", {}) + self.register_dynamic_tool(name, dynamic_tool) return dynamic_tool @@ -428,87 +437,99 @@ def dynamic_tool(**kwargs): # Example tool implementations for demonstration def create_sample_tools(orchestrator: AdvancedToolOrchestrator): """Create sample dynamic tools for demonstration.""" - + # Data validation tool - def validate_data(data: Dict[str, Any] = None, schema: Dict[str, Any] = None, **kwargs) -> Dict[str, Any]: + def validate_data( + data: Dict[str, Any] = None, schema: Dict[str, Any] = None, **kwargs + ) -> Dict[str, Any]: """Validate data against schema.""" # Use provided parameters or defaults from kwargs if data is None: - data = kwargs.get('data', {'name': 'Demo', 'age': 25}) + data = kwargs.get("data", {"name": "Demo", "age": 25}) if schema is None: - schema = kwargs.get('schema', {'name': str, 'age': int}) - + schema = kwargs.get("schema", {"name": str, "age": int}) + errors = [] for field, field_type in schema.items(): if field not in data: errors.append(f"Missing field: {field}") elif not isinstance(data[field], field_type): - errors.append(f"Invalid type for {field}: expected {field_type.__name__}") - - return { - 'valid': len(errors) == 0, - 'errors': errors, - 'data': data - } - + errors.append( + f"Invalid type for {field}: expected {field_type.__name__}" + ) + + return {"valid": len(errors) == 0, "errors": errors, "data": data} + # Data transformation tool - def transform_data(transformations: List[str] = None, data: Dict[str, Any] = None, **kwargs) -> Dict[str, Any]: + def transform_data( + transformations: List[str] = None, data: Dict[str, Any] = None, **kwargs + ) -> Dict[str, Any]: """Apply transformations to data.""" # Handle context from previous steps - if data is None and 'data' in kwargs: - data = kwargs['data'] + if data is None and "data" in kwargs: + data = kwargs["data"] elif data is None: - data = kwargs.get('context', {}) - + data = kwargs.get("context", {}) + if transformations is None: - transformations = kwargs.get('transformations', ['add_timestamp']) - + transformations = kwargs.get("transformations", ["add_timestamp"]) + result = data.copy() if data else {} - + for transform in transformations: - if transform == 'uppercase_strings': + if transform == "uppercase_strings": for key, value in result.items(): if isinstance(value, str): result[key] = value.upper() - elif transform == 'add_timestamp': - result['timestamp'] = time.time() - - return {'transformed_data': result} - + elif transform == "add_timestamp": + result["timestamp"] = time.time() + + return {"transformed_data": result} + # Analysis tool - async def analyze_results(results: List[Dict[str, Any]] = None, **kwargs) -> Dict[str, Any]: + async def analyze_results( + results: List[Dict[str, Any]] = None, **kwargs + ) -> Dict[str, Any]: """Analyze aggregated results.""" await asyncio.sleep(0.1) # Simulate async processing - + # Handle context from previous steps if results is None: - results = kwargs.get('results', []) - + results = kwargs.get("results", []) + # If no results provided, create demo analysis if not results: - results = [{'execution_time': 0.1, 'success': True}] - + results = [{"execution_time": 0.1, "success": True}] + return { - 'analysis': { - 'total_items': len(results), - 'avg_processing_time': sum(r.get('execution_time', 0) for r in results) / len(results) if results else 0, - 'error_rate': sum(1 for r in results if 'error' in r) / len(results) if results else 0 + "analysis": { + "total_items": len(results), + "avg_processing_time": ( + sum(r.get("execution_time", 0) for r in results) / len(results) + if results + else 0 + ), + "error_rate": ( + sum(1 for r in results if "error" in r) / len(results) + if results + else 0 + ), } } - + # Register tools - orchestrator.register_dynamic_tool('validate_data', validate_data) - orchestrator.register_dynamic_tool('transform_data', transform_data) - orchestrator.register_dynamic_tool('analyze_results', analyze_results) + orchestrator.register_dynamic_tool("validate_data", validate_data) + orchestrator.register_dynamic_tool("transform_data", transform_data) + orchestrator.register_dynamic_tool("analyze_results", analyze_results) async def demonstrate_tool_chaining(): """Demonstrate basic tool chaining.""" print("šŸ”— Demonstrating Tool Chaining") - + # Create mock arcade client and cache manager with proper config arcade_client = ArcadeClient(api_key=os.getenv("ARCADE_API_KEY", "demo")) - + # Create cache config for demo cache_config = { "prefix": "demo_cache", @@ -516,39 +537,36 @@ async def demonstrate_tool_chaining(): "max_size": "50MB", "ttl_seconds": 3600, "hit_target_ms": 48, - "miss_target_ms": 140 + "miss_target_ms": 140, } cache_manager = CacheManager(cache_config) - + async with AdvancedToolOrchestrator(arcade_client, cache_manager) as orchestrator: # Create sample tools create_sample_tools(orchestrator) - + # Define tool chain steps = [ ToolChainStep( - tool_name='validate_data', - params={ - 'data': {'name': 'John', 'age': 30}, - 'schema': {'name': str, 'age': int} - } - ), - ToolChainStep( - tool_name='transform_data', + tool_name="validate_data", params={ - 'transformations': ['uppercase_strings', 'add_timestamp'] + "data": {"name": "John", "age": 30}, + "schema": {"name": str, "age": int}, }, - condition=lambda ctx: ctx.get('valid', False) # Only if validation passed ), ToolChainStep( - tool_name='analyze_results', - params={} - ) + tool_name="transform_data", + params={"transformations": ["uppercase_strings", "add_timestamp"]}, + condition=lambda ctx: ctx.get( + "valid", False + ), # Only if validation passed + ), + ToolChainStep(tool_name="analyze_results", params={}), ] - + # Execute chain result = await orchestrator.execute_tool_chain(steps) - + print(f"āœ… Chain completed: {result.steps_completed} steps") print(f"ā±ļø Execution time: {result.execution_time:.2f}s") if result.errors: @@ -558,9 +576,9 @@ async def demonstrate_tool_chaining(): async def demonstrate_conditional_branching(): """Demonstrate conditional tool branching.""" print("\n🌿 Demonstrating Conditional Branching") - + arcade_client = ArcadeClient(api_key=os.getenv("ARCADE_API_KEY", "demo")) - + # Create cache config for demo cache_config = { "prefix": "demo_cache", @@ -568,52 +586,45 @@ async def demonstrate_conditional_branching(): "max_size": "50MB", "ttl_seconds": 3600, "hit_target_ms": 48, - "miss_target_ms": 140 + "miss_target_ms": 140, } cache_manager = CacheManager(cache_config) - + async with AdvancedToolOrchestrator(arcade_client, cache_manager) as orchestrator: create_sample_tools(orchestrator) - + # Create condition tool def check_condition(threshold: float = 0.5) -> Dict[str, Any]: import random + value = random.random() - return { - 'condition_met': value > threshold, - 'value': value - } - - orchestrator.register_dynamic_tool('check_condition', check_condition) - + return {"condition_met": value > threshold, "value": value} + + orchestrator.register_dynamic_tool("check_condition", check_condition) + # Define branches true_branch = [ ToolChainStep( - tool_name='transform_data', + tool_name="transform_data", params={ - 'data': {'status': 'success'}, - 'transformations': ['add_timestamp'] - } + "data": {"status": "success"}, + "transformations": ["add_timestamp"], + }, ) ] - + false_branch = [ ToolChainStep( - tool_name='validate_data', - params={ - 'data': {'status': 'failure'}, - 'schema': {'status': str} - } + tool_name="validate_data", + params={"data": {"status": "failure"}, "schema": {"status": str}}, ) ] - + # Execute conditional branch result = await orchestrator.create_conditional_branch( - 'check_condition', - true_branch, - false_branch + "check_condition", true_branch, false_branch ) - + print(f"āœ… Branch completed: {result.success}") print(f"šŸ“Š Results: {len(result.results)} items") @@ -621,9 +632,9 @@ def check_condition(threshold: float = 0.5) -> Dict[str, Any]: async def demonstrate_dynamic_tool_generation(): """Demonstrate dynamic tool generation.""" print("\nšŸ› ļø Demonstrating Dynamic Tool Generation") - + arcade_client = ArcadeClient(api_key=os.getenv("ARCADE_API_KEY", "demo")) - + # Create cache config for demo cache_config = { "prefix": "demo_cache", @@ -631,10 +642,10 @@ async def demonstrate_dynamic_tool_generation(): "max_size": "50MB", "ttl_seconds": 3600, "hit_target_ms": 48, - "miss_target_ms": 140 + "miss_target_ms": 140, } cache_manager = CacheManager(cache_config) - + async with AdvancedToolOrchestrator(arcade_client, cache_manager) as orchestrator: # Generate dynamic tool from template template = """ @@ -649,24 +660,17 @@ def calculate_{operation}(a, b): result = calculate_{operation}({a}, {b}) """ - + # Generate calculator tool orchestrator.generate_dynamic_tool( - 'calculator', - template, - {'operation': 'add', 'a': 10, 'b': 20} + "calculator", template, {"operation": "add", "a": 10, "b": 20} ) - + # Execute generated tool - steps = [ - ToolChainStep( - tool_name='calculator', - params={} - ) - ] - + steps = [ToolChainStep(tool_name="calculator", params={})] + result = await orchestrator.execute_tool_chain(steps) - + print(f"āœ… Dynamic tool executed: {result.success}") if result.results: print(f"šŸ”¢ Calculation result: {result.results[0]}") @@ -676,16 +680,16 @@ async def main(): """Main demonstration function.""" logging.basicConfig( level=logging.INFO, - format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", ) - + # Check for demo mode api_key = os.getenv("ARCADE_API_KEY", "demo") demo_mode = api_key == "demo" or not api_key or not REAL_ARCADE_AVAILABLE - + print("šŸš€ Advanced Tool Usage Example") print("=" * 50) - + if demo_mode: print("šŸŽÆ Running in DEMO MODE") print(" - Using mock Arcade client") @@ -696,25 +700,28 @@ async def main(): print("šŸ”— Connected to Arcade.dev API") print(f" - API Key: {api_key[:8]}...") print() - + try: # Run demonstrations await demonstrate_tool_chaining() await demonstrate_conditional_branching() await demonstrate_dynamic_tool_generation() - + print("\nšŸŽ‰ All advanced tool usage examples completed successfully!") if demo_mode: - print("šŸ’” To run with real API integration, set the ARCADE_API_KEY environment variable") + print( + "šŸ’” To run with real API integration, set the ARCADE_API_KEY environment variable" + ) return 0 - + except Exception as e: print(f"āŒ Error: {e}") import traceback + traceback.print_exc() return 1 if __name__ == "__main__": exit_code = asyncio.run(main()) - sys.exit(exit_code) \ No newline at end of file + sys.exit(exit_code) diff --git a/examples/arcade-dev/09_testing/arcade_integration_tests.py b/examples/arcade-dev/09_testing/arcade_integration_tests.py index c2b54f7..09d937c 100644 --- a/examples/arcade-dev/09_testing/arcade_integration_tests.py +++ b/examples/arcade-dev/09_testing/arcade_integration_tests.py @@ -39,35 +39,38 @@ from src.monitoring.metrics import MetricsCollector from src.core.errors import ToolExecutionError, ConfigurationError + # Import arcade integration classes (create simple mocks for testing) class ArcadeClient: """Mock ArcadeClient for testing.""" - - def __init__(self, api_key: str, cache_manager=None, max_retries: int = 3, timeout: int = 30): + + def __init__( + self, api_key: str, cache_manager=None, max_retries: int = 3, timeout: int = 30 + ): self.api_key = api_key self.cache_manager = cache_manager self.max_retries = max_retries self.timeout = timeout self.session = None - + async def connect(self): """Mock connect method.""" pass - + async def disconnect(self): """Mock disconnect method.""" pass - + async def health_check(self): """Mock health check that uses cache.""" cache_key = "health_check" - + # Check cache first if available if self.cache_manager: cached = self.cache_manager.get(cache_key) if cached: return json.loads(cached.content) - + # Make "API" request - create a larger response to meet token minimum result = { "status": "healthy", @@ -86,27 +89,27 @@ async def health_check(self): "authorization": "running", "cache": "running", "metrics": "running", - "logging": "running" + "logging": "running", }, "performance_metrics": { "avg_response_time": "125ms", "p95_response_time": "350ms", "p99_response_time": "750ms", - "error_rate": "0.02%" - } - } + "error_rate": "0.02%", + }, + }, } - + # Store in cache if available if self.cache_manager: self.cache_manager.store(cache_key, json.dumps(result)) - + return result - + async def analyze_code(self, code: str, language: str): """Mock analyze code.""" return {"result": f"analyzed_{code}_{language}"} - + async def make_request(self, method: str, endpoint: str, **kwargs): """Mock request method.""" return {"status": "success", "method": method, "endpoint": endpoint} @@ -114,31 +117,31 @@ async def make_request(self, method: str, endpoint: str, **kwargs): class ArcadeGateway: """Mock ArcadeGateway for testing.""" - + def __init__(self, api_key: str, enable_failover: bool = False): self.api_key = api_key self.enable_failover = enable_failover - + async def initialize(self): """Mock initialize method.""" pass - + async def cleanup(self): """Mock cleanup method.""" pass - + async def execute_local_tool(self, tool_name: str, params: Dict[str, Any]): """Mock local tool execution.""" return {"tool": tool_name, "params": params, "result": "local_result"} - + async def execute_arcade_tool(self, tool_name: str, params: Dict[str, Any]): """Mock arcade tool execution.""" return {"tool": tool_name, "params": params, "result": "arcade_result"} - + async def execute_hybrid_workflow(self, workflow: List[Dict[str, Any]]): """Mock hybrid workflow execution.""" return {"workflow": workflow, "result": "hybrid_result"} - + async def analyze_with_failover(self, code: str): """Mock analyze with failover.""" return {"result": "fallback_analysis", "code": code} @@ -147,6 +150,7 @@ async def analyze_with_failover(self, code: str): @dataclass class TestResult: """Test execution result.""" + test_name: str success: bool execution_time: float @@ -156,71 +160,65 @@ class TestResult: class MockArcadeResponse: """Mock Arcade.dev API response.""" - + def __init__(self, status: int = 200, data: Dict[str, Any] = None): self.status = status self.data = data or {} - + async def json(self): return self.data - + def raise_for_status(self): if self.status >= 400: raise aiohttp.ClientResponseError( - request_info=None, - history=None, - status=self.status + request_info=None, history=None, status=self.status ) class MockArcadeSession: """Mock aiohttp session for Arcade.dev API.""" - + def __init__(self, responses: Dict[str, MockArcadeResponse] = None): self.responses = responses or {} self.request_history = [] - + async def request(self, method: str, url: str, **kwargs): """Mock request method.""" - self.request_history.append({ - 'method': method, - 'url': url, - 'kwargs': kwargs - }) - + self.request_history.append({"method": method, "url": url, "kwargs": kwargs}) + # Return configured response or default success key = f"{method}:{url}" if key in self.responses: response = self.responses[key] else: - response = MockArcadeResponse(200, {'status': 'success'}) - + response = MockArcadeResponse(200, {"status": "success"}) + return response - + async def close(self): """Mock close method.""" pass - + def __aenter__(self): return self - + async def __aexit__(self, *args): await self.close() class ArcadeIntegrationTestSuite: """Comprehensive test suite for Arcade.dev integration.""" - + def __init__(self): self.test_results: List[TestResult] = [] self.mock_session = None self.cache_manager = None self.temp_dir = None - + async def setup(self): """Set up test environment.""" self.temp_dir = tempfile.mkdtemp() - + # Create cache configuration cache_config = { "prefix": "test_cache", @@ -228,686 +226,757 @@ async def setup(self): "max_size": "10MB", "ttl_seconds": 3600, "hit_target_ms": 50, - "miss_target_ms": 150 + "miss_target_ms": 150, } - + self.cache_manager = CacheManager(cache_config) # CacheManager doesn't have async initialize method - + async def teardown(self): """Clean up test environment.""" if self.cache_manager: # CacheManager doesn't have async close method self.cache_manager = None - + if self.temp_dir: import shutil + shutil.rmtree(self.temp_dir, ignore_errors=True) - + async def run_all_tests(self) -> List[TestResult]: """Run all integration tests.""" await self.setup() - + try: # Unit tests await self.test_tool_registration() await self.test_cache_integration() await self.test_error_handling() - + # Integration tests await self.test_hybrid_execution() await self.test_concurrent_requests() await self.test_failover_behavior() - + # Performance tests await self.test_performance_benchmarks() await self.test_memory_usage() - + # Mock tests await self.test_mock_responses() await self.test_network_failures() - + finally: await self.teardown() - + return self.test_results - + async def test_tool_registration(self): """Test tool registration functionality.""" test_name = "tool_registration" start_time = time.time() errors = [] details = {} - + try: # Test basic tool execution using ToolExecutor executor = ToolExecutor() - + # Get available tools (this tests the tool registry) available_tools = executor.get_available_tools() - details['available_tools_count'] = len(available_tools) - + details["available_tools_count"] = len(available_tools) + # Create a mock tool call to test execution path from src.tools.executor import ToolCall, create_tool_call - + # Test tool call creation tool_call = create_tool_call( tool_name="test_tool", arguments={"param1": "hello", "param2": 42}, - user_id="test_user" + user_id="test_user", ) - details['tool_call_created'] = True - + details["tool_call_created"] = True + # Test that we can handle tool calls (even if tool doesn't exist) try: result = await executor.execute_tool_call(tool_call) # Tool likely doesn't exist, so we expect this to fail with ToolNotFoundError - details['unexpected_success'] = True + details["unexpected_success"] = True except Exception as e: # Expected - tool doesn't exist in registry - details['expected_error'] = str(type(e).__name__) - + details["expected_error"] = str(type(e).__name__) + # Test rate limiting functionality - can_execute_before = executor.rate_limiter.can_execute() if executor.rate_limiter else True - details['rate_limiting_enabled'] = executor.enable_rate_limiting - details['can_execute'] = can_execute_before - - details['validation_tests'] = "passed" + can_execute_before = ( + executor.rate_limiter.can_execute() if executor.rate_limiter else True + ) + details["rate_limiting_enabled"] = executor.enable_rate_limiting + details["can_execute"] = can_execute_before + + details["validation_tests"] = "passed" success = True - + except Exception as e: success = False errors.append(str(e)) - + execution_time = time.time() - start_time - self.test_results.append(TestResult( - test_name=test_name, - success=success, - execution_time=execution_time, - details=details, - errors=errors - )) - + self.test_results.append( + TestResult( + test_name=test_name, + success=success, + execution_time=execution_time, + details=details, + errors=errors, + ) + ) + async def test_cache_integration(self): """Test cache integration with Arcade.dev client.""" test_name = "cache_integration" start_time = time.time() errors = [] details = {} - + try: # Test cache operations cache_key = "test_key_hash" # Create content with sufficient tokens (minimum 100 required) - test_data_str = json.dumps({ - "message": "cached_data_for_testing", - "timestamp": time.time(), - "description": "This is a test cache entry that contains enough content to meet the minimum token requirement for the cache manager. " * 10, - "additional_data": ["item1", "item2", "item3"] * 20 - }) - + test_data_str = json.dumps( + { + "message": "cached_data_for_testing", + "timestamp": time.time(), + "description": "This is a test cache entry that contains enough content to meet the minimum token requirement for the cache manager. " + * 10, + "additional_data": ["item1", "item2", "item3"] * 20, + } + ) + # Test cache store entry = self.cache_manager.store(cache_key, test_data_str) assert entry is not None - details['cache_set'] = True - + details["cache_set"] = True + # Test cache get cached_entry = self.cache_manager.get(cache_key) assert cached_entry is not None assert cached_entry.content == test_data_str - details['cache_get'] = True - + details["cache_get"] = True + # Test cache invalidation (cache manager doesn't have delete method, skip for now) - details['cache_invalidation'] = "skipped" - + details["cache_invalidation"] = "skipped" + # Test cache with Arcade client mock_responses = { "GET:/v1/health": MockArcadeResponse(200, {"status": "healthy"}) } self.mock_session = MockArcadeSession(mock_responses) - - with patch('aiohttp.ClientSession', return_value=self.mock_session): - client = ArcadeClient(api_key="test_key", cache_manager=self.cache_manager) + + with patch("aiohttp.ClientSession", return_value=self.mock_session): + client = ArcadeClient( + api_key="test_key", cache_manager=self.cache_manager + ) await client.connect() - + # First request should hit API result1 = await client.health_check() - details['first_request'] = result1 - + details["first_request"] = result1 + # Second request should hit cache result2 = await client.health_check() - details['cached_request'] = result2 - + details["cached_request"] = result2 + # Verify both results are the same assert result1 == result2 - + await client.disconnect() - + success = True - + except Exception as e: success = False errors.append(str(e)) - + execution_time = time.time() - start_time - self.test_results.append(TestResult( - test_name=test_name, - success=success, - execution_time=execution_time, - details=details, - errors=errors - )) - + self.test_results.append( + TestResult( + test_name=test_name, + success=success, + execution_time=execution_time, + details=details, + errors=errors, + ) + ) + async def test_hybrid_execution(self): """Test hybrid execution between local and Arcade.dev tools.""" test_name = "hybrid_execution" start_time = time.time() errors = [] details = {} - + try: # Set up mock Arcade responses mock_responses = { - "POST:/v1/tools/execute": MockArcadeResponse(200, { - "result": {"analysis": "comprehensive", "score": 85} - }) + "POST:/v1/tools/execute": MockArcadeResponse( + 200, {"result": {"analysis": "comprehensive", "score": 85}} + ) } self.mock_session = MockArcadeSession(mock_responses) - - with patch('aiohttp.ClientSession', return_value=self.mock_session): + + with patch("aiohttp.ClientSession", return_value=self.mock_session): # Create gateway for hybrid execution gateway = ArcadeGateway(api_key="test_key") await gateway.initialize() - + # Test local tool execution - local_result = await gateway.execute_local_tool("local_processor", { - "data": "test_input" - }) - details['local_execution'] = local_result - + local_result = await gateway.execute_local_tool( + "local_processor", {"data": "test_input"} + ) + details["local_execution"] = local_result + # Test Arcade tool execution - arcade_result = await gateway.execute_arcade_tool("code_analyzer", { - "code": "def hello(): return 'world'" - }) - details['arcade_execution'] = arcade_result - + arcade_result = await gateway.execute_arcade_tool( + "code_analyzer", {"code": "def hello(): return 'world'"} + ) + details["arcade_execution"] = arcade_result + # Test hybrid workflow - hybrid_result = await gateway.execute_hybrid_workflow([ - {"type": "local", "tool": "preprocessor", "params": {"raw_data": "input"}}, - {"type": "arcade", "tool": "analyzer", "params": {"processed_data": "{{previous_result}}"}}, - {"type": "local", "tool": "postprocessor", "params": {"analysis": "{{previous_result}}"}} - ]) - details['hybrid_workflow'] = hybrid_result - + hybrid_result = await gateway.execute_hybrid_workflow( + [ + { + "type": "local", + "tool": "preprocessor", + "params": {"raw_data": "input"}, + }, + { + "type": "arcade", + "tool": "analyzer", + "params": {"processed_data": "{{previous_result}}"}, + }, + { + "type": "local", + "tool": "postprocessor", + "params": {"analysis": "{{previous_result}}"}, + }, + ] + ) + details["hybrid_workflow"] = hybrid_result + await gateway.cleanup() - + success = True - + except Exception as e: success = False errors.append(str(e)) - + execution_time = time.time() - start_time - self.test_results.append(TestResult( - test_name=test_name, - success=success, - execution_time=execution_time, - details=details, - errors=errors - )) - + self.test_results.append( + TestResult( + test_name=test_name, + success=success, + execution_time=execution_time, + details=details, + errors=errors, + ) + ) + async def test_concurrent_requests(self): """Test concurrent request handling.""" test_name = "concurrent_requests" start_time = time.time() errors = [] details = {} - + try: # Set up mock responses with delays mock_responses = { "GET:/v1/health": MockArcadeResponse(200, {"status": "healthy"}), - "POST:/v1/analyze": MockArcadeResponse(200, {"result": "analysis_complete"}) + "POST:/v1/analyze": MockArcadeResponse( + 200, {"result": "analysis_complete"} + ), } self.mock_session = MockArcadeSession(mock_responses) - - with patch('aiohttp.ClientSession', return_value=self.mock_session): + + with patch("aiohttp.ClientSession", return_value=self.mock_session): client = ArcadeClient(api_key="test_key") await client.connect() - + # Create multiple concurrent tasks tasks = [] for i in range(10): if i % 2 == 0: task = asyncio.create_task(client.health_check()) else: - task = asyncio.create_task(client.analyze_code(f"code_{i}", "python")) + task = asyncio.create_task( + client.analyze_code(f"code_{i}", "python") + ) tasks.append(task) - + # Wait for all tasks to complete results = await asyncio.gather(*tasks, return_exceptions=True) - + # Analyze results - successful_requests = sum(1 for r in results if not isinstance(r, Exception)) + successful_requests = sum( + 1 for r in results if not isinstance(r, Exception) + ) failed_requests = sum(1 for r in results if isinstance(r, Exception)) - - details['total_requests'] = len(tasks) - details['successful_requests'] = successful_requests - details['failed_requests'] = failed_requests - details['success_rate'] = successful_requests / len(tasks) - + + details["total_requests"] = len(tasks) + details["successful_requests"] = successful_requests + details["failed_requests"] = failed_requests + details["success_rate"] = successful_requests / len(tasks) + await client.disconnect() - + # Verify reasonable success rate - assert details['success_rate'] >= 0.8 # At least 80% success + assert details["success_rate"] >= 0.8 # At least 80% success success = True - + except Exception as e: success = False errors.append(str(e)) - + execution_time = time.time() - start_time - self.test_results.append(TestResult( - test_name=test_name, - success=success, - execution_time=execution_time, - details=details, - errors=errors - )) - + self.test_results.append( + TestResult( + test_name=test_name, + success=success, + execution_time=execution_time, + details=details, + errors=errors, + ) + ) + async def test_performance_benchmarks(self): """Test performance benchmarks.""" test_name = "performance_benchmarks" start_time = time.time() errors = [] details = {} - + try: # Benchmark cache operations cache_times = [] for i in range(100): cache_start = time.time() # Create content with sufficient tokens - test_data = json.dumps({ - "data": f"value_{i}", - "description": "Performance test data that contains enough content to meet the minimum token requirement. " * 10, - "iteration": i, - "additional_fields": ["field1", "field2", "field3"] * 10 - }) + test_data = json.dumps( + { + "data": f"value_{i}", + "description": "Performance test data that contains enough content to meet the minimum token requirement. " + * 10, + "iteration": i, + "additional_fields": ["field1", "field2", "field3"] * 10, + } + ) self.cache_manager.store(f"perf_key_{i}", test_data) result = self.cache_manager.get(f"perf_key_{i}") cache_times.append(time.time() - cache_start) - - details['cache_avg_time'] = sum(cache_times) / len(cache_times) - details['cache_max_time'] = max(cache_times) - details['cache_min_time'] = min(cache_times) - + + details["cache_avg_time"] = sum(cache_times) / len(cache_times) + details["cache_max_time"] = max(cache_times) + details["cache_min_time"] = min(cache_times) + # Benchmark tool execution using ToolExecutor executor = ToolExecutor() - + # Create mock tool calls for performance testing tool_times = [] for i in range(50): tool_start = time.time() - + # Create a tool call that should fail (no tool registered) tool_call = create_tool_call( tool_name="benchmark_tool", arguments={"size": 1000}, - user_id="test_user" + user_id="test_user", ) - + # Execute and expect failure (measure execution time) try: await executor.execute_tool_call(tool_call) except Exception: pass # Expected - no tool registered - + tool_times.append(time.time() - tool_start) - - details['tool_avg_time'] = sum(tool_times) / len(tool_times) - details['tool_max_time'] = max(tool_times) - details['tool_min_time'] = min(tool_times) - + + details["tool_avg_time"] = sum(tool_times) / len(tool_times) + details["tool_max_time"] = max(tool_times) + details["tool_min_time"] = min(tool_times) + # Performance assertions - assert details['cache_avg_time'] < 0.01 # Cache ops should be fast - assert details['tool_avg_time'] < 0.1 # Tool execution reasonable (even for failures) - + assert details["cache_avg_time"] < 0.01 # Cache ops should be fast + assert ( + details["tool_avg_time"] < 0.1 + ) # Tool execution reasonable (even for failures) + success = True - + except Exception as e: success = False errors.append(str(e)) - + execution_time = time.time() - start_time - self.test_results.append(TestResult( - test_name=test_name, - success=success, - execution_time=execution_time, - details=details, - errors=errors - )) - + self.test_results.append( + TestResult( + test_name=test_name, + success=success, + execution_time=execution_time, + details=details, + errors=errors, + ) + ) + async def test_error_handling(self): """Test error handling scenarios.""" test_name = "error_handling" start_time = time.time() errors = [] details = {} - + try: # Test tool execution errors (skip network errors since mock isn't working as expected) executor = ToolExecutor() - + # Test invalid tool handling try: invalid_tool_call = create_tool_call( - tool_name="nonexistent_tool", - arguments={}, - user_id="test_user" + tool_name="nonexistent_tool", arguments={}, user_id="test_user" ) result = await executor.execute_tool_call(invalid_tool_call) assert False, "Should have raised exception" except Exception as e: - details['invalid_tool_handling'] = "passed" - details['invalid_tool_error'] = str(type(e).__name__) - + details["invalid_tool_handling"] = "passed" + details["invalid_tool_error"] = str(type(e).__name__) + # Test rate limiting (if enabled) if executor.enable_rate_limiting: # Record many calls to test rate limiting for _ in range(65): # Exceed default limit of 60 executor.rate_limiter.record_call("test_user") - + can_execute_after = executor.rate_limiter.can_execute("test_user") - details['rate_limiting_works'] = not can_execute_after - - details['tool_error_handling'] = "passed" - details['network_error_handling'] = "skipped (mock issues)" - + details["rate_limiting_works"] = not can_execute_after + + details["tool_error_handling"] = "passed" + details["network_error_handling"] = "skipped (mock issues)" + success = True - + except Exception as e: success = False errors.append(str(e)) - + execution_time = time.time() - start_time - self.test_results.append(TestResult( - test_name=test_name, - success=success, - execution_time=execution_time, - details=details, - errors=errors - )) - + self.test_results.append( + TestResult( + test_name=test_name, + success=success, + execution_time=execution_time, + details=details, + errors=errors, + ) + ) + async def test_mock_responses(self): """Test various mock response scenarios.""" test_name = "mock_responses" start_time = time.time() errors = [] details = {} - + try: # Test different response types scenarios = [ - ("success", MockArcadeResponse(200, {"status": "success", "data": "test"})), + ( + "success", + MockArcadeResponse(200, {"status": "success", "data": "test"}), + ), ("not_found", MockArcadeResponse(404, {"error": "Not found"})), ("rate_limited", MockArcadeResponse(429, {"error": "Rate limited"})), - ("server_error", MockArcadeResponse(500, {"error": "Internal error"})) + ("server_error", MockArcadeResponse(500, {"error": "Internal error"})), ] - + for scenario_name, mock_response in scenarios: mock_responses = {"GET:/v1/test": mock_response} self.mock_session = MockArcadeSession(mock_responses) - - with patch('aiohttp.ClientSession', return_value=self.mock_session): + + with patch("aiohttp.ClientSession", return_value=self.mock_session): client = ArcadeClient(api_key="test_key", max_retries=1) await client.connect() - + try: result = await client.make_request("GET", "/test") - details[f'{scenario_name}_result'] = result + details[f"{scenario_name}_result"] = result except Exception as e: - details[f'{scenario_name}_error'] = str(type(e).__name__) - + details[f"{scenario_name}_error"] = str(type(e).__name__) + await client.disconnect() - + success = True - + except Exception as e: success = False errors.append(str(e)) - + execution_time = time.time() - start_time - self.test_results.append(TestResult( - test_name=test_name, - success=success, - execution_time=execution_time, - details=details, - errors=errors - )) - + self.test_results.append( + TestResult( + test_name=test_name, + success=success, + execution_time=execution_time, + details=details, + errors=errors, + ) + ) + async def test_network_failures(self): """Test network failure scenarios.""" test_name = "network_failures" start_time = time.time() errors = [] details = {} - + try: # Test timeout scenario - simulate by setting very short timeout try: # Create a client with extremely short timeout to force timeout client = ArcadeClient(api_key="test_key", timeout=0.001, max_retries=1) await client.connect() - + # This should timeout due to the very short timeout await asyncio.wait_for(client.health_check(), timeout=0.001) - details['timeout_handling'] = "timeout not triggered as expected" - except (asyncio.TimeoutError, aiohttp.ClientTimeout, aiohttp.ServerTimeoutError): - details['timeout_handling'] = "passed" + details["timeout_handling"] = "timeout not triggered as expected" + except ( + asyncio.TimeoutError, + aiohttp.ClientTimeout, + aiohttp.ServerTimeoutError, + ): + details["timeout_handling"] = "passed" except Exception as e: # Accept any network-related error as timeout-like behavior if "timeout" in str(e).lower() or "time" in str(e).lower(): - details['timeout_handling'] = "passed" + details["timeout_handling"] = "passed" else: - details['timeout_handling'] = f"unexpected error: {e}" + details["timeout_handling"] = f"unexpected error: {e}" finally: try: await client.disconnect() except: pass - + # Test connection error with invalid URL try: - client = ArcadeClient(api_key="test_key", base_url="http://invalid-host-12345.local", max_retries=1) + client = ArcadeClient( + api_key="test_key", + base_url="http://invalid-host-12345.local", + max_retries=1, + ) await client.connect() await client.health_check() - details['connection_error_handling'] = "connection error not triggered" + details["connection_error_handling"] = "connection error not triggered" except (aiohttp.ClientConnectorError, aiohttp.ClientError, OSError): - details['connection_error_handling'] = "passed" + details["connection_error_handling"] = "passed" except Exception as e: # Accept DNS or connection-related errors - if any(word in str(e).lower() for word in ['connection', 'resolve', 'network', 'host']): - details['connection_error_handling'] = "passed" + if any( + word in str(e).lower() + for word in ["connection", "resolve", "network", "host"] + ): + details["connection_error_handling"] = "passed" else: - details['connection_error_handling'] = f"unexpected error: {e}" + details["connection_error_handling"] = f"unexpected error: {e}" finally: try: await client.disconnect() except: pass - + success = True - + except Exception as e: success = False errors.append(str(e)) - + execution_time = time.time() - start_time - self.test_results.append(TestResult( - test_name=test_name, - success=success, - execution_time=execution_time, - details=details, - errors=errors - )) - + self.test_results.append( + TestResult( + test_name=test_name, + success=success, + execution_time=execution_time, + details=details, + errors=errors, + ) + ) + async def test_failover_behavior(self): """Test failover behavior between services.""" test_name = "failover_behavior" start_time = time.time() errors = [] details = {} - + try: # Simulate primary service failure and fallback primary_responses = { - "POST:/v1/analyze": MockArcadeResponse(503, {"error": "Service unavailable"}) + "POST:/v1/analyze": MockArcadeResponse( + 503, {"error": "Service unavailable"} + ) } - + fallback_responses = { - "POST:/v1/analyze": MockArcadeResponse(200, {"result": "fallback_analysis"}) + "POST:/v1/analyze": MockArcadeResponse( + 200, {"result": "fallback_analysis"} + ) } - + # Test with failover logic - with patch('aiohttp.ClientSession') as mock_session_class: + with patch("aiohttp.ClientSession") as mock_session_class: # First attempt fails, second succeeds failed_session = MockArcadeSession(primary_responses) success_session = MockArcadeSession(fallback_responses) mock_session_class.side_effect = [failed_session, success_session] - + gateway = ArcadeGateway(api_key="test_key", enable_failover=True) - + # Should succeed via fallback result = await gateway.analyze_with_failover("test code") - details['failover_result'] = result + details["failover_result"] = result assert result.get("result") == "fallback_analysis" - + success = True - + except Exception as e: success = False errors.append(str(e)) - + execution_time = time.time() - start_time - self.test_results.append(TestResult( - test_name=test_name, - success=success, - execution_time=execution_time, - details=details, - errors=errors - )) - + self.test_results.append( + TestResult( + test_name=test_name, + success=success, + execution_time=execution_time, + details=details, + errors=errors, + ) + ) + async def test_memory_usage(self): """Test memory usage during operations.""" test_name = "memory_usage" start_time = time.time() errors = [] details = {} - + try: import psutil + process = psutil.Process() - + # Baseline memory baseline_memory = process.memory_info().rss / 1024 / 1024 # MB - + # Create large cache entries large_data = json.dumps({"data": "x" * 10000}) for i in range(100): self.cache_manager.store(f"large_key_{i}", large_data) - + # Memory after cache operations cache_memory = process.memory_info().rss / 1024 / 1024 # MB - + # Clean up cache (skip deletion as manager doesn't have delete method) # for i in range(100): # self.cache_manager.delete(f"large_key_{i}") - + # Memory after cleanup cleanup_memory = process.memory_info().rss / 1024 / 1024 # MB - - details['baseline_memory_mb'] = baseline_memory - details['cache_memory_mb'] = cache_memory - details['cleanup_memory_mb'] = cleanup_memory - details['memory_growth_mb'] = cache_memory - baseline_memory - details['memory_recovered_mb'] = cache_memory - cleanup_memory - + + details["baseline_memory_mb"] = baseline_memory + details["cache_memory_mb"] = cache_memory + details["cleanup_memory_mb"] = cleanup_memory + details["memory_growth_mb"] = cache_memory - baseline_memory + details["memory_recovered_mb"] = cache_memory - cleanup_memory + # Memory growth should be reasonable - assert details['memory_growth_mb'] < 100 # Less than 100MB growth - + assert details["memory_growth_mb"] < 100 # Less than 100MB growth + success = True - + except ImportError: # psutil not available, skip memory test - details['skipped'] = "psutil not available" + details["skipped"] = "psutil not available" success = True except Exception as e: success = False errors.append(str(e)) - + execution_time = time.time() - start_time - self.test_results.append(TestResult( - test_name=test_name, - success=success, - execution_time=execution_time, - details=details, - errors=errors - )) + self.test_results.append( + TestResult( + test_name=test_name, + success=success, + execution_time=execution_time, + details=details, + errors=errors, + ) + ) def create_test_fixtures(): """Create test fixtures and utilities.""" - + class TestFixtures: """Collection of test fixtures.""" - + @staticmethod def create_mock_cache_manager(): """Create a mock cache manager.""" cache = {} - + class MockCacheManager: async def get(self, key: str): return cache.get(key) - + async def set(self, key: str, value: Any, ttl: int = None): cache[key] = value - + async def delete(self, key: str): cache.pop(key, None) - + async def initialize(self): pass - + async def close(self): cache.clear() - + return MockCacheManager() - + @staticmethod def create_sample_tools(): """Create sample tools for testing.""" tools = {} - - async def sample_analyzer(code: str, language: str = "python") -> Dict[str, Any]: + + async def sample_analyzer( + code: str, language: str = "python" + ) -> Dict[str, Any]: return { - "lines": len(code.split('\n')), + "lines": len(code.split("\n")), "language": language, - "complexity": "low" - } - - async def sample_formatter(code: str, style: str = "black") -> Dict[str, Any]: - return { - "formatted_code": code.strip(), - "style_applied": style + "complexity": "low", } - + + async def sample_formatter( + code: str, style: str = "black" + ) -> Dict[str, Any]: + return {"formatted_code": code.strip(), "style_applied": style} + tools["analyzer"] = sample_analyzer tools["formatter"] = sample_formatter - + return tools - + @staticmethod def create_test_data(): """Create test data sets.""" @@ -920,9 +989,9 @@ def fibonacci(n): return fibonacci(n-1) + fibonacci(n-2) """, "invalid_code": "def broken( syntax error", - "large_code": "# " + "x" * 10000 + "large_code": "# " + "x" * 10000, } - + return TestFixtures @@ -930,52 +999,53 @@ async def run_comprehensive_tests(): """Run comprehensive test suite.""" print("🧪 Running Comprehensive Arcade.dev Integration Tests") print("=" * 60) - + test_suite = ArcadeIntegrationTestSuite() results = await test_suite.run_all_tests() - + # Print results summary total_tests = len(results) passed_tests = sum(1 for r in results if r.success) failed_tests = total_tests - passed_tests - + print(f"\nšŸ“Š Test Results Summary:") print(f" Total Tests: {total_tests}") print(f" Passed: {passed_tests}") print(f" Failed: {failed_tests}") print(f" Success Rate: {(passed_tests/total_tests)*100:.1f}%") - + # Print detailed results print(f"\nšŸ“‹ Detailed Results:") for result in results: status = "āœ…" if result.success else "āŒ" print(f" {status} {result.test_name}: {result.execution_time:.3f}s") - + if result.errors: for error in result.errors: print(f" Error: {error}") - + return passed_tests == total_tests async def main(): """Main test execution function.""" import logging + logging.basicConfig( level=logging.WARNING, # Reduce noise during tests - format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", ) - + try: success = await run_comprehensive_tests() - + if success: print("\nšŸŽ‰ All tests passed successfully!") return 0 else: print("\nāŒ Some tests failed!") return 1 - + except Exception as e: print(f"āŒ Test execution failed: {e}") return 1 @@ -983,4 +1053,4 @@ async def main(): if __name__ == "__main__": exit_code = asyncio.run(main()) - sys.exit(exit_code) \ No newline at end of file + sys.exit(exit_code) diff --git a/examples/arcade-dev/10_deployment/production_deployment.py b/examples/arcade-dev/10_deployment/production_deployment.py index 26c8e53..cd45bd4 100644 --- a/examples/arcade-dev/10_deployment/production_deployment.py +++ b/examples/arcade-dev/10_deployment/production_deployment.py @@ -25,6 +25,7 @@ import threading from concurrent.futures import ThreadPoolExecutor import subprocess + try: import psutil except ImportError: @@ -48,24 +49,31 @@ # Mock classes for demonstration (in production, these would be real implementations) class ArcadeClient: """Mock Arcade.dev API client for deployment example.""" - - def __init__(self, api_key: str, api_url: str, timeout: int, max_retries: int, cache_manager=None): + + def __init__( + self, + api_key: str, + api_url: str, + timeout: int, + max_retries: int, + cache_manager=None, + ): self.api_key = api_key self.api_url = api_url self.timeout = timeout self.max_retries = max_retries self.cache_manager = cache_manager self.connected = False - + async def connect(self): """Connect to Arcade.dev API.""" # Mock connection self.connected = True - + async def disconnect(self): """Disconnect from Arcade.dev API.""" self.connected = False - + async def health_check(self): """Perform health check.""" if not self.connected: @@ -75,18 +83,18 @@ async def health_check(self): class ArcadeGateway: """Mock Arcade.dev gateway for deployment example.""" - + def __init__(self, arcade_client, cache_manager, metrics, authorization_manager): self.arcade_client = arcade_client self.cache_manager = cache_manager self.metrics = metrics self.authorization_manager = authorization_manager self.initialized = False - + async def initialize(self): """Initialize the gateway.""" self.initialized = True - + async def cleanup(self): """Clean up gateway resources.""" self.initialized = False @@ -94,12 +102,14 @@ async def cleanup(self): class ServiceStartupError(Exception): """Custom exception for service startup errors.""" + pass @dataclass class ServiceHealthStatus: """Health status of a service component.""" + service_name: str healthy: bool last_check: float @@ -110,38 +120,39 @@ class ServiceHealthStatus: @dataclass class DeploymentConfig: """Production deployment configuration.""" + # Service configuration service_name: str = "fact-arcade-integration" service_version: str = "1.0.0" environment: str = "production" - + # Network configuration host: str = "0.0.0.0" port: int = 8080 health_check_port: int = 8081 - + # Arcade.dev configuration arcade_api_key: str = "" arcade_api_url: str = "https://api.arcade.dev" arcade_timeout: int = 30 arcade_max_retries: int = 3 - + # Cache configuration cache_backend: str = "redis" cache_host: str = "localhost" cache_port: int = 6379 cache_db: int = 0 - + # Monitoring configuration metrics_enabled: bool = True metrics_port: int = 9090 log_level: str = "INFO" - + # Performance configuration worker_threads: int = 4 max_concurrent_requests: int = 100 request_timeout: int = 60 - + # Security configuration enable_auth: bool = False # Disabled for demo jwt_secret: str = "" @@ -150,11 +161,11 @@ class DeploymentConfig: class ProductionArcadeService: """Production-ready Arcade.dev integration service.""" - + def __init__(self, config: DeploymentConfig): self.config = config self.logger = self._setup_logging() - + # Core components self.driver: Optional[FACTDriver] = None self.cache_manager: Optional[CacheManager] = None @@ -162,127 +173,138 @@ def __init__(self, config: DeploymentConfig): self.arcade_gateway: Optional[ArcadeGateway] = None self.metrics: Optional[MetricsCollector] = None self.authorization_manager: Optional[AuthorizationManager] = None - + # Service state self.is_running = False self.is_ready = False self.startup_time: Optional[float] = None self.shutdown_event = asyncio.Event() self.health_status: Dict[str, ServiceHealthStatus] = {} - + # Background tasks self.background_tasks: List[asyncio.Task] = [] self.thread_pool = ThreadPoolExecutor(max_workers=config.worker_threads) - + # Signal handlers self._setup_signal_handlers() - + def _setup_logging(self) -> logging.Logger: """Set up production logging configuration.""" # Create log handlers handlers = [logging.StreamHandler()] - + # Try to add file handler, but don't fail if directory doesn't exist try: - log_file = f'/var/log/{self.config.service_name}.log' + log_file = f"/var/log/{self.config.service_name}.log" handlers.append(logging.FileHandler(log_file)) except (PermissionError, FileNotFoundError): # Fallback to local log file try: - log_file = f'./{self.config.service_name}.log' + log_file = f"./{self.config.service_name}.log" handlers.append(logging.FileHandler(log_file)) except Exception: # If all file logging fails, just use console pass - + logging.basicConfig( level=getattr(logging, self.config.log_level.upper()), - format='%(asctime)s - %(name)s - %(levelname)s - %(message)s - %(pathname)s:%(lineno)d', - handlers=handlers + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s - %(pathname)s:%(lineno)d", + handlers=handlers, ) - + logger = logging.getLogger(self.config.service_name) - logger.info(f"Initializing {self.config.service_name} v{self.config.service_version}") + logger.info( + f"Initializing {self.config.service_name} v{self.config.service_version}" + ) return logger - + def _setup_signal_handlers(self): """Set up signal handlers for graceful shutdown.""" + def signal_handler(signum, frame): self.logger.info(f"Received signal {signum}, initiating graceful shutdown") asyncio.create_task(self.shutdown()) - + signal.signal(signal.SIGTERM, signal_handler) signal.signal(signal.SIGINT, signal_handler) - + async def initialize(self): """Initialize all service components.""" self.logger.info("Starting service initialization") initialization_start = time.time() - + try: # Load and validate configuration await self._load_configuration() - + # Initialize core components await self._initialize_cache() await self._initialize_arcade_client() await self._initialize_monitoring() await self._initialize_authorization() await self._initialize_gateway() - + # Start background tasks await self._start_background_tasks() - + # Perform initial health checks await self._perform_initial_health_checks() - + self.startup_time = time.time() - initialization_start self.is_running = True self.is_ready = True - - self.logger.info(f"Service initialization completed in {self.startup_time:.2f}s") - + + self.logger.info( + f"Service initialization completed in {self.startup_time:.2f}s" + ) + except Exception as e: self.logger.error(f"Service initialization failed: {e}") await self.cleanup() raise ServiceStartupError(f"Failed to initialize service: {e}") - + async def _load_configuration(self): """Load and validate configuration from environment and files.""" self.logger.info("Loading configuration") - + # Load from environment variables - self.config.arcade_api_key = os.getenv("ARCADE_API_KEY", self.config.arcade_api_key) - self.config.arcade_api_url = os.getenv("ARCADE_API_URL", self.config.arcade_api_url) + self.config.arcade_api_key = os.getenv( + "ARCADE_API_KEY", self.config.arcade_api_key + ) + self.config.arcade_api_url = os.getenv( + "ARCADE_API_URL", self.config.arcade_api_url + ) self.config.cache_host = os.getenv("CACHE_HOST", self.config.cache_host) self.config.jwt_secret = os.getenv("JWT_SECRET", self.config.jwt_secret) - + # Load from configuration file if exists config_file = os.getenv("CONFIG_FILE", "/etc/fact-arcade/config.yaml") if os.path.exists(config_file): - with open(config_file, 'r') as f: + with open(config_file, "r") as f: file_config = yaml.safe_load(f) self._merge_config(file_config) - + # Validate required configuration if not self.config.arcade_api_key: raise ConfigurationError("ARCADE_API_KEY is required") - + if self.config.enable_auth and not self.config.jwt_secret: - raise ConfigurationError("JWT_SECRET is required when authentication is enabled") - + raise ConfigurationError( + "JWT_SECRET is required when authentication is enabled" + ) + self.logger.info("Configuration loaded and validated") - + def _merge_config(self, file_config: Dict[str, Any]): """Merge file configuration with current config.""" for key, value in file_config.items(): if hasattr(self.config, key): setattr(self.config, key, value) - + async def _initialize_cache(self): """Initialize cache manager.""" self.logger.info("Initializing cache manager") - + cache_config = { "prefix": "fact_arcade_cache", "min_tokens": 1, # Reduced for demo purposes @@ -291,142 +313,141 @@ async def _initialize_cache(self): "backend": self.config.cache_backend, "host": self.config.cache_host, "port": self.config.cache_port, - "db": self.config.cache_db + "db": self.config.cache_db, } - + self.cache_manager = CacheManager(cache_config) # CacheManager doesn't require async initialization - + # Test cache connectivity test_key = "health_check_cache" test_content = '{"status": "ok"}' self.cache_manager.store(test_key, test_content) result = self.cache_manager.get(test_key) - + if not result: raise ServiceStartupError("Cache connectivity test failed") - + self.cache_manager.invalidate_by_prefix(test_key) self.logger.info("Cache manager initialized successfully") - + async def _initialize_arcade_client(self): """Initialize Arcade.dev client.""" self.logger.info("Initializing Arcade.dev client") - + self.arcade_client = ArcadeClient( api_key=self.config.arcade_api_key, api_url=self.config.arcade_api_url, timeout=self.config.arcade_timeout, max_retries=self.config.arcade_max_retries, - cache_manager=self.cache_manager + cache_manager=self.cache_manager, ) - + await self.arcade_client.connect() - + # Test API connectivity try: health_result = await self.arcade_client.health_check() - self.logger.info(f"Arcade.dev API health: {health_result.get('status', 'unknown')}") + self.logger.info( + f"Arcade.dev API health: {health_result.get('status', 'unknown')}" + ) except Exception as e: self.logger.warning(f"Arcade.dev API health check failed: {e}") - + self.logger.info("Arcade.dev client initialized successfully") - + async def _initialize_monitoring(self): """Initialize monitoring and metrics collection.""" if not self.config.metrics_enabled: self.logger.info("Metrics collection disabled") return - + self.logger.info("Initializing metrics collection") - - self.metrics = MetricsCollector( - max_history=10000 - ) - + + self.metrics = MetricsCollector(max_history=10000) + # MetricsCollector doesn't require async initialization - + # Register custom metrics await self._register_custom_metrics() - + self.logger.info("Metrics collection initialized successfully") - + async def _register_custom_metrics(self): """Register custom application metrics.""" if not self.metrics: return - + # MetricsCollector from monitoring doesn't require metric registration # It automatically tracks tool execution metrics self.logger.info("Metrics collection ready") - + async def _initialize_authorization(self): """Initialize security manager.""" if not self.config.enable_auth: self.logger.info("Authorization disabled") return - + self.logger.info("Initializing authorization manager") - + self.authorization_manager = AuthorizationManager() - + # Authorization manager doesn't need async initialization self.logger.info("Authorization manager initialized successfully") - + async def _initialize_gateway(self): """Initialize Arcade.dev gateway.""" self.logger.info("Initializing Arcade.dev gateway") - + self.arcade_gateway = ArcadeGateway( arcade_client=self.arcade_client, cache_manager=self.cache_manager, metrics=self.metrics, - authorization_manager=self.authorization_manager + authorization_manager=self.authorization_manager, ) - + await self.arcade_gateway.initialize() self.logger.info("Arcade.dev gateway initialized successfully") - + async def _start_background_tasks(self): """Start background maintenance tasks.""" self.logger.info("Starting background tasks") - + # Health check task health_task = asyncio.create_task(self._health_check_loop()) self.background_tasks.append(health_task) - + # Metrics collection task if self.metrics: metrics_task = asyncio.create_task(self._metrics_collection_loop()) self.background_tasks.append(metrics_task) - + # Cache cleanup task if self.cache_manager: cleanup_task = asyncio.create_task(self._cache_cleanup_loop()) self.background_tasks.append(cleanup_task) - + self.logger.info(f"Started {len(self.background_tasks)} background tasks") - + async def _perform_initial_health_checks(self): """Perform initial health checks on all components.""" self.logger.info("Performing initial health checks") - + # Check all components await self._check_cache_health() await self._check_arcade_health() await self._check_system_health() - + # Verify all components are healthy unhealthy_services = [ - name for name, status in self.health_status.items() - if not status.healthy + name for name, status in self.health_status.items() if not status.healthy ] - + if unhealthy_services: raise ServiceStartupError(f"Unhealthy services: {unhealthy_services}") - + self.logger.info("All initial health checks passed") - + async def _health_check_loop(self): """Background task for periodic health checks.""" while not self.shutdown_event.is_set(): @@ -434,68 +455,67 @@ async def _health_check_loop(self): await self._check_cache_health() await self._check_arcade_health() await self._check_system_health() - + # Wait for next check await asyncio.wait_for( - self.shutdown_event.wait(), - timeout=30.0 # Check every 30 seconds + self.shutdown_event.wait(), timeout=30.0 # Check every 30 seconds ) - + except asyncio.TimeoutError: continue # Continue health checks except Exception as e: self.logger.error(f"Health check loop error: {e}") await asyncio.sleep(30) - + async def _check_cache_health(self): """Check cache health.""" try: start_time = time.time() test_key = f"health_check_{int(time.time())}" - + self.cache_manager.store(test_key, "ok") result = self.cache_manager.get(test_key) self.cache_manager.invalidate_by_prefix(test_key) - + response_time = time.time() - start_time - + self.health_status["cache"] = ServiceHealthStatus( service_name="cache", healthy=result is not None and result.content == "ok", last_check=time.time(), - details={"response_time": response_time} + details={"response_time": response_time}, ) - + except Exception as e: self.health_status["cache"] = ServiceHealthStatus( service_name="cache", healthy=False, last_check=time.time(), - error_message=str(e) + error_message=str(e), ) - + async def _check_arcade_health(self): """Check Arcade.dev API health.""" try: start_time = time.time() result = await self.arcade_client.health_check() response_time = time.time() - start_time - + self.health_status["arcade"] = ServiceHealthStatus( service_name="arcade", healthy=result.get("status") == "healthy", last_check=time.time(), - details={"response_time": response_time, "api_status": result} + details={"response_time": response_time, "api_status": result}, ) - + except Exception as e: self.health_status["arcade"] = ServiceHealthStatus( service_name="arcade", healthy=False, last_check=time.time(), - error_message=str(e) + error_message=str(e), ) - + async def _check_system_health(self): """Check system resource health.""" try: @@ -509,11 +529,11 @@ async def _check_system_health(self): process = psutil.Process() memory_info = process.memory_info() cpu_percent = process.cpu_percent() - + # Check if we're using too much memory (> 1GB as example) memory_mb = memory_info.rss / 1024 / 1024 healthy = memory_mb < 1024 and cpu_percent < 80 - + self.health_status["system"] = ServiceHealthStatus( service_name="system", healthy=healthy, @@ -521,18 +541,20 @@ async def _check_system_health(self): details={ "memory_mb": memory_mb, "cpu_percent": cpu_percent, - "uptime": time.time() - self.startup_time if self.startup_time else 0 - } + "uptime": ( + time.time() - self.startup_time if self.startup_time else 0 + ), + }, ) - + except Exception as e: self.health_status["system"] = ServiceHealthStatus( service_name="system", healthy=False, last_check=time.time(), - error_message=str(e) + error_message=str(e), ) - + async def _metrics_collection_loop(self): """Background task for metrics collection.""" while not self.shutdown_event.is_set(): @@ -542,19 +564,18 @@ async def _metrics_collection_loop(self): # MetricsCollector from monitoring doesn't have record_gauge method # It automatically tracks metrics during tool execution self.logger.debug(f"Service uptime: {uptime:.2f} seconds") - + # Wait for next collection await asyncio.wait_for( - self.shutdown_event.wait(), - timeout=60.0 # Collect every minute + self.shutdown_event.wait(), timeout=60.0 # Collect every minute ) - + except asyncio.TimeoutError: continue except Exception as e: self.logger.error(f"Metrics collection error: {e}") await asyncio.sleep(60) - + async def _cache_cleanup_loop(self): """Background task for cache cleanup.""" while not self.shutdown_event.is_set(): @@ -562,25 +583,22 @@ async def _cache_cleanup_loop(self): if self.cache_manager: # Use private method for cleanup (sync, not async) self.cache_manager._cleanup_expired() - + # Wait for next cleanup await asyncio.wait_for( - self.shutdown_event.wait(), - timeout=300.0 # Cleanup every 5 minutes + self.shutdown_event.wait(), timeout=300.0 # Cleanup every 5 minutes ) - + except asyncio.TimeoutError: continue except Exception as e: self.logger.error(f"Cache cleanup error: {e}") await asyncio.sleep(300) - + async def get_health_status(self) -> Dict[str, Any]: """Get current health status of all components.""" - overall_healthy = all( - status.healthy for status in self.health_status.values() - ) - + overall_healthy = all(status.healthy for status in self.health_status.values()) + return { "service": self.config.service_name, "version": self.config.service_version, @@ -592,78 +610,78 @@ async def get_health_status(self) -> Dict[str, Any]: "healthy": status.healthy, "last_check": status.last_check, "details": status.details, - "error": status.error_message + "error": status.error_message, } for name, status in self.health_status.items() - } + }, } - + async def get_readiness_status(self) -> Dict[str, Any]: """Get readiness status for load balancer probes.""" return { "ready": self.is_ready, "startup_time": self.startup_time, - "components_ready": len(self.health_status) > 0 + "components_ready": len(self.health_status) > 0, } - + async def shutdown(self): """Gracefully shutdown the service.""" if not self.is_running: return - + self.logger.info("Starting graceful shutdown") shutdown_start = time.time() - + # Set shutdown event self.shutdown_event.set() self.is_ready = False - + try: # Cancel background tasks self.logger.info("Cancelling background tasks") for task in self.background_tasks: task.cancel() - + # Wait for tasks to complete if self.background_tasks: await asyncio.gather(*self.background_tasks, return_exceptions=True) - + # Cleanup components await self.cleanup() - + shutdown_time = time.time() - shutdown_start self.logger.info(f"Graceful shutdown completed in {shutdown_time:.2f}s") - + except Exception as e: self.logger.error(f"Error during shutdown: {e}") finally: self.is_running = False - + async def cleanup(self): """Clean up all resources.""" self.logger.info("Cleaning up resources") - + # Close gateway if self.arcade_gateway: await self.arcade_gateway.cleanup() - + # Close arcade client if self.arcade_client: await self.arcade_client.disconnect() - + # Close cache manager if self.cache_manager: # CacheManager doesn't have a close method - cleanup is automatic pass - + # Close metrics collector if self.metrics: # MetricsCollector from monitoring doesn't require cleanup pass - + # Shutdown thread pool self.thread_pool.shutdown(wait=True) - + self.logger.info("Resource cleanup completed") @@ -675,29 +693,30 @@ async def create_health_check_server(service: ProductionArcadeService): # Fallback for demonstration if aiohttp is not installed web = None web_response = None - + if web is None or web_response is None: # Return a mock app if aiohttp is not available class MockApp: pass + return MockApp() - + async def health_handler(request): """Health check endpoint.""" status = await service.get_health_status() status_code = 200 if status["status"] == "healthy" else 503 return web_response.json_response(status, status=status_code) - + async def readiness_handler(request): """Readiness check endpoint.""" status = await service.get_readiness_status() status_code = 200 if status["ready"] else 503 return web_response.json_response(status, status=status_code) - + app = web.Application() - app.router.add_get('/health', health_handler) - app.router.add_get('/ready', readiness_handler) - + app.router.add_get("/health", health_handler) + app.router.add_get("/ready", readiness_handler) + return app @@ -705,54 +724,54 @@ async def main(): """Main service entry point.""" # Load configuration config = DeploymentConfig() - + # Override from environment config.environment = os.getenv("ENVIRONMENT", config.environment) config.log_level = os.getenv("LOG_LEVEL", config.log_level) config.port = int(os.getenv("PORT", str(config.port))) - + # Create and start service service = ProductionArcadeService(config) - + try: # Initialize service await service.initialize() - + # Create health check server health_app = await create_health_check_server(service) - + # Start health check server try: from aiohttp import web except ImportError: web = None - if web is not None and hasattr(health_app, 'router'): + if web is not None and hasattr(health_app, "router"): health_runner = web.AppRunner(health_app) await health_runner.setup() health_site = web.TCPSite( - health_runner, - service.config.host, - service.config.health_check_port + health_runner, service.config.host, service.config.health_check_port ) await health_site.start() - - service.logger.info(f"Health check server started on port {service.config.health_check_port}") + + service.logger.info( + f"Health check server started on port {service.config.health_check_port}" + ) else: service.logger.info("Health check server skipped (aiohttp not available)") health_runner = None - + service.logger.info(f"Service ready and accepting requests") - + # Wait for shutdown signal await service.shutdown_event.wait() - + # Cleanup health server if health_runner: await health_runner.cleanup() - + service.logger.info("Service shutdown completed") return 0 - + except Exception as e: if service: service.logger.error(f"Service failed: {e}") @@ -763,4 +782,4 @@ async def main(): if __name__ == "__main__": exit_code = asyncio.run(main()) - sys.exit(exit_code) \ No newline at end of file + sys.exit(exit_code) diff --git a/examples/arcade-dev/12_monitoring/arcade_monitoring.py b/examples/arcade-dev/12_monitoring/arcade_monitoring.py index 292804b..1089059 100644 --- a/examples/arcade-dev/12_monitoring/arcade_monitoring.py +++ b/examples/arcade-dev/12_monitoring/arcade_monitoring.py @@ -38,26 +38,27 @@ @dataclass class MonitoringConfig: """Configuration for monitoring and observability.""" + # Performance monitoring enable_performance_tracking: bool = True enable_health_checks: bool = True enable_alerting: bool = True enable_telemetry: bool = True - + # Health check intervals health_check_interval_seconds: int = 30 performance_check_interval_seconds: int = 60 - + # Alert thresholds error_rate_threshold: float = 0.05 # 5% response_time_threshold_ms: float = 5000 # 5 seconds cpu_usage_threshold: float = 80.0 # 80% memory_usage_threshold: float = 85.0 # 85% - + # Metrics retention metrics_retention_hours: int = 24 detailed_metrics_retention_minutes: int = 60 - + # Logging metrics_log_file: str = "arcade_metrics.log" alerts_log_file: str = "arcade_alerts.log" @@ -66,6 +67,7 @@ class MonitoringConfig: @dataclass class HealthCheckResult: """Result of a health check.""" + service: str status: str # healthy, degraded, unhealthy response_time_ms: float @@ -77,6 +79,7 @@ class HealthCheckResult: @dataclass class PerformanceMetric: """Performance metric data point.""" + metric_name: str value: float timestamp: datetime @@ -86,6 +89,7 @@ class PerformanceMetric: @dataclass class Alert: """Alert notification.""" + alert_id: str severity: str # info, warning, critical title: str @@ -100,160 +104,168 @@ class Alert: class SystemMonitor: """Monitors system resources and performance.""" - + def __init__(self, config: MonitoringConfig): self.config = config self.logger = logging.getLogger(f"{__name__}.SystemMonitor") - + def get_system_metrics(self) -> Dict[str, float]: """Get current system resource metrics.""" try: cpu_percent = psutil.cpu_percent(interval=1) memory = psutil.virtual_memory() - disk = psutil.disk_usage('/') - + disk = psutil.disk_usage("/") + # Network I/O net_io = psutil.net_io_counters() - + return { - 'cpu_usage_percent': cpu_percent, - 'memory_usage_percent': memory.percent, - 'memory_available_mb': memory.available / (1024 * 1024), - 'disk_usage_percent': disk.percent, - 'network_bytes_sent': net_io.bytes_sent, - 'network_bytes_recv': net_io.bytes_recv, - 'load_average': os.getloadavg()[0] if hasattr(os, 'getloadavg') else 0.0 + "cpu_usage_percent": cpu_percent, + "memory_usage_percent": memory.percent, + "memory_available_mb": memory.available / (1024 * 1024), + "disk_usage_percent": disk.percent, + "network_bytes_sent": net_io.bytes_sent, + "network_bytes_recv": net_io.bytes_recv, + "load_average": ( + os.getloadavg()[0] if hasattr(os, "getloadavg") else 0.0 + ), } except Exception as e: self.logger.error(f"Failed to collect system metrics: {e}") return {} - + def check_system_health(self) -> HealthCheckResult: """Perform system health check.""" start_time = time.time() - + try: metrics = self.get_system_metrics() response_time = (time.time() - start_time) * 1000 - + # Determine health status status = "healthy" issues = [] - - if metrics.get('cpu_usage_percent', 0) > self.config.cpu_usage_threshold: + + if metrics.get("cpu_usage_percent", 0) > self.config.cpu_usage_threshold: status = "degraded" issues.append(f"High CPU usage: {metrics['cpu_usage_percent']:.1f}%") - - if metrics.get('memory_usage_percent', 0) > self.config.memory_usage_threshold: + + if ( + metrics.get("memory_usage_percent", 0) + > self.config.memory_usage_threshold + ): status = "degraded" - issues.append(f"High memory usage: {metrics['memory_usage_percent']:.1f}%") - - if metrics.get('disk_usage_percent', 0) > 90: + issues.append( + f"High memory usage: {metrics['memory_usage_percent']:.1f}%" + ) + + if metrics.get("disk_usage_percent", 0) > 90: status = "unhealthy" - issues.append(f"Critical disk usage: {metrics['disk_usage_percent']:.1f}%") - + issues.append( + f"Critical disk usage: {metrics['disk_usage_percent']:.1f}%" + ) + return HealthCheckResult( service="system", status=status, response_time_ms=response_time, timestamp=datetime.now(timezone.utc), - details={ - 'metrics': metrics, - 'issues': issues - } + details={"metrics": metrics, "issues": issues}, ) - + except Exception as e: return HealthCheckResult( service="system", status="unhealthy", response_time_ms=(time.time() - start_time) * 1000, timestamp=datetime.now(timezone.utc), - error_message=str(e) + error_message=str(e), ) class ArcadeAPIMonitor: """Monitors Arcade.dev API performance and health.""" - + def __init__(self, config: MonitoringConfig): self.config = config self.logger = logging.getLogger(f"{__name__}.ArcadeAPIMonitor") - self.api_key = os.getenv('ARCADE_API_KEY') - self.demo_mode = self.api_key is None or self.api_key == '' or self.api_key == 'demo_key' - + self.api_key = os.getenv("ARCADE_API_KEY") + self.demo_mode = ( + self.api_key is None or self.api_key == "" or self.api_key == "demo_key" + ) + if self.demo_mode: self.logger.info("Running in demo mode - API calls will be simulated") else: self.logger.info("Running with real API credentials") - + async def check_api_health(self) -> HealthCheckResult: """Check Arcade.dev API health.""" start_time = time.time() - + try: if self.demo_mode: # Demo mode: simulate API health check await asyncio.sleep(0.1) # Simulate network call response_time = (time.time() - start_time) * 1000 - + # Mock API response analysis status = "healthy" if response_time > self.config.response_time_threshold_ms: status = "degraded" - + return HealthCheckResult( service="arcade_api", status=status, response_time_ms=response_time, timestamp=datetime.now(timezone.utc), details={ - 'endpoint': 'https://api.arcade.dev/health (simulated)', - 'api_version': 'v1', - 'demo_mode': True, - 'response_time_threshold_ms': self.config.response_time_threshold_ms - } + "endpoint": "https://api.arcade.dev/health (simulated)", + "api_version": "v1", + "demo_mode": True, + "response_time_threshold_ms": self.config.response_time_threshold_ms, + }, ) else: # Real API mode: would make actual API call here # For now, simulate since we don't have real API implementation await asyncio.sleep(0.1) # Simulate network call response_time = (time.time() - start_time) * 1000 - + status = "healthy" if response_time > self.config.response_time_threshold_ms: status = "degraded" - + return HealthCheckResult( service="arcade_api", status=status, response_time_ms=response_time, timestamp=datetime.now(timezone.utc), details={ - 'endpoint': 'https://api.arcade.dev/health', - 'api_version': 'v1', - 'demo_mode': False, - 'response_time_threshold_ms': self.config.response_time_threshold_ms - } + "endpoint": "https://api.arcade.dev/health", + "api_version": "v1", + "demo_mode": False, + "response_time_threshold_ms": self.config.response_time_threshold_ms, + }, ) - + except Exception as e: return HealthCheckResult( service="arcade_api", status="unhealthy", response_time_ms=(time.time() - start_time) * 1000, timestamp=datetime.now(timezone.utc), - error_message=str(e) + error_message=str(e), ) - + async def test_api_operations(self) -> Dict[str, HealthCheckResult]: """Test various API operations for monitoring.""" operations = { - 'code_analysis': self._test_code_analysis, - 'test_generation': self._test_generation, - 'documentation': self._test_documentation + "code_analysis": self._test_code_analysis, + "test_generation": self._test_generation, + "documentation": self._test_documentation, } - + results = {} for operation_name, test_func in operations.items(): try: @@ -265,106 +277,112 @@ async def test_api_operations(self) -> Dict[str, HealthCheckResult]: status="unhealthy", response_time_ms=0, timestamp=datetime.now(timezone.utc), - error_message=str(e) + error_message=str(e), ) - + return results - + async def _test_code_analysis(self) -> HealthCheckResult: """Test code analysis endpoint.""" start_time = time.time() - + # Simulate API call (demo or real mode) await asyncio.sleep(0.2) - + response_time = (time.time() - start_time) * 1000 - + return HealthCheckResult( service="arcade_api_code_analysis", status="healthy" if response_time < 1000 else "degraded", response_time_ms=response_time, timestamp=datetime.now(timezone.utc), - details={ - 'test_code_lines': 10, - 'demo_mode': self.demo_mode - } + details={"test_code_lines": 10, "demo_mode": self.demo_mode}, ) - + async def _test_generation(self) -> HealthCheckResult: """Test test generation endpoint.""" start_time = time.time() - + # Simulate API call (demo or real mode) await asyncio.sleep(0.3) - + response_time = (time.time() - start_time) * 1000 - + return HealthCheckResult( service="arcade_api_test_generation", status="healthy" if response_time < 1500 else "degraded", response_time_ms=response_time, timestamp=datetime.now(timezone.utc), - details={ - 'generated_tests': 5, - 'demo_mode': self.demo_mode - } + details={"generated_tests": 5, "demo_mode": self.demo_mode}, ) - + async def _test_documentation(self) -> HealthCheckResult: """Test documentation endpoint.""" start_time = time.time() - + # Simulate API call (demo or real mode) await asyncio.sleep(0.15) - + response_time = (time.time() - start_time) * 1000 - + return HealthCheckResult( service="arcade_api_documentation", status="healthy" if response_time < 800 else "degraded", response_time_ms=response_time, timestamp=datetime.now(timezone.utc), - details={ - 'doc_sections': 4, - 'demo_mode': self.demo_mode - } + details={"doc_sections": 4, "demo_mode": self.demo_mode}, ) class AlertManager: """Manages alerts and notifications.""" - + def __init__(self, config: MonitoringConfig): self.config = config self.logger = logging.getLogger(f"{__name__}.AlertManager") self.active_alerts: Dict[str, Alert] = {} self.alert_history: List[Alert] = [] self.alert_callbacks: List[Callable[[Alert], None]] = [] - self.suppressed_alerts: Dict[str, datetime] = {} # Track alerts to prevent duplicates - + self.suppressed_alerts: Dict[str, datetime] = ( + {} + ) # Track alerts to prevent duplicates + def register_alert_callback(self, callback: Callable[[Alert], None]): """Register a callback for alert notifications.""" self.alert_callbacks.append(callback) - - def create_alert(self, severity: str, title: str, message: str, service: str, - metric_name: str = None, current_value: float = None, - threshold: float = None) -> Optional[Alert]: + + def create_alert( + self, + severity: str, + title: str, + message: str, + service: str, + metric_name: str = None, + current_value: float = None, + threshold: float = None, + ) -> Optional[Alert]: """Create a new alert, with duplicate suppression.""" # Create unique key for this type of alert - alert_key = f"{service}_{metric_name}_{title}" if metric_name else f"{service}_{title}" - + alert_key = ( + f"{service}_{metric_name}_{title}" if metric_name else f"{service}_{title}" + ) + # Check if we've recently created this type of alert (suppress duplicates for 5 minutes) now = datetime.now(timezone.utc) if alert_key in self.suppressed_alerts: time_since_last = now - self.suppressed_alerts[alert_key] if time_since_last < timedelta(minutes=5): return None # Suppress duplicate alert - + # Update suppression tracker self.suppressed_alerts[alert_key] = now - - alert_id = f"{service}_{metric_name}_{int(time.time())}" if metric_name else f"{service}_{int(time.time())}" - + + alert_id = ( + f"{service}_{metric_name}_{int(time.time())}" + if metric_name + else f"{service}_{int(time.time())}" + ) + alert = Alert( alert_id=alert_id, severity=severity, @@ -374,41 +392,41 @@ def create_alert(self, severity: str, title: str, message: str, service: str, service=service, metric_name=metric_name, current_value=current_value, - threshold=threshold + threshold=threshold, ) - + self.active_alerts[alert_id] = alert self.alert_history.append(alert) - + # Trigger callbacks for callback in self.alert_callbacks: try: callback(alert) except Exception as e: self.logger.error(f"Alert callback failed: {e}") - + self.logger.warning(f"Alert created: {alert.title} - {alert.message}") return alert - + def resolve_alert(self, alert_id: str) -> bool: """Resolve an active alert.""" if alert_id in self.active_alerts: alert = self.active_alerts[alert_id] alert.resolved = True del self.active_alerts[alert_id] - + self.logger.info(f"Alert resolved: {alert.title}") return True - + return False - + def check_metric_thresholds(self, metrics: Dict[str, float], service: str): """Check metrics against thresholds and create alerts if needed.""" if not self.config.enable_alerting: return - + # CPU usage alert - cpu_usage = metrics.get('cpu_usage_percent', 0) + cpu_usage = metrics.get("cpu_usage_percent", 0) if cpu_usage > self.config.cpu_usage_threshold: self.create_alert( severity="warning", @@ -417,11 +435,11 @@ def check_metric_thresholds(self, metrics: Dict[str, float], service: str): service=service, metric_name="cpu_usage_percent", current_value=cpu_usage, - threshold=self.config.cpu_usage_threshold + threshold=self.config.cpu_usage_threshold, ) - + # Memory usage alert - memory_usage = metrics.get('memory_usage_percent', 0) + memory_usage = metrics.get("memory_usage_percent", 0) if memory_usage > self.config.memory_usage_threshold: self.create_alert( severity="warning", @@ -430,89 +448,95 @@ def check_metric_thresholds(self, metrics: Dict[str, float], service: str): service=service, metric_name="memory_usage_percent", current_value=memory_usage, - threshold=self.config.memory_usage_threshold + threshold=self.config.memory_usage_threshold, ) - + def get_active_alerts(self) -> List[Alert]: """Get all active alerts.""" return list(self.active_alerts.values()) - + def get_alert_summary(self) -> Dict[str, Any]: """Get alert summary statistics.""" now = datetime.now(timezone.utc) last_hour = now - timedelta(hours=1) last_day = now - timedelta(days=1) - + recent_alerts = [a for a in self.alert_history if a.timestamp > last_hour] daily_alerts = [a for a in self.alert_history if a.timestamp > last_day] - + return { - 'active_alerts': len(self.active_alerts), - 'alerts_last_hour': len(recent_alerts), - 'alerts_last_day': len(daily_alerts), - 'total_alerts': len(self.alert_history), - 'severity_breakdown': { - 'critical': len([a for a in recent_alerts if a.severity == 'critical']), - 'warning': len([a for a in recent_alerts if a.severity == 'warning']), - 'info': len([a for a in recent_alerts if a.severity == 'info']) - } + "active_alerts": len(self.active_alerts), + "alerts_last_hour": len(recent_alerts), + "alerts_last_day": len(daily_alerts), + "total_alerts": len(self.alert_history), + "severity_breakdown": { + "critical": len([a for a in recent_alerts if a.severity == "critical"]), + "warning": len([a for a in recent_alerts if a.severity == "warning"]), + "info": len([a for a in recent_alerts if a.severity == "info"]), + }, } class TelemetryCollector: """Collects and exports telemetry data.""" - + def __init__(self, config: MonitoringConfig): self.config = config self.logger = logging.getLogger(f"{__name__}.TelemetryCollector") self.metrics_buffer: List[PerformanceMetric] = [] self.buffer_lock = threading.Lock() - - def record_metric(self, metric_name: str, value: float, tags: Dict[str, str] = None): + + def record_metric( + self, metric_name: str, value: float, tags: Dict[str, str] = None + ): """Record a performance metric.""" if not self.config.enable_telemetry: return - + metric = PerformanceMetric( metric_name=metric_name, value=value, timestamp=datetime.now(timezone.utc), - tags=tags or {} + tags=tags or {}, ) - + with self.buffer_lock: self.metrics_buffer.append(metric) - + # Keep buffer size manageable if len(self.metrics_buffer) > 1000: self.metrics_buffer = self.metrics_buffer[-500:] - - def get_metrics(self, metric_name: str = None, last_minutes: int = 10) -> List[PerformanceMetric]: + + def get_metrics( + self, metric_name: str = None, last_minutes: int = 10 + ) -> List[PerformanceMetric]: """Get metrics from buffer.""" cutoff_time = datetime.now(timezone.utc) - timedelta(minutes=last_minutes) - + with self.buffer_lock: filtered_metrics = [ - m for m in self.metrics_buffer - if m.timestamp > cutoff_time and (metric_name is None or m.metric_name == metric_name) + m + for m in self.metrics_buffer + if m.timestamp > cutoff_time + and (metric_name is None or m.metric_name == metric_name) ] - + return filtered_metrics - - def export_metrics(self, format_type: str = 'json') -> str: + + def export_metrics(self, format_type: str = "json") -> str: """Export metrics in specified format.""" with self.buffer_lock: metrics_data = [ { - 'metric_name': m.metric_name, - 'value': m.value, - 'timestamp': m.timestamp.isoformat(), - 'tags': m.tags + "metric_name": m.metric_name, + "value": m.value, + "timestamp": m.timestamp.isoformat(), + "tags": m.tags, } for m in self.metrics_buffer ] - - if format_type == 'json': + + if format_type == "json": return json.dumps(metrics_data, indent=2) else: return str(metrics_data) @@ -520,72 +544,72 @@ def export_metrics(self, format_type: str = 'json') -> str: class ArcadeMonitoringDashboard: """Main monitoring dashboard orchestrating all monitoring components.""" - + def __init__(self, config: MonitoringConfig = None): self.config = config or MonitoringConfig() self.logger = logging.getLogger(__name__) - + # Initialize components self.system_monitor = SystemMonitor(self.config) self.api_monitor = ArcadeAPIMonitor(self.config) self.alert_manager = AlertManager(self.config) self.telemetry = TelemetryCollector(self.config) self.metrics_collector = MetricsCollector() - + # Register alert callback self.alert_manager.register_alert_callback(self._handle_alert) - + # Monitoring state self.monitoring_active = False self.monitoring_task: Optional[asyncio.Task] = None - + def _handle_alert(self, alert: Alert): """Handle alert notifications.""" # Log alert alert_data = { - 'alert_id': alert.alert_id, - 'severity': alert.severity, - 'title': alert.title, - 'message': alert.message, - 'service': alert.service, - 'timestamp': alert.timestamp.isoformat() + "alert_id": alert.alert_id, + "severity": alert.severity, + "title": alert.title, + "message": alert.message, + "service": alert.service, + "timestamp": alert.timestamp.isoformat(), } - + # Write to alerts log alerts_log = Path(self.config.alerts_log_file) alerts_log.parent.mkdir(exist_ok=True) - - with open(alerts_log, 'a') as f: - f.write(json.dumps(alert_data) + '\n') - + + with open(alerts_log, "a") as f: + f.write(json.dumps(alert_data) + "\n") + # In production, send to notification system (email, Slack, etc.) print(f"🚨 ALERT [{alert.severity.upper()}]: {alert.title}") print(f" Service: {alert.service}") print(f" Message: {alert.message}") - + async def start_monitoring(self): """Start continuous monitoring.""" if self.monitoring_active: self.logger.warning("Monitoring already active") return - + self.monitoring_active = True self.monitoring_task = asyncio.create_task(self._monitoring_loop()) self.logger.info("Monitoring started") - + async def stop_monitoring(self): """Stop monitoring.""" self.monitoring_active = False - + if self.monitoring_task: self.monitoring_task.cancel() try: await self.monitoring_task except asyncio.CancelledError: pass - + self.logger.info("Monitoring stopped") - + async def _monitoring_loop(self): """Main monitoring loop.""" while self.monitoring_active: @@ -593,35 +617,34 @@ async def _monitoring_loop(self): # System health check system_health = self.system_monitor.check_system_health() self.telemetry.record_metric( - 'system_health_response_time_ms', + "system_health_response_time_ms", system_health.response_time_ms, - {'service': 'system', 'status': system_health.status} + {"service": "system", "status": system_health.status}, ) - + # Check system metrics against thresholds - if system_health.details and 'metrics' in system_health.details: + if system_health.details and "metrics" in system_health.details: self.alert_manager.check_metric_thresholds( - system_health.details['metrics'], - 'system' + system_health.details["metrics"], "system" ) - + # API health checks api_health = await self.api_monitor.check_api_health() self.telemetry.record_metric( - 'api_health_response_time_ms', + "api_health_response_time_ms", api_health.response_time_ms, - {'service': 'arcade_api', 'status': api_health.status} + {"service": "arcade_api", "status": api_health.status}, ) - + # Test API operations operation_results = await self.api_monitor.test_api_operations() for operation, result in operation_results.items(): self.telemetry.record_metric( - f'api_operation_response_time_ms', + f"api_operation_response_time_ms", result.response_time_ms, - {'operation': operation, 'status': result.status} + {"operation": operation, "status": result.status}, ) - + # Check for slow operations if result.response_time_ms > self.config.response_time_threshold_ms: self.alert_manager.create_alert( @@ -631,100 +654,114 @@ async def _monitoring_loop(self): service="arcade_api", metric_name="response_time_ms", current_value=result.response_time_ms, - threshold=self.config.response_time_threshold_ms + threshold=self.config.response_time_threshold_ms, ) - + await asyncio.sleep(self.config.health_check_interval_seconds) - + except Exception as e: self.logger.error(f"Monitoring loop error: {e}") await asyncio.sleep(5) # Wait before retrying - + def get_dashboard_data(self) -> Dict[str, Any]: """Get comprehensive dashboard data.""" # System metrics system_metrics = self.system_monitor.get_system_metrics() - + # Recent performance metrics recent_metrics = { - 'system_response_times': [ - m.value for m in self.telemetry.get_metrics('system_health_response_time_ms', 10) + "system_response_times": [ + m.value + for m in self.telemetry.get_metrics( + "system_health_response_time_ms", 10 + ) + ], + "api_response_times": [ + m.value + for m in self.telemetry.get_metrics("api_health_response_time_ms", 10) ], - 'api_response_times': [ - m.value for m in self.telemetry.get_metrics('api_health_response_time_ms', 10) - ] } - + # Calculate averages - if recent_metrics['system_response_times']: - recent_metrics['avg_system_response_time'] = statistics.mean(recent_metrics['system_response_times']) + if recent_metrics["system_response_times"]: + recent_metrics["avg_system_response_time"] = statistics.mean( + recent_metrics["system_response_times"] + ) else: - recent_metrics['avg_system_response_time'] = 0 - - if recent_metrics['api_response_times']: - recent_metrics['avg_api_response_time'] = statistics.mean(recent_metrics['api_response_times']) + recent_metrics["avg_system_response_time"] = 0 + + if recent_metrics["api_response_times"]: + recent_metrics["avg_api_response_time"] = statistics.mean( + recent_metrics["api_response_times"] + ) else: - recent_metrics['avg_api_response_time'] = 0 - + recent_metrics["avg_api_response_time"] = 0 + return { - 'timestamp': datetime.now(timezone.utc).isoformat(), - 'monitoring_active': self.monitoring_active, - 'system_metrics': system_metrics, - 'performance_metrics': recent_metrics, - 'alerts': self.alert_manager.get_alert_summary(), - 'active_alerts': [ + "timestamp": datetime.now(timezone.utc).isoformat(), + "monitoring_active": self.monitoring_active, + "system_metrics": system_metrics, + "performance_metrics": recent_metrics, + "alerts": self.alert_manager.get_alert_summary(), + "active_alerts": [ { - 'id': alert.alert_id, - 'severity': alert.severity, - 'title': alert.title, - 'service': alert.service, - 'timestamp': alert.timestamp.isoformat() + "id": alert.alert_id, + "severity": alert.severity, + "title": alert.title, + "service": alert.service, + "timestamp": alert.timestamp.isoformat(), } for alert in self.alert_manager.get_active_alerts() - ] + ], } - + def print_dashboard(self): """Print monitoring dashboard to console.""" dashboard_data = self.get_dashboard_data() - + print("\n" + "=" * 60) print("šŸ–„ļø ARCADE.DEV MONITORING DASHBOARD") print("=" * 60) - + # Demo mode indicator - if hasattr(self, 'api_monitor') and self.api_monitor.demo_mode: + if hasattr(self, "api_monitor") and self.api_monitor.demo_mode: print("šŸ”§ Running in DEMO MODE (simulated API calls)") - + # System Status print(f"\nšŸ“Š System Metrics:") - system_metrics = dashboard_data['system_metrics'] + system_metrics = dashboard_data["system_metrics"] if system_metrics: print(f" CPU Usage: {system_metrics.get('cpu_usage_percent', 0):.1f}%") - print(f" Memory Usage: {system_metrics.get('memory_usage_percent', 0):.1f}%") + print( + f" Memory Usage: {system_metrics.get('memory_usage_percent', 0):.1f}%" + ) print(f" Disk Usage: {system_metrics.get('disk_usage_percent', 0):.1f}%") print(f" Load Average: {system_metrics.get('load_average', 0):.2f}") - + # Performance print(f"\n⚔ Performance:") - perf_metrics = dashboard_data['performance_metrics'] - print(f" Avg System Response: {perf_metrics['avg_system_response_time']:.1f}ms") + perf_metrics = dashboard_data["performance_metrics"] + print( + f" Avg System Response: {perf_metrics['avg_system_response_time']:.1f}ms" + ) print(f" Avg API Response: {perf_metrics['avg_api_response_time']:.1f}ms") - + # Alerts print(f"\n🚨 Alerts:") - alert_summary = dashboard_data['alerts'] + alert_summary = dashboard_data["alerts"] print(f" Active: {alert_summary['active_alerts']}") print(f" Last Hour: {alert_summary['alerts_last_hour']}") print(f" Last Day: {alert_summary['alerts_last_day']}") - + # Active Alerts - active_alerts = dashboard_data['active_alerts'] + active_alerts = dashboard_data["active_alerts"] if active_alerts: print(f"\nāš ļø Active Alerts:") for alert in active_alerts[:5]: # Show first 5 - print(f" [{alert['severity'].upper()}] {alert['title']} ({alert['service']})") - + print( + f" [{alert['severity'].upper()}] {alert['title']} ({alert['service']})" + ) + print("\n" + "=" * 60) @@ -732,51 +769,57 @@ async def demonstrate_monitoring(): """Demonstrate monitoring capabilities.""" print("šŸ“Š Arcade.dev Monitoring Demo") print("=" * 50) - + try: # Create monitoring dashboard config = MonitoringConfig( health_check_interval_seconds=5, # Faster for demo cpu_usage_threshold=50.0, # Lower threshold for demo alerts - memory_usage_threshold=60.0 + memory_usage_threshold=60.0, ) - + dashboard = ArcadeMonitoringDashboard(config) - + # Check demo mode status - demo_status = "šŸ”§ DEMO MODE" if dashboard.api_monitor.demo_mode else "🌐 LIVE MODE" + demo_status = ( + "šŸ”§ DEMO MODE" if dashboard.api_monitor.demo_mode else "🌐 LIVE MODE" + ) print(f"\nMode: {demo_status}") - print(f"API Key Available: {'āŒ No' if dashboard.api_monitor.demo_mode else 'āœ… Yes'}") - + print( + f"API Key Available: {'āŒ No' if dashboard.api_monitor.demo_mode else 'āœ… Yes'}" + ) + print("\nšŸš€ Starting monitoring...") await dashboard.start_monitoring() - + # Let it run for a bit and collect data print("ā³ Collecting metrics for 30 seconds...") - + for i in range(6): # 6 iterations of 5 seconds each await asyncio.sleep(5) dashboard.print_dashboard() - + # Simulate some load to trigger alerts if i == 2: print("\nšŸ”„ Simulating high load...") # This would trigger alerts in a real scenario - + print("\nšŸ›‘ Stopping monitoring...") await dashboard.stop_monitoring() - + # Final summary print("\nšŸ“‹ Final Monitoring Summary:") dashboard.print_dashboard() - + # Export metrics for review metrics_export = dashboard.telemetry.export_metrics() - print(f"\nšŸ“„ Exported {len(dashboard.telemetry.metrics_buffer)} metrics to JSON") - + print( + f"\nšŸ“„ Exported {len(dashboard.telemetry.metrics_buffer)} metrics to JSON" + ) + print("\nšŸŽ‰ Monitoring demonstration completed!") return dashboard - + except Exception as e: print(f"\nāŒ Error during monitoring demo: {e}") raise @@ -786,9 +829,9 @@ async def main(): """Main demonstration function.""" logging.basicConfig( level=logging.INFO, - format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", ) - + try: await demonstrate_monitoring() return 0 @@ -802,4 +845,4 @@ async def main(): if __name__ == "__main__": exit_code = asyncio.run(main()) - sys.exit(exit_code) \ No newline at end of file + sys.exit(exit_code) diff --git a/examples/arcade-dev/interactive_demo.py b/examples/arcade-dev/interactive_demo.py index 821becb..87a0eaf 100644 --- a/examples/arcade-dev/interactive_demo.py +++ b/examples/arcade-dev/interactive_demo.py @@ -25,6 +25,7 @@ from rich.table import Table from rich.progress import Progress, SpinnerColumn, TextColumn from rich.syntax import Syntax + HAS_RICH = True except ImportError: HAS_RICH = False @@ -33,6 +34,7 @@ @dataclass class DemoExample: """Represents a demo example.""" + name: str description: str script_path: str @@ -42,12 +44,12 @@ class DemoExample: class InteractiveDemo: """Interactive demonstration manager.""" - + def __init__(self): self.console = Console() if HAS_RICH else None self.examples = self._load_examples() self.logger = logging.getLogger(__name__) - + def _load_examples(self) -> List[DemoExample]: """Load available examples.""" return [ @@ -56,77 +58,83 @@ def _load_examples(self) -> List[DemoExample]: description="Fundamental Arcade.dev API client integration", script_path="01_basic_integration/basic_arcade_client.py", requirements=["ARCADE_API_KEY"], - difficulty="beginner" + difficulty="beginner", ), DemoExample( name="Code Analysis", description="Analyze code quality and structure", script_path="02_code_analysis/code_analyzer.py", requirements=["ARCADE_API_KEY"], - difficulty="beginner" + difficulty="beginner", ), DemoExample( name="Test Generation", description="Generate comprehensive test cases using AI", script_path="03_automated_testing/test_generator.py", requirements=["ARCADE_API_KEY"], - difficulty="intermediate" + difficulty="intermediate", ), DemoExample( name="Documentation Generation", description="Automatically generate documentation from code", script_path="04_documentation/doc_generator.py", requirements=["ARCADE_API_KEY"], - difficulty="intermediate" + difficulty="intermediate", ), DemoExample( name="Cache Integration", description="Integrate Arcade.dev with FACT's caching system", script_path="07_cache_integration/cached_arcade_client.py", requirements=["ARCADE_API_KEY", "FACT_CACHE_ENABLED"], - difficulty="advanced" - ) + difficulty="advanced", + ), ] - + def _print_banner(self): """Print welcome banner.""" if self.console: - self.console.print(Panel.fit( - "[bold blue]šŸŽ® Arcade.dev Integration Examples[/bold blue]\n" - "[dim]Interactive demonstration of FACT + Arcade.dev integration[/dim]", - border_style="blue" - )) + self.console.print( + Panel.fit( + "[bold blue]šŸŽ® Arcade.dev Integration Examples[/bold blue]\n" + "[dim]Interactive demonstration of FACT + Arcade.dev integration[/dim]", + border_style="blue", + ) + ) else: print("=" * 60) print("šŸŽ® Arcade.dev Integration Examples") print("Interactive demonstration of FACT + Arcade.dev integration") print("=" * 60) - + def _check_prerequisites(self) -> bool: """Check if prerequisites are met.""" missing_vars = [] required_vars = ["ARCADE_API_KEY"] - + for var in required_vars: if not os.getenv(var): missing_vars.append(var) - + if missing_vars: if self.console: - self.console.print(Panel( - f"[red]Missing required environment variables:[/red]\n" - f"{', '.join(missing_vars)}\n\n" - f"[dim]Please copy .env.example to .env and configure your settings.[/dim]", - title="Prerequisites Check Failed", - border_style="red" - )) + self.console.print( + Panel( + f"[red]Missing required environment variables:[/red]\n" + f"{', '.join(missing_vars)}\n\n" + f"[dim]Please copy .env.example to .env and configure your settings.[/dim]", + title="Prerequisites Check Failed", + border_style="red", + ) + ) else: - print(f"āŒ Missing required environment variables: {', '.join(missing_vars)}") + print( + f"āŒ Missing required environment variables: {', '.join(missing_vars)}" + ) print("Please copy .env.example to .env and configure your settings.") return False - + return True - + def _display_examples_table(self): """Display available examples in a table.""" if self.console: @@ -136,12 +144,12 @@ def _display_examples_table(self): table.add_column("Description") table.add_column("Difficulty", justify="center") table.add_column("Status", justify="center") - + for i, example in enumerate(self.examples, 1): # Check if example script exists script_path = Path(__file__).parent / example.script_path status = "āœ…" if script_path.exists() else "🚧" - + # Color code difficulty if example.difficulty == "beginner": difficulty = f"[green]{example.difficulty}[/green]" @@ -149,15 +157,11 @@ def _display_examples_table(self): difficulty = f"[yellow]{example.difficulty}[/yellow]" else: difficulty = f"[red]{example.difficulty}[/red]" - + table.add_row( - str(i), - example.name, - example.description, - difficulty, - status + str(i), example.name, example.description, difficulty, status ) - + self.console.print(table) else: print("\nAvailable Examples:") @@ -167,109 +171,119 @@ def _display_examples_table(self): status = "āœ…" if script_path.exists() else "🚧" print(f"{i:2d}. {example.name} ({example.difficulty}) {status}") print(f" {example.description}") - + def _get_user_choice(self) -> Optional[int]: """Get user's example choice.""" if self.console: choice = Prompt.ask( "\n[bold]Select an example to run[/bold]", choices=[str(i) for i in range(1, len(self.examples) + 1)] + ["q"], - default="q" + default="q", ) else: - choice = input(f"\nSelect an example (1-{len(self.examples)}) or 'q' to quit: ") - - if choice.lower() == 'q': + choice = input( + f"\nSelect an example (1-{len(self.examples)}) or 'q' to quit: " + ) + + if choice.lower() == "q": return None - + try: return int(choice) except ValueError: return None - + async def _run_example(self, example: DemoExample): """Run a selected example.""" script_path = Path(__file__).parent / example.script_path - + if not script_path.exists(): if self.console: - self.console.print(f"[red]āŒ Example script not found: {example.script_path}[/red]") + self.console.print( + f"[red]āŒ Example script not found: {example.script_path}[/red]" + ) else: print(f"āŒ Example script not found: {example.script_path}") return - + # Check example-specific requirements missing_reqs = [] for req in example.requirements: if not os.getenv(req): missing_reqs.append(req) - + if missing_reqs: if self.console: - self.console.print(f"[yellow]āš ļø Missing requirements for this example: {', '.join(missing_reqs)}[/yellow]") + self.console.print( + f"[yellow]āš ļø Missing requirements for this example: {', '.join(missing_reqs)}[/yellow]" + ) if not Confirm.ask("Continue anyway?"): return else: print(f"āš ļø Missing requirements: {', '.join(missing_reqs)}") response = input("Continue anyway? (y/N): ") - if response.lower() != 'y': + if response.lower() != "y": return - + # Display example info if self.console: - self.console.print(Panel( - f"[bold]{example.name}[/bold]\n" - f"{example.description}\n\n" - f"[dim]Script: {example.script_path}[/dim]\n" - f"[dim]Difficulty: {example.difficulty}[/dim]", - title="Running Example", - border_style="green" - )) + self.console.print( + Panel( + f"[bold]{example.name}[/bold]\n" + f"{example.description}\n\n" + f"[dim]Script: {example.script_path}[/dim]\n" + f"[dim]Difficulty: {example.difficulty}[/dim]", + title="Running Example", + border_style="green", + ) + ) else: print(f"\nšŸš€ Running: {example.name}") print(f"Description: {example.description}") print(f"Script: {example.script_path}") - + # Run the example try: if self.console: with Progress( SpinnerColumn(), TextColumn("[progress.description]{task.description}"), - console=self.console + console=self.console, ) as progress: task = progress.add_task("Executing example...", total=None) - + # Execute the script process = await asyncio.create_subprocess_exec( - sys.executable, str(script_path), + sys.executable, + str(script_path), stdout=asyncio.subprocess.PIPE, - stderr=asyncio.subprocess.PIPE + stderr=asyncio.subprocess.PIPE, ) - + stdout, stderr = await process.communicate() - + progress.remove_task(task) - + else: print("Executing example...") process = await asyncio.create_subprocess_exec( - sys.executable, str(script_path), + sys.executable, + str(script_path), stdout=asyncio.subprocess.PIPE, - stderr=asyncio.subprocess.PIPE + stderr=asyncio.subprocess.PIPE, ) stdout, stderr = await process.communicate() - + # Display results if process.returncode == 0: if self.console: - self.console.print("[green]āœ… Example completed successfully![/green]") + self.console.print( + "[green]āœ… Example completed successfully![/green]" + ) if stdout: - self.console.print(Panel( - stdout.decode(), - title="Output", - border_style="green" - )) + self.console.print( + Panel(stdout.decode(), title="Output", border_style="green") + ) else: print("āœ… Example completed successfully!") if stdout: @@ -280,59 +294,63 @@ async def _run_example(self, example: DemoExample): if self.console: self.console.print("[red]āŒ Example failed![/red]") if stderr: - self.console.print(Panel( - stderr.decode(), - title="Error Output", - border_style="red" - )) + self.console.print( + Panel( + stderr.decode(), + title="Error Output", + border_style="red", + ) + ) else: print("āŒ Example failed!") if stderr: print("\nError Output:") print("-" * 40) print(stderr.decode()) - + except Exception as e: if self.console: self.console.print(f"[red]āŒ Error running example: {e}[/red]") else: print(f"āŒ Error running example: {e}") - + async def run(self): """Run the interactive demo.""" self._print_banner() - + if not self._check_prerequisites(): return 1 - + while True: self._display_examples_table() choice = self._get_user_choice() - + if choice is None: if self.console: - self.console.print("\n[blue]šŸ‘‹ Thanks for exploring Arcade.dev examples![/blue]") + self.console.print( + "\n[blue]šŸ‘‹ Thanks for exploring Arcade.dev examples![/blue]" + ) else: print("\nšŸ‘‹ Thanks for exploring Arcade.dev examples!") break - + if 1 <= choice <= len(self.examples): example = self.examples[choice - 1] await self._run_example(example) - + if self.console: if not Confirm.ask("\nRun another example?"): break else: response = input("\nRun another example? (y/N): ") - if response.lower() != 'y': + if response.lower() != "y": break else: if self.console: self.console.print("[red]Invalid choice. Please try again.[/red]") else: print("Invalid choice. Please try again.") - + return 0 @@ -341,14 +359,14 @@ async def main(): # Setup basic logging logging.basicConfig( level=logging.WARNING, # Reduce noise in interactive mode - format='%(levelname)s: %(message)s' + format="%(levelname)s: %(message)s", ) - + # Check for rich dependency if not HAS_RICH: print("āš ļø For the best experience, install 'rich': pip install rich") print("Continuing with basic text interface...\n") - + demo = InteractiveDemo() return await demo.run() @@ -359,4 +377,4 @@ async def main(): sys.exit(exit_code) except KeyboardInterrupt: print("\nšŸ‘‹ Demo interrupted by user") - sys.exit(0) \ No newline at end of file + sys.exit(0) diff --git a/examples/arcade-dev/requirements.txt b/examples/arcade-dev/requirements.txt index 3f35774..0daac8e 100644 --- a/examples/arcade-dev/requirements.txt +++ b/examples/arcade-dev/requirements.txt @@ -1,7 +1,7 @@ # Arcade.dev Integration Examples - Python Dependencies # Core HTTP and async libraries -aiohttp>=3.8.0 +aiohttp>=3.13.4 asyncio-throttle>=1.0.2 httpx>=0.24.0 diff --git a/examples/arcade-dev/run_all_examples.py b/examples/arcade-dev/run_all_examples.py index 9f827ff..0b75a6d 100644 --- a/examples/arcade-dev/run_all_examples.py +++ b/examples/arcade-dev/run_all_examples.py @@ -21,227 +21,230 @@ class ExampleRunner: """Manages execution of all Arcade.dev examples.""" - + def __init__(self): self.base_dir = Path(__file__).parent self.examples = self._discover_examples() self.results: List[Dict[str, Any]] = [] - + def _discover_examples(self) -> List[Dict[str, str]]: """Discover all runnable examples.""" examples = [] - + # Known examples known_examples = [ { - 'name': 'Basic Integration', - 'script': '01_basic_integration/basic_arcade_client.py', - 'description': 'Basic API client integration' + "name": "Basic Integration", + "script": "01_basic_integration/basic_arcade_client.py", + "description": "Basic API client integration", }, { - 'name': 'Tool Registration', - 'script': '02_tool_registration/register_fact_tools.py', - 'description': 'Register FACT tools with Arcade.dev' + "name": "Tool Registration", + "script": "02_tool_registration/register_fact_tools.py", + "description": "Register FACT tools with Arcade.dev", }, { - 'name': 'Intelligent Routing', - 'script': '03_intelligent_routing/hybrid_execution.py', - 'description': 'Hybrid execution with intelligent routing' + "name": "Intelligent Routing", + "script": "03_intelligent_routing/hybrid_execution.py", + "description": "Hybrid execution with intelligent routing", }, { - 'name': 'Error Handling', - 'script': '04_error_handling/resilient_execution.py', - 'description': 'Resilient execution with error handling' + "name": "Error Handling", + "script": "04_error_handling/resilient_execution.py", + "description": "Resilient execution with error handling", }, { - 'name': 'Cache Integration (Enhanced)', - 'script': '05_cache_integration/cached_arcade_client_enhanced.py', - 'description': 'Enhanced cached API client' + "name": "Cache Integration (Enhanced)", + "script": "05_cache_integration/cached_arcade_client_enhanced.py", + "description": "Enhanced cached API client", }, { - 'name': 'Security', - 'script': '06_security/secure_tool_execution.py', - 'description': 'Secure tool execution with validation' + "name": "Security", + "script": "06_security/secure_tool_execution.py", + "description": "Secure tool execution with validation", }, { - 'name': 'Cache Integration', - 'script': '07_cache_integration/cached_arcade_client.py', - 'description': 'Cached API client with performance optimization' + "name": "Cache Integration", + "script": "07_cache_integration/cached_arcade_client.py", + "description": "Cached API client with performance optimization", }, { - 'name': 'Advanced Tools', - 'script': '08_advanced_tools/advanced_tool_usage.py', - 'description': 'Advanced tool usage patterns' + "name": "Advanced Tools", + "script": "08_advanced_tools/advanced_tool_usage.py", + "description": "Advanced tool usage patterns", }, { - 'name': 'Testing', - 'script': '09_testing/arcade_integration_tests.py', - 'description': 'Integration testing framework' + "name": "Testing", + "script": "09_testing/arcade_integration_tests.py", + "description": "Integration testing framework", }, { - 'name': 'Production Deployment', - 'script': '10_deployment/production_deployment.py', - 'description': 'Production deployment configuration' + "name": "Production Deployment", + "script": "10_deployment/production_deployment.py", + "description": "Production deployment configuration", }, { - 'name': 'Monitoring', - 'script': '12_monitoring/arcade_monitoring.py', - 'description': 'Monitoring and observability' - } + "name": "Monitoring", + "script": "12_monitoring/arcade_monitoring.py", + "description": "Monitoring and observability", + }, ] - + # Filter to only include existing scripts for example in known_examples: - script_path = self.base_dir / example['script'] + script_path = self.base_dir / example["script"] if script_path.exists(): examples.append(example) - + return examples - + async def run_example(self, example: Dict[str, str]) -> Dict[str, Any]: """Run a single example and collect results.""" print(f"\n{'='*60}") print(f"šŸš€ Running: {example['name']}") print(f"šŸ“ Script: {example['script']}") print(f"šŸ“ Description: {example['description']}") - print('='*60) - - script_path = self.base_dir / example['script'] + print("=" * 60) + + script_path = self.base_dir / example["script"] start_time = time.time() - + try: # Run the example process = await asyncio.create_subprocess_exec( - sys.executable, str(script_path), + sys.executable, + str(script_path), stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE, - cwd=str(script_path.parent) + cwd=str(script_path.parent), ) - + stdout, stderr = await process.communicate() duration = time.time() - start_time - + result = { - 'name': example['name'], - 'script': example['script'], - 'duration': duration, - 'success': process.returncode == 0, - 'return_code': process.returncode, - 'stdout': stdout.decode() if stdout else '', - 'stderr': stderr.decode() if stderr else '' + "name": example["name"], + "script": example["script"], + "duration": duration, + "success": process.returncode == 0, + "return_code": process.returncode, + "stdout": stdout.decode() if stdout else "", + "stderr": stderr.decode() if stderr else "", } - - if result['success']: + + if result["success"]: print(f"āœ… {example['name']} completed successfully ({duration:.2f}s)") else: - print(f"āŒ {example['name']} failed (return code: {process.returncode})") - + print( + f"āŒ {example['name']} failed (return code: {process.returncode})" + ) + # Show output/errors - if result['stdout']: + if result["stdout"]: print("\nšŸ“„ Output:") print("-" * 40) - print(result['stdout']) - - if result['stderr']: + print(result["stdout"]) + + if result["stderr"]: print("\nāš ļø Errors/Warnings:") print("-" * 40) - print(result['stderr']) - + print(result["stderr"]) + return result - + except Exception as e: duration = time.time() - start_time print(f"āŒ {example['name']} failed with exception: {e}") - + return { - 'name': example['name'], - 'script': example['script'], - 'duration': duration, - 'success': False, - 'return_code': -1, - 'stdout': '', - 'stderr': str(e) + "name": example["name"], + "script": example["script"], + "duration": duration, + "success": False, + "return_code": -1, + "stdout": "", + "stderr": str(e), } - + def print_summary(self): """Print execution summary.""" - print("\n" + "="*80) + print("\n" + "=" * 80) print("šŸ“Š EXECUTION SUMMARY") - print("="*80) - + print("=" * 80) + total_examples = len(self.results) - successful = sum(1 for r in self.results if r['success']) + successful = sum(1 for r in self.results if r["success"]) failed = total_examples - successful - total_duration = sum(r['duration'] for r in self.results) - + total_duration = sum(r["duration"] for r in self.results) + print(f"Total Examples: {total_examples}") print(f"Successful: {successful}") print(f"Failed: {failed}") print(f"Total Duration: {total_duration:.2f}s") print(f"Average Duration: {total_duration/total_examples:.2f}s") - + print("\nšŸ“‹ Individual Results:") print("-" * 80) for result in self.results: - status = "āœ…" if result['success'] else "āŒ" + status = "āœ…" if result["success"] else "āŒ" print(f"{status} {result['name']:<30} {result['duration']:>8.2f}s") - + if failed > 0: print("\nāŒ Failed Examples:") print("-" * 40) for result in self.results: - if not result['success']: + if not result["success"]: print(f"• {result['name']}: {result['stderr'][:100]}...") - - print("\n" + "="*80) - + + print("\n" + "=" * 80) + async def run_all(self): """Run all examples in sequence.""" print("šŸŽ® Running All Arcade.dev Examples") - print("="*50) - + print("=" * 50) + if not self.examples: print("āŒ No runnable examples found!") return False - + print(f"Found {len(self.examples)} examples to run") - + for example in self.examples: result = await self.run_example(example) self.results.append(result) - + # Small delay between examples await asyncio.sleep(1) - + self.print_summary() - + # Return True if all examples succeeded - return all(r['success'] for r in self.results) + return all(r["success"] for r in self.results) async def main(): """Main function.""" print("šŸŽ® Arcade.dev Examples - Batch Runner") - print("="*50) - + print("=" * 50) + # Check prerequisites - if not os.getenv('ARCADE_API_KEY'): + if not os.getenv("ARCADE_API_KEY"): print("āŒ ARCADE_API_KEY environment variable not set") print("Please copy .env.example to .env and configure your API key") return 1 - + runner = ExampleRunner() - + try: success = await runner.run_all() - + if success: print("\nšŸŽ‰ All examples completed successfully!") return 0 else: print("\nāš ļø Some examples failed. Check the summary above.") return 1 - + except KeyboardInterrupt: print("\nāŒ Execution interrupted by user") return 1 @@ -252,4 +255,4 @@ async def main(): if __name__ == "__main__": exit_code = asyncio.run(main()) - sys.exit(exit_code) \ No newline at end of file + sys.exit(exit_code) diff --git a/examples/arcade-dev/utils/__init__.py b/examples/arcade-dev/utils/__init__.py index 18623d0..89205d5 100644 --- a/examples/arcade-dev/utils/__init__.py +++ b/examples/arcade-dev/utils/__init__.py @@ -4,11 +4,16 @@ This package provides helper functions and utilities for running FACT arcade examples. """ -from .import_helper import setup_fact_imports, verify_fact_imports, get_fact_module_path, print_fact_module_info +from .import_helper import ( + setup_fact_imports, + verify_fact_imports, + get_fact_module_path, + print_fact_module_info, +) __all__ = [ - 'setup_fact_imports', - 'verify_fact_imports', - 'get_fact_module_path', - 'print_fact_module_info' -] \ No newline at end of file + "setup_fact_imports", + "verify_fact_imports", + "get_fact_module_path", + "print_fact_module_info", +] diff --git a/examples/arcade-dev/utils/import_helper.py b/examples/arcade-dev/utils/import_helper.py index 596b65e..01e4217 100644 --- a/examples/arcade-dev/utils/import_helper.py +++ b/examples/arcade-dev/utils/import_helper.py @@ -14,26 +14,28 @@ # Try to import python-dotenv if available try: from dotenv import load_dotenv + _DOTENV_AVAILABLE = True except ImportError: _DOTENV_AVAILABLE = False + def setup_fact_imports() -> Path: """ Set up Python path to allow importing FACT modules from any example location. - + This function finds the project root (directory containing 'src') and adds it to sys.path if it's not already there. - + Returns: Path to the project root directory - + Raises: ImportError: If the FACT project root cannot be found """ # Start from the current file's directory and work upward current_path = Path(__file__).resolve() - + # Look for the project root by finding the directory containing 'src' for parent in current_path.parents: src_dir = parent / "src" @@ -45,24 +47,24 @@ def setup_fact_imports() -> Path: "Could not find FACT project root. " "Please ensure this script is within the FACT project directory structure." ) - + # Add project root to Python path if not already present project_root_str = str(project_root) if project_root_str not in sys.path: sys.path.insert(0, project_root_str) - + # Load environment variables from root .env file env_file = project_root / ".env" if env_file.exists() and _DOTENV_AVAILABLE: load_dotenv(env_file) - + return project_root def verify_fact_imports() -> bool: """ Verify that essential FACT modules can be imported. - + Returns: True if all essential modules can be imported, False otherwise """ @@ -70,6 +72,7 @@ def verify_fact_imports() -> bool: # Test importing core FACT modules from src.core.driver import FACTDriver from src.cache.manager import CacheManager + return True except ImportError as e: print(f"Failed to import FACT modules: {e}") @@ -79,30 +82,30 @@ def verify_fact_imports() -> bool: def get_fact_module_path(module_name: str) -> Optional[Path]: """ Get the full path to a FACT module. - + Args: module_name: Name of the module (e.g., 'core.driver', 'cache.manager') - + Returns: Path to the module file if found, None otherwise """ try: project_root = setup_fact_imports() - module_parts = module_name.split('.') + module_parts = module_name.split(".") module_path = project_root / "src" - + for part in module_parts: module_path = module_path / part - + # Try both .py file and directory with __init__.py - py_file = module_path.with_suffix('.py') + py_file = module_path.with_suffix(".py") if py_file.exists(): return py_file - + init_file = module_path / "__init__.py" if init_file.exists(): return module_path - + return None except Exception: return None @@ -113,28 +116,32 @@ def print_fact_module_info(): try: project_root = setup_fact_imports() src_dir = project_root / "src" - + print("šŸ” FACT Module Information:") print(f" Project Root: {project_root}") print(f" Source Directory: {src_dir}") - + # List available modules print("\nšŸ“¦ Available Modules:") for item in src_dir.iterdir(): - if item.is_dir() and not item.name.startswith('.'): + if item.is_dir() and not item.name.startswith("."): print(f" šŸ“ {item.name}/") # List key files in each module for subitem in item.iterdir(): - if subitem.is_file() and subitem.suffix == '.py' and subitem.name != '__init__.py': + if ( + subitem.is_file() + and subitem.suffix == ".py" + and subitem.name != "__init__.py" + ): print(f" šŸ“„ {subitem.name}") - + # Test imports print("\nāœ… Import Test:") if verify_fact_imports(): print(" All essential FACT modules can be imported successfully!") else: print(" āŒ Some FACT modules failed to import") - + except Exception as e: print(f"āŒ Error getting module info: {e}") @@ -146,4 +153,4 @@ def print_fact_module_info(): if __name__ == "__main__": # When run directly, display module information - print_fact_module_info() \ No newline at end of file + print_fact_module_info() diff --git a/examples/arcade-dev/verify_setup.py b/examples/arcade-dev/verify_setup.py index b0afd75..3657a34 100644 --- a/examples/arcade-dev/verify_setup.py +++ b/examples/arcade-dev/verify_setup.py @@ -17,46 +17,51 @@ # Load environment variables from .env file if it exists try: from dotenv import load_dotenv + load_dotenv() except ImportError: pass # dotenv not available, skip loading # Setup logging -logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s') +logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s") logger = logging.getLogger(__name__) class SetupVerifier: """Verifies the setup for Arcade.dev examples.""" - + def __init__(self): self.errors: List[str] = [] self.warnings: List[str] = [] self.success: List[str] = [] - + def check_python_version(self) -> bool: """Check Python version compatibility.""" required_version = (3, 8) current_version = sys.version_info[:2] - + if current_version >= required_version: - self.success.append(f"āœ… Python {'.'.join(map(str, current_version))} (>= 3.8)") + self.success.append( + f"āœ… Python {'.'.join(map(str, current_version))} (>= 3.8)" + ) return True else: - self.errors.append(f"āŒ Python {'.'.join(map(str, current_version))} < 3.8 (required)") + self.errors.append( + f"āŒ Python {'.'.join(map(str, current_version))} < 3.8 (required)" + ) return False - + def check_required_packages(self) -> bool: """Check if required Python packages are installed.""" # Map package names to their import names required_packages = { - 'aiohttp': 'aiohttp', - 'asyncio': 'asyncio', - 'pydantic': 'pydantic', - 'python-dotenv': 'dotenv', - 'redis': 'redis' + "aiohttp": "aiohttp", + "asyncio": "asyncio", + "pydantic": "pydantic", + "python-dotenv": "dotenv", + "redis": "redis", } - + all_installed = True for package_name, import_name in required_packages.items(): try: @@ -65,147 +70,156 @@ def check_required_packages(self) -> bool: except ImportError: self.errors.append(f"āŒ Package '{package_name}' is not installed") all_installed = False - + return all_installed - + def check_environment_variables(self) -> bool: """Check required environment variables.""" required_vars = [ - ('ARCADE_API_KEY', True), - ('ARCADE_API_URL', False), - ('ARCADE_TIMEOUT', False), - ('ARCADE_MAX_RETRIES', False), - ('FACT_LOG_LEVEL', False), - ('FACT_CACHE_ENABLED', False) + ("ARCADE_API_KEY", True), + ("ARCADE_API_URL", False), + ("ARCADE_TIMEOUT", False), + ("ARCADE_MAX_RETRIES", False), + ("FACT_LOG_LEVEL", False), + ("FACT_CACHE_ENABLED", False), ] - + all_set = True for var_name, required in required_vars: value = os.getenv(var_name) if value: - if var_name == 'ARCADE_API_KEY': + if var_name == "ARCADE_API_KEY": display_value = f"{value[:8]}..." if len(value) > 8 else "***" else: display_value = value self.success.append(f"āœ… {var_name}={display_value}") elif required: - self.errors.append(f"āŒ Required environment variable '{var_name}' is not set") + self.errors.append( + f"āŒ Required environment variable '{var_name}' is not set" + ) all_set = False else: - self.warnings.append(f"āš ļø Optional environment variable '{var_name}' is not set") - + self.warnings.append( + f"āš ļø Optional environment variable '{var_name}' is not set" + ) + return all_set - + def check_fact_framework(self) -> bool: """Check if FACT framework components are accessible.""" - fact_modules = [ - 'src.core.driver', - 'src.cache.manager' - ] - + fact_modules = ["src.core.driver", "src.cache.manager"] + # Add parent directories to path project_root = Path(__file__).parent.parent.parent sys.path.insert(0, str(project_root)) - + all_accessible = True for module_name in fact_modules: try: importlib.import_module(module_name) self.success.append(f"āœ… FACT module '{module_name}' is accessible") except ImportError as e: - self.errors.append(f"āŒ FACT module '{module_name}' is not accessible: {e}") + self.errors.append( + f"āŒ FACT module '{module_name}' is not accessible: {e}" + ) all_accessible = False - + return all_accessible - + def check_file_structure(self) -> bool: """Check if required files and directories exist.""" base_dir = Path(__file__).parent required_paths = [ - 'README.md', - 'requirements.txt', - '01_basic_integration/basic_arcade_client.py', - 'config' + "README.md", + "requirements.txt", + "01_basic_integration/basic_arcade_client.py", + "config", ] - + all_exist = True for path_str in required_paths: path = base_dir / path_str if path.exists(): self.success.append(f"āœ… Path '{path_str}' exists") else: - if path_str == 'config': - self.warnings.append(f"āš ļø Directory '{path_str}' does not exist (will be created)") + if path_str == "config": + self.warnings.append( + f"āš ļø Directory '{path_str}' does not exist (will be created)" + ) else: self.errors.append(f"āŒ Required path '{path_str}' does not exist") all_exist = False - + return all_exist - + async def check_network_connectivity(self) -> bool: """Check network connectivity to Arcade.dev API.""" try: import aiohttp - api_url = os.getenv('ARCADE_API_URL', 'https://api.arcade.dev') - + + api_url = os.getenv("ARCADE_API_URL", "https://api.arcade.dev") + # List of endpoints to try (in order of preference) - endpoints_to_try = [ - '/health', - '/v1/health', - '/status', - '/' - ] - + endpoints_to_try = ["/health", "/v1/health", "/status", "/"] + async with aiohttp.ClientSession() as session: for endpoint in endpoints_to_try: try: - async with session.get(f"{api_url}{endpoint}", timeout=10) as response: + async with session.get( + f"{api_url}{endpoint}", timeout=10 + ) as response: if response.status == 200: - self.success.append(f"āœ… Network connectivity to {api_url} is working (endpoint: {endpoint})") + self.success.append( + f"āœ… Network connectivity to {api_url} is working (endpoint: {endpoint})" + ) return True elif response.status == 404: # 404 is expected for some endpoints, continue trying continue elif response.status in [401, 403]: # Authentication errors indicate the API is reachable - self.success.append(f"āœ… Network connectivity to {api_url} is working (authentication required)") + self.success.append( + f"āœ… Network connectivity to {api_url} is working (authentication required)" + ) return True else: - self.warnings.append(f"āš ļø {api_url}{endpoint} returned status {response.status}") + self.warnings.append( + f"āš ļø {api_url}{endpoint} returned status {response.status}" + ) continue except asyncio.TimeoutError: continue except Exception as e: continue - + # If we get here, none of the endpoints worked - self.warnings.append(f"āš ļø Could not verify API connectivity to {api_url} (all endpoints returned 404 or failed)") - self.warnings.append(f"āš ļø This may be normal if the API requires authentication or has a different endpoint structure") + self.warnings.append( + f"āš ļø Could not verify API connectivity to {api_url} (all endpoints returned 404 or failed)" + ) + self.warnings.append( + f"āš ļø This may be normal if the API requires authentication or has a different endpoint structure" + ) return False - + except Exception as e: self.warnings.append(f"āš ļø Could not verify network connectivity: {e}") return False - + def create_missing_directories(self): """Create missing directories.""" base_dir = Path(__file__).parent - directories = [ - 'config', - 'output', - 'logs' - ] - + directories = ["config", "output", "logs"] + for dir_name in directories: dir_path = base_dir / dir_name if not dir_path.exists(): dir_path.mkdir(parents=True, exist_ok=True) self.success.append(f"āœ… Created directory '{dir_name}'") - + async def run_verification(self) -> bool: """Run all verification checks.""" print("šŸ” Verifying Arcade.dev Examples Setup...\n") - + checks = [ ("Python Version", self.check_python_version), ("Required Packages", self.check_required_packages), @@ -213,52 +227,54 @@ async def run_verification(self) -> bool: ("FACT Framework", self.check_fact_framework), ("File Structure", self.check_file_structure), ] - + # Run synchronous checks all_passed = True for check_name, check_func in checks: print(f"Checking {check_name}...") passed = check_func() all_passed = all_passed and passed - + # Run async checks print("Checking Network Connectivity...") await self.check_network_connectivity() - + # Create missing directories print("Creating Missing Directories...") self.create_missing_directories() - + return all_passed - + def print_summary(self): """Print verification summary.""" - print("\n" + "="*60) + print("\n" + "=" * 60) print("SETUP VERIFICATION SUMMARY") - print("="*60) - + print("=" * 60) + if self.success: print("\nāœ… PASSED:") for item in self.success: print(f" {item}") - + if self.warnings: print("\nāš ļø WARNINGS:") for item in self.warnings: print(f" {item}") - + if self.errors: print("\nāŒ ERRORS:") for item in self.errors: print(f" {item}") - - print("\n" + "="*60) - + + print("\n" + "=" * 60) + if self.errors: print("āŒ Setup verification FAILED. Please fix the errors above.") print("\nNext steps:") print("1. Install missing packages: pip install -r requirements.txt") - print("2. Copy .env.example to .env and configure your environment variables:") + print( + "2. Copy .env.example to .env and configure your environment variables:" + ) print(" cp .env.example .env") print(" # Then edit .env with your actual API keys and configuration") print("3. Ensure FACT framework is properly installed") @@ -267,7 +283,9 @@ def print_summary(self): else: print("āœ… Setup verification PASSED!") if self.warnings: - print("Note: There are warnings that should be addressed for optimal functionality.") + print( + "Note: There are warnings that should be addressed for optimal functionality." + ) print("\nYou can now run the Arcade.dev examples:") print(" python 01_basic_integration/basic_arcade_client.py") return True @@ -276,12 +294,12 @@ def print_summary(self): async def main(): """Main verification function.""" verifier = SetupVerifier() - + try: success = await verifier.run_verification() verifier.print_summary() return 0 if success else 1 - + except KeyboardInterrupt: print("\nāŒ Verification interrupted by user") return 1 @@ -292,4 +310,4 @@ async def main(): if __name__ == "__main__": exit_code = asyncio.run(main()) - sys.exit(exit_code) \ No newline at end of file + sys.exit(exit_code) diff --git a/examples/tool_execution_demo.py b/examples/tool_execution_demo.py index d7bcc96..ef2eabe 100644 --- a/examples/tool_execution_demo.py +++ b/examples/tool_execution_demo.py @@ -13,7 +13,7 @@ from typing import Dict, Any # Add the project root to the path -sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..')) +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) # Import FACT system components from src.tools.executor import ToolExecutor, create_tool_call @@ -29,15 +29,15 @@ description="Perform basic mathematical calculations", parameters={ "operation": { - "type": "string", + "type": "string", "description": "Mathematical operation", - "enum": ["add", "subtract", "multiply", "divide"] + "enum": ["add", "subtract", "multiply", "divide"], }, "a": {"type": "number", "description": "First number"}, - "b": {"type": "number", "description": "Second number"} + "b": {"type": "number", "description": "Second number"}, }, requires_auth=False, - timeout_seconds=10 + timeout_seconds=10, ) def calculator_tool(operation: str, a: float, b: float) -> Dict[str, Any]: """Perform basic mathematical calculations.""" @@ -45,22 +45,22 @@ def calculator_tool(operation: str, a: float, b: float) -> Dict[str, Any]: "add": lambda x, y: x + y, "subtract": lambda x, y: x - y, "multiply": lambda x, y: x * y, - "divide": lambda x, y: x / y if y != 0 else None + "divide": lambda x, y: x / y if y != 0 else None, } - + if operation not in operations: return {"error": f"Unknown operation: {operation}"} - + if operation == "divide" and b == 0: return {"error": "Division by zero"} - + result = operations[operation](a, b) - + return { "operation": operation, "operands": [a, b], "result": result, - "expression": f"{a} {operation} {b} = {result}" + "expression": f"{a} {operation} {b} = {result}", } @@ -68,31 +68,33 @@ def calculator_tool(operation: str, a: float, b: float) -> Dict[str, Any]: name="Demo.TextProcessor", description="Process and analyze text content", parameters={ - "text": { - "type": "string", - "description": "Text to process", - "maxLength": 1000 - }, + "text": {"type": "string", "description": "Text to process", "maxLength": 1000}, "operations": { "type": "array", "description": "Operations to perform", "items": { "type": "string", - "enum": ["uppercase", "lowercase", "reverse", "word_count", "char_count"] + "enum": [ + "uppercase", + "lowercase", + "reverse", + "word_count", + "char_count", + ], }, - "default": ["word_count"] - } + "default": ["word_count"], + }, }, requires_auth=False, - timeout_seconds=15 + timeout_seconds=15, ) def text_processor_tool(text: str, operations: list = None) -> Dict[str, Any]: """Process and analyze text content.""" if operations is None: operations = ["word_count"] - + results = {"original_text": text, "processed": {}} - + for operation in operations: if operation == "uppercase": results["processed"]["uppercase"] = text.upper() @@ -104,7 +106,7 @@ def text_processor_tool(text: str, operations: list = None) -> Dict[str, Any]: results["processed"]["word_count"] = len(text.split()) elif operation == "char_count": results["processed"]["char_count"] = len(text) - + return results @@ -115,31 +117,33 @@ def text_processor_tool(text: str, operations: list = None) -> Dict[str, Any]: "data_type": { "type": "string", "description": "Type of data to generate", - "enum": ["numbers", "names", "emails", "dates"] + "enum": ["numbers", "names", "emails", "dates"], }, "count": { "type": "integer", "description": "Number of items to generate", "minimum": 1, "maximum": 100, - "default": 10 + "default": 10, }, "format": { "type": "string", "description": "Output format", "enum": ["list", "json", "csv"], - "default": "list" - } + "default": "list", + }, }, requires_auth=False, - timeout_seconds=20 + timeout_seconds=20, ) -def data_generator_tool(data_type: str, count: int = 10, format: str = "list") -> Dict[str, Any]: +def data_generator_tool( + data_type: str, count: int = 10, format: str = "list" +) -> Dict[str, Any]: """Generate sample data sets for testing.""" import random import string from datetime import datetime, timedelta - + generators = { "numbers": lambda: [random.randint(1, 1000) for _ in range(count)], "names": lambda: [f"User_{i+1}" for i in range(count)], @@ -147,20 +151,16 @@ def data_generator_tool(data_type: str, count: int = 10, format: str = "list") - "dates": lambda: [ (datetime.now() - timedelta(days=random.randint(0, 365))).isoformat() for _ in range(count) - ] + ], } - + if data_type not in generators: return {"error": f"Unknown data type: {data_type}"} - + data = generators[data_type]() - - result = { - "data_type": data_type, - "count": len(data), - "format": format - } - + + result = {"data_type": data_type, "count": len(data), "format": format} + if format == "list": result["data"] = data elif format == "json": @@ -170,85 +170,77 @@ def data_generator_tool(data_type: str, count: int = 10, format: str = "list") - result["data"] = "value\n" + "\n".join(map(str, data)) else: result["data"] = f"{data_type}\n" + "\n".join(map(str, data)) - + return result async def demo_basic_tool_execution(): """Demonstrate basic tool execution without Arcade.""" print("\n=== Demo: Basic Tool Execution ===") - + # Create tool executor (local execution only) executor = ToolExecutor( - arcade_client=None, - enable_rate_limiting=True, - max_calls_per_minute=60 + arcade_client=None, enable_rate_limiting=True, max_calls_per_minute=60 ) - + # Test calculator tool print("\n1. Testing Calculator Tool:") calc_call = create_tool_call( - "Demo.Calculator", - {"operation": "multiply", "a": 15, "b": 7} + "Demo.Calculator", {"operation": "multiply", "a": 15, "b": 7} ) - + result = await executor.execute_tool_call(calc_call) print(f" Result: {result.data}") print(f" Execution time: {result.execution_time_ms:.2f}ms") - + # Test text processor tool print("\n2. Testing Text Processor Tool:") text_call = create_tool_call( "Demo.TextProcessor", { "text": "Hello World! This is a test.", - "operations": ["uppercase", "word_count", "char_count"] - } + "operations": ["uppercase", "word_count", "char_count"], + }, ) - + result = await executor.execute_tool_call(text_call) print(f" Original: {result.data['original_text']}") print(f" Processed: {json.dumps(result.data['processed'], indent=2)}") - + # Test data generator tool print("\n3. Testing Data Generator Tool:") data_call = create_tool_call( - "Demo.DataGenerator", - {"data_type": "emails", "count": 5, "format": "list"} + "Demo.DataGenerator", {"data_type": "emails", "count": 5, "format": "list"} ) - + result = await executor.execute_tool_call(data_call) print(f" Generated {result.data['count']} {result.data['data_type']}:") - for email in result.data['data']: + for email in result.data["data"]: print(f" - {email}") async def demo_error_handling(): """Demonstrate error handling in tool execution.""" print("\n=== Demo: Error Handling ===") - + executor = ToolExecutor(arcade_client=None) - + # Test division by zero print("\n1. Testing Division by Zero:") error_call = create_tool_call( - "Demo.Calculator", - {"operation": "divide", "a": 10, "b": 0} + "Demo.Calculator", {"operation": "divide", "a": 10, "b": 0} ) - + result = await executor.execute_tool_call(error_call) if result.success: print(f" Error message: {result.data['error']}") else: print(f" Execution failed: {result.error}") - + # Test invalid tool print("\n2. Testing Invalid Tool:") - invalid_call = create_tool_call( - "NonExistent.Tool", - {"param": "value"} - ) - + invalid_call = create_tool_call("NonExistent.Tool", {"param": "value"}) + result = await executor.execute_tool_call(invalid_call) print(f" Error: {result.error}") print(f" Status: {result.status_code}") @@ -257,26 +249,26 @@ async def demo_error_handling(): async def demo_concurrent_execution(): """Demonstrate concurrent tool execution.""" print("\n=== Demo: Concurrent Execution ===") - + executor = ToolExecutor(arcade_client=None) - + # Create multiple tool calls tool_calls = [ - create_tool_call("Demo.Calculator", {"operation": "add", "a": i, "b": i*2}) + create_tool_call("Demo.Calculator", {"operation": "add", "a": i, "b": i * 2}) for i in range(1, 6) ] - + print(f"\nExecuting {len(tool_calls)} calculator operations concurrently...") - + start_time = time.time() results = await executor.execute_tool_calls(tool_calls) execution_time = (time.time() - start_time) * 1000 - + print(f"Total execution time: {execution_time:.2f}ms") print("Results:") for i, result in enumerate(results): if result.success: - expr = result.data['expression'] + expr = result.data["expression"] print(f" {i+1}. {expr} (took {result.execution_time_ms:.1f}ms)") else: print(f" {i+1}. Error: {result.error}") @@ -285,24 +277,21 @@ async def demo_concurrent_execution(): async def demo_rate_limiting(): """Demonstrate rate limiting functionality.""" print("\n=== Demo: Rate Limiting ===") - + # Create executor with low rate limit executor = ToolExecutor( arcade_client=None, enable_rate_limiting=True, - max_calls_per_minute=3 # Very low limit for demo + max_calls_per_minute=3, # Very low limit for demo ) - + print("\nTesting rate limiting (max 3 calls per minute):") - + for i in range(5): - call = create_tool_call( - "Demo.Calculator", - {"operation": "add", "a": i, "b": 1} - ) - + call = create_tool_call("Demo.Calculator", {"operation": "add", "a": i, "b": 1}) + result = await executor.execute_tool_call(call) - + if result.success: print(f" Call {i+1}: Success - {result.data['result']}") else: @@ -312,34 +301,36 @@ async def demo_rate_limiting(): async def demo_metrics_collection(): """Demonstrate metrics collection and reporting.""" print("\n=== Demo: Metrics Collection ===") - + executor = ToolExecutor(arcade_client=None) metrics_collector = get_metrics_collector() - + # Execute several tools to generate metrics print("\nExecuting tools to generate metrics...") - + operations = [ {"operation": "add", "a": 10, "b": 5}, {"operation": "multiply", "a": 3, "b": 7}, {"operation": "subtract", "a": 20, "b": 8}, - {"operation": "divide", "a": 15, "b": 3} + {"operation": "divide", "a": 15, "b": 3}, ] - + for op in operations: call = create_tool_call("Demo.Calculator", op) result = await executor.execute_tool_call(call) - + # Metrics are automatically collected by the executor - print(f" {op['operation']}: {result.data.get('result') if result.success else 'Error'}") - + print( + f" {op['operation']}: {result.data.get('result') if result.success else 'Error'}" + ) + # Get system metrics print("\nSystem Metrics:") system_metrics = metrics_collector.get_system_metrics(time_window_minutes=5) print(f" Total executions: {system_metrics.total_executions}") print(f" Success rate: {100 - system_metrics.error_rate:.1f}%") print(f" Average execution time: {system_metrics.average_execution_time:.2f}ms") - + # Get tool-specific metrics print("\nCalculator Tool Metrics:") tool_metrics = metrics_collector.get_tool_metrics("Demo.Calculator") @@ -351,36 +342,36 @@ async def demo_metrics_collection(): async def demo_available_tools(): """Demonstrate tool discovery and information retrieval.""" print("\n=== Demo: Tool Discovery ===") - + executor = ToolExecutor(arcade_client=None) - + # Get list of available tools tools = executor.get_available_tools() - + print(f"\nFound {len(tools)} available tools:") for tool in tools: - function_info = tool['function'] + function_info = tool["function"] print(f"\n Tool: {function_info['name']}") print(f" Description: {function_info['description']}") - + # Show parameters - params = function_info.get('parameters', {}).get('properties', {}) + params = function_info.get("parameters", {}).get("properties", {}) if params: print(" Parameters:") for param_name, param_info in params.items(): - param_type = param_info.get('type', 'unknown') - param_desc = param_info.get('description', 'No description') + param_type = param_info.get("type", "unknown") + param_desc = param_info.get("description", "No description") print(f" - {param_name} ({param_type}): {param_desc}") async def demo_arcade_integration(): """Demonstrate Arcade.dev integration (simulated).""" print("\n=== Demo: Arcade.dev Integration (Simulated) ===") - + # Note: This would require actual Arcade.dev credentials print("\nSimulating Arcade.dev integration...") print("(In production, this would connect to actual Arcade.dev platform)") - + # Create mock Arcade client for demo class MockArcadeClient: async def execute_tool(self, tool_name, arguments, **kwargs): @@ -389,18 +380,17 @@ async def execute_tool(self, tool_name, arguments, **kwargs): "success": True, "data": f"Arcade executed {tool_name} with {arguments}", "execution_environment": "secure_container", - "execution_id": f"arcade_exec_{int(time.time())}" + "execution_id": f"arcade_exec_{int(time.time())}", } - + mock_client = MockArcadeClient() executor = ToolExecutor(arcade_client=mock_client) - + # Execute tool via "Arcade" call = create_tool_call( - "Demo.Calculator", - {"operation": "multiply", "a": 6, "b": 9} + "Demo.Calculator", {"operation": "multiply", "a": 6, "b": 9} ) - + result = await executor.execute_tool_call(call) print(f" Arcade execution result: {result.data}") print(f" Execution time: {result.execution_time_ms:.2f}ms") @@ -410,7 +400,7 @@ async def main(): """Run all demonstrations.""" print("FACT System Tool Execution Framework Demo") print("=" * 50) - + try: await demo_basic_tool_execution() await demo_error_handling() @@ -419,16 +409,17 @@ async def main(): await demo_metrics_collection() await demo_available_tools() await demo_arcade_integration() - + print("\n" + "=" * 50) print("All demos completed successfully!") - + except Exception as e: print(f"\nDemo failed with error: {e}") import traceback + traceback.print_exc() if __name__ == "__main__": # Run the demo - asyncio.run(main()) \ No newline at end of file + asyncio.run(main()) diff --git a/fact-memory/examples/basic_usage.py b/fact-memory/examples/basic_usage.py index 3a5f904..f99d3cb 100644 --- a/fact-memory/examples/basic_usage.py +++ b/fact-memory/examples/basic_usage.py @@ -13,121 +13,123 @@ async def basic_memory_operations(): """Demonstrate basic memory operations.""" - + # Initialize memory manager with FACT cache integration memory_manager = FactMemoryManager( cache_config={ "prefix": "fact_memory_demo", "min_tokens": 50, # Lower for demo "max_size": "1MB", - "ttl_seconds": 3600 + "ttl_seconds": 3600, } ) - + user_id = "demo_user_123" - + print("=== FACT Memory System - Basic Usage Demo ===\n") - + # 1. Add various types of memories print("1. Adding memories...") - + memories = [ { "content": "User prefers dark mode interface with high contrast colors", "memory_type": MemoryType.PREFERENCE, - "tags": ["ui", "accessibility", "dark-mode"] + "tags": ["ui", "accessibility", "dark-mode"], }, { "content": "User is a software engineer working on AI/ML projects", "memory_type": MemoryType.FACT, - "tags": ["profession", "ai", "ml"] + "tags": ["profession", "ai", "ml"], }, { "content": "User frequently asks about Python optimization techniques", "memory_type": MemoryType.BEHAVIOR, - "tags": ["python", "optimization", "patterns"] + "tags": ["python", "optimization", "patterns"], }, { "content": "Always provide code examples with detailed explanations", "memory_type": MemoryType.INSTRUCTION, - "tags": ["communication", "code", "teaching"] - } + "tags": ["communication", "code", "teaching"], + }, ] - + added_memories = [] for memory_data in memories: memory = await memory_manager.add_memory( user_id=user_id, content=memory_data["content"], memory_type=memory_data["memory_type"], - tags=memory_data["tags"] + tags=memory_data["tags"], ) added_memories.append(memory) print(f"āœ“ Added {memory_data['memory_type'].value}: {memory.id}") - + print(f"\nAdded {len(added_memories)} memories for user {user_id}\n") - + # 2. Search for memories print("2. Searching memories...") - + search_queries = [ "What are the user's interface preferences?", "Tell me about the user's profession", "How should I communicate with this user?", - "What programming topics interest the user?" + "What programming topics interest the user?", ] - + for query in search_queries: print(f"\nQuery: '{query}'") results = await memory_manager.search_memories( - user_id=user_id, - query=query, - limit=3 + user_id=user_id, query=query, limit=3 ) - + for i, memory in enumerate(results, 1): print(f" {i}. [{memory.memory_type.value}] {memory.content[:60]}...") - print(f" Relevance: {memory.relevance_score:.2f} | Tags: {', '.join(memory.tags)}") - + print( + f" Relevance: {memory.relevance_score:.2f} | Tags: {', '.join(memory.tags)}" + ) + # 3. Get all memories by type print("\n3. Retrieving memories by type...") - + for memory_type in MemoryType: memories = await memory_manager.get_memories_by_type(user_id, memory_type) print(f"{memory_type.value.title()}: {len(memories)} memories") - + # 4. Update a memory print("\n4. Updating a memory...") - - preference_memory = next(m for m in added_memories if m.memory_type == MemoryType.PREFERENCE) + + preference_memory = next( + m for m in added_memories if m.memory_type == MemoryType.PREFERENCE + ) updated_memory = await memory_manager.update_memory( user_id=user_id, memory_id=preference_memory.id, content="User prefers dark mode with blue accent colors and large fonts for accessibility", - tags=["ui", "accessibility", "dark-mode", "fonts", "colors"] + tags=["ui", "accessibility", "dark-mode", "fonts", "colors"], ) print(f"āœ“ Updated memory: {updated_memory.id}") - + # 5. Get memory statistics print("\n5. Memory statistics...") - + stats = await memory_manager.get_user_stats(user_id) print(f"Total memories: {stats['total_memories']}") print(f"Memory types: {stats['memory_by_type']}") print(f"Cache hit rate: {stats['cache_hit_rate']:.1%}") print(f"Average access time: {stats['avg_access_time_ms']}ms") - + return added_memories async def advanced_search_examples(): """Demonstrate advanced search capabilities.""" - + memory_manager = FactMemoryManager() user_id = "advanced_user_456" - + print("\n=== Advanced Search Examples ===\n") - + # Add some complex memories for advanced search complex_memories = [ "User is building a machine learning pipeline using Python and scikit-learn for customer churn prediction", @@ -136,48 +138,45 @@ async def advanced_search_examples(): "Uses VS Code with Copilot extension and prefers vim keybindings", "Interested in functional programming concepts, especially in JavaScript and Python", "Team lead responsible for code reviews and architecture decisions", - "Prefers Test-Driven Development approach with pytest for Python projects" + "Prefers Test-Driven Development approach with pytest for Python projects", ] - + for i, content in enumerate(complex_memories): await memory_manager.add_memory( user_id=user_id, content=content, memory_type=MemoryType.FACT, - tags=[f"tech_{i}", "development"] + tags=[f"tech_{i}", "development"], ) - + # Advanced search scenarios search_scenarios = [ { "query": "What technologies does the user work with?", - "description": "Technology stack inquiry" + "description": "Technology stack inquiry", }, { "query": "How does the user prefer to develop software?", - "description": "Development methodology preferences" + "description": "Development methodology preferences", }, { "query": "What is the user's role and responsibilities?", - "description": "Professional role inquiry" + "description": "Professional role inquiry", }, { "query": "Tell me about the user's deployment preferences", - "description": "Infrastructure and deployment" - } + "description": "Infrastructure and deployment", + }, ] - + for scenario in search_scenarios: print(f"Scenario: {scenario['description']}") print(f"Query: '{scenario['query']}'") - + results = await memory_manager.search_memories( - user_id=user_id, - query=scenario["query"], - limit=5, - min_relevance=0.3 + user_id=user_id, query=scenario["query"], limit=5, min_relevance=0.3 ) - + print(f"Found {len(results)} relevant memories:") for i, memory in enumerate(results, 1): print(f" {i}. {memory.content[:80]}...") @@ -187,26 +186,23 @@ async def advanced_search_examples(): async def mcp_integration_example(): """Demonstrate MCP server integration.""" - + print("=== MCP Integration Example ===\n") - + # This would typically be done in a separate process/server from fact_memory.mcp import FactMemoryMCPServer - + # Initialize MCP server mcp_server = FactMemoryMCPServer( host="localhost", port=8080, - memory_config={ - "max_memories_per_user": 1000, - "cache_ttl_seconds": 3600 - } + memory_config={"max_memories_per_user": 1000, "cache_ttl_seconds": 3600}, ) - + print("MCP Server initialized with tools:") for tool_name in mcp_server.get_available_tools(): print(f" - {tool_name}") - + # Simulate MCP client interactions client_requests = [ { @@ -215,45 +211,46 @@ async def mcp_integration_example(): "content": "User prefers concise explanations with practical examples", "userId": "mcp_user_789", "memoryType": "instruction", - "tags": ["communication", "learning"] - } + "tags": ["communication", "learning"], + }, }, { "tool": "search-memories", "params": { "query": "How should I communicate with this user?", "userId": "mcp_user_789", - "limit": 3 - } - } + "limit": 3, + }, + }, ] - + print("\nSimulating MCP client requests:") for request in client_requests: print(f"\nRequest: {request['tool']}") print(f"Params: {request['params']}") - + # Simulate tool execution response = await mcp_server.handle_tool_request( - tool_name=request["tool"], - parameters=request["params"] + tool_name=request["tool"], parameters=request["params"] ) - + print(f"Response: {response['content'][0]['text'][:100]}...") if __name__ == "__main__": + async def main(): try: await basic_memory_operations() await advanced_search_examples() await mcp_integration_example() - + print("\nāœ“ All examples completed successfully!") - + except Exception as e: print(f"\nāŒ Error running examples: {e}") import traceback + traceback.print_exc() - - asyncio.run(main()) \ No newline at end of file + + asyncio.run(main()) diff --git a/fact-memory/examples/mcp_client.py b/fact-memory/examples/mcp_client.py index fdf21ec..cbf3baf 100644 --- a/fact-memory/examples/mcp_client.py +++ b/fact-memory/examples/mcp_client.py @@ -16,6 +16,7 @@ @dataclass class MCPResponse: """Represents an MCP tool response.""" + content: List[Dict[str, Any]] metadata: Dict[str, Any] = None @@ -23,30 +24,30 @@ class MCPResponse: class FactMemoryMCPClient: """ MCP client for FACT Memory System. - + This client demonstrates how to interact with the FACT Memory MCP server using the standard MCP protocol. It's compatible with existing Mem0 workflows. """ - + def __init__(self, server_url: str = "http://localhost:8080", api_key: str = None): self.server_url = server_url self.api_key = api_key self.headers = { "Content-Type": "application/json", - "Authorization": f"Bearer {api_key}" if api_key else None + "Authorization": f"Bearer {api_key}" if api_key else None, } - + async def tool(self, tool_name: str, parameters: Dict[str, Any]) -> MCPResponse: """ Execute an MCP tool on the FACT Memory server. - + This method simulates the MCP protocol interaction. In a real implementation, this would make HTTP requests to the MCP server. """ # Simulate MCP tool execution print(f"šŸ“” MCP Tool Call: {tool_name}") print(f" Parameters: {json.dumps(parameters, indent=2)}") - + # Simulate different tool responses if tool_name == "add-memory": return await self._handle_add_memory(parameters) @@ -60,13 +61,13 @@ async def tool(self, tool_name: str, parameters: Dict[str, Any]) -> MCPResponse: return await self._handle_get_stats(parameters) else: raise ValueError(f"Unknown tool: {tool_name}") - + async def _handle_add_memory(self, params: Dict[str, Any]) -> MCPResponse: """Simulate add-memory tool response.""" memory_id = f"mem_{hash(params['content']) % 1000000:06d}" - + response_text = f"Memory added successfully. ID: {memory_id}" - + return MCPResponse( content=[{"type": "text", "text": response_text}], metadata={ @@ -74,49 +75,55 @@ async def _handle_add_memory(self, params: Dict[str, Any]) -> MCPResponse: "userId": params["userId"], "memoryType": params.get("memoryType", "fact"), "tokenCount": len(params["content"].split()) * 1.3, # Rough estimate - "timestamp": "2024-01-15T10:30:00Z" - } + "timestamp": "2024-01-15T10:30:00Z", + }, ) - + async def _handle_search_memories(self, params: Dict[str, Any]) -> MCPResponse: """Simulate search-memories tool response.""" # Simulate search results based on query query = params["query"].lower() user_id = params["userId"] - + # Mock search results based on common queries mock_memories = [] - + if "preference" in query or "interface" in query: - mock_memories.append({ - "content": "User prefers dark mode interface with high contrast", - "relevance": 0.95, - "type": "preference", - "tags": ["ui", "accessibility"], - "created": "2024-01-15T10:30:00Z" - }) - + mock_memories.append( + { + "content": "User prefers dark mode interface with high contrast", + "relevance": 0.95, + "type": "preference", + "tags": ["ui", "accessibility"], + "created": "2024-01-15T10:30:00Z", + } + ) + if "profession" in query or "work" in query: - mock_memories.append({ - "content": "User is a software engineer working on AI/ML projects", - "relevance": 0.87, - "type": "fact", - "tags": ["profession", "ai"], - "created": "2024-01-14T15:20:00Z" - }) - + mock_memories.append( + { + "content": "User is a software engineer working on AI/ML projects", + "relevance": 0.87, + "type": "fact", + "tags": ["profession", "ai"], + "created": "2024-01-14T15:20:00Z", + } + ) + if "communication" in query or "explain" in query: - mock_memories.append({ - "content": "Always provide code examples with detailed explanations", - "relevance": 0.78, - "type": "instruction", - "tags": ["communication", "code"], - "created": "2024-01-13T09:15:00Z" - }) - + mock_memories.append( + { + "content": "Always provide code examples with detailed explanations", + "relevance": 0.78, + "type": "instruction", + "tags": ["communication", "code"], + "created": "2024-01-13T09:15:00Z", + } + ) + # Format response text response_lines = [] - for memory in mock_memories[:params.get("limit", 10)]: + for memory in mock_memories[: params.get("limit", 10)]: response_lines.append( f"Memory: {memory['content']}\n" f"Relevance: {memory['relevance']}\n" @@ -124,9 +131,9 @@ async def _handle_search_memories(self, params: Dict[str, Any]) -> MCPResponse: f"Tags: {', '.join(memory['tags'])}\n" f"Created: {memory['created']}\n---" ) - + response_text = "\n".join(response_lines) - + return MCPResponse( content=[{"type": "text", "text": response_text}], metadata={ @@ -134,14 +141,14 @@ async def _handle_search_memories(self, params: Dict[str, Any]) -> MCPResponse: "searchTime": "45ms", "cacheHit": True, "query": params["query"], - "userId": user_id - } + "userId": user_id, + }, ) - + async def _handle_get_memories(self, params: Dict[str, Any]) -> MCPResponse: """Simulate get-memories tool response.""" user_id = params["userId"] - + # Mock user memories memories = [ { @@ -149,17 +156,17 @@ async def _handle_get_memories(self, params: Dict[str, Any]) -> MCPResponse: "content": "User prefers dark mode interface", "type": "preference", "created": "2024-01-15T10:30:00Z", - "tags": ["ui", "accessibility"] + "tags": ["ui", "accessibility"], }, { - "id": "mem_001235", + "id": "mem_001235", "content": "User is a Python developer", "type": "fact", "created": "2024-01-14T15:20:00Z", - "tags": ["profession", "python"] - } + "tags": ["profession", "python"], + }, ] - + response_lines = [] for memory in memories: response_lines.append( @@ -169,45 +176,45 @@ async def _handle_get_memories(self, params: Dict[str, Any]) -> MCPResponse: f"Created: {memory['created']}\n" f"Tags: {', '.join(memory['tags'])}\n---" ) - + response_text = "\n".join(response_lines) - + return MCPResponse( content=[{"type": "text", "text": response_text}], metadata={ "totalMemories": len(memories), "userId": user_id, "page": 1, - "hasMore": False - } + "hasMore": False, + }, ) - + async def _handle_delete_memory(self, params: Dict[str, Any]) -> MCPResponse: """Simulate delete-memory tool response.""" memory_id = params["memoryId"] user_id = params["userId"] - + response_text = f"Memory {memory_id} deleted successfully for user {user_id}" - + return MCPResponse( content=[{"type": "text", "text": response_text}], metadata={ "memoryId": memory_id, "userId": user_id, - "deletedAt": "2024-01-15T10:30:00Z" - } + "deletedAt": "2024-01-15T10:30:00Z", + }, ) - + async def _handle_get_stats(self, params: Dict[str, Any]) -> MCPResponse: """Simulate get-memory-stats tool response.""" user_id = params["userId"] - + response_text = ( f"Memory Statistics for user: {user_id}\n" f"Total memories: 45\n" f"Memory types:\n" f"- Preferences: 12\n" - f"- Facts: 18\n" + f"- Facts: 18\n" f"- Context: 8\n" f"- Behavior: 5\n" f"- Instructions: 2\n" @@ -215,7 +222,7 @@ async def _handle_get_stats(self, params: Dict[str, Any]) -> MCPResponse: f"Cache hit rate: 94.2%\n" f"Average access time: 28ms" ) - + return MCPResponse( content=[{"type": "text", "text": response_text}], metadata={ @@ -226,93 +233,91 @@ async def _handle_get_stats(self, params: Dict[str, Any]) -> MCPResponse: "fact": 18, "context": 8, "behavior": 5, - "instruction": 2 + "instruction": 2, }, "totalSizeBytes": 128409, "cacheHitRate": 0.942, - "avgAccessTimeMs": 28 - } + "avgAccessTimeMs": 28, + }, ) async def demonstrate_mcp_compatibility(): """Demonstrate MCP compatibility with Mem0-style usage.""" - + print("=== FACT Memory MCP Client Demo ===\n") - + # Initialize MCP client client = FactMemoryMCPClient( - server_url="http://localhost:8080", - api_key="your_fact_memory_api_key" + server_url="http://localhost:8080", api_key="your_fact_memory_api_key" ) - + user_id = "demo_user_mcp" - + # 1. Add memories (Mem0-compatible) print("1. Adding memories via MCP...") - + memories_to_add = [ { "content": "User prefers concise code reviews with actionable feedback", "userId": user_id, "memoryType": "instruction", - "tags": ["code-review", "communication"] + "tags": ["code-review", "communication"], }, { "content": "User works with React, TypeScript, and Node.js daily", "userId": user_id, "memoryType": "fact", - "tags": ["technology", "frontend", "backend"] + "tags": ["technology", "frontend", "backend"], }, { "content": "User asks for optimization tips during code discussions", "userId": user_id, "memoryType": "behavior", - "tags": ["optimization", "learning"] - } + "tags": ["optimization", "learning"], + }, ] - + for memory_data in memories_to_add: response = await client.tool("add-memory", memory_data) print(f"āœ“ {response.content[0]['text']}") - + print() - + # 2. Search memories print("2. Searching memories...") - - search_response = await client.tool("search-memories", { - "query": "How should I provide feedback to this user?", - "userId": user_id, - "limit": 5, - "minRelevance": 0.3 - }) - + + search_response = await client.tool( + "search-memories", + { + "query": "How should I provide feedback to this user?", + "userId": user_id, + "limit": 5, + "minRelevance": 0.3, + }, + ) + print("Search Results:") print(search_response.content[0]["text"]) print(f"Search completed in {search_response.metadata['searchTime']}") print() - + # 3. Get all memories print("3. Retrieving all memories...") - - all_memories_response = await client.tool("get-memories", { - "userId": user_id, - "limit": 10, - "sortBy": "created" - }) - + + all_memories_response = await client.tool( + "get-memories", {"userId": user_id, "limit": 10, "sortBy": "created"} + ) + print("All Memories:") print(all_memories_response.content[0]["text"]) print() - + # 4. Get memory statistics print("4. Getting memory statistics...") - - stats_response = await client.tool("get-memory-stats", { - "userId": user_id - }) - + + stats_response = await client.tool("get-memory-stats", {"userId": user_id}) + print("Memory Statistics:") print(stats_response.content[0]["text"]) print() @@ -320,72 +325,76 @@ async def demonstrate_mcp_compatibility(): async def demonstrate_advanced_mcp_features(): """Demonstrate FACT Memory's enhanced MCP features.""" - + print("=== Advanced FACT Memory Features ===\n") - + client = FactMemoryMCPClient() user_id = "advanced_user_mcp" - + # Add memory with metadata and tags print("1. Adding memory with rich metadata...") - - enhanced_memory = await client.tool("add-memory", { - "content": "User prefers functional programming patterns in JavaScript and Python", - "userId": user_id, - "memoryType": "preference", - "tags": ["programming", "functional", "javascript", "python"], - "metadata": { - "source": "code_review_session", - "confidence": 0.95, - "context": "discussing async/await patterns", - "related_topics": ["promises", "async", "map", "filter", "reduce"] - } - }) - + + enhanced_memory = await client.tool( + "add-memory", + { + "content": "User prefers functional programming patterns in JavaScript and Python", + "userId": user_id, + "memoryType": "preference", + "tags": ["programming", "functional", "javascript", "python"], + "metadata": { + "source": "code_review_session", + "confidence": 0.95, + "context": "discussing async/await patterns", + "related_topics": ["promises", "async", "map", "filter", "reduce"], + }, + }, + ) + print(f"āœ“ Enhanced memory added: {enhanced_memory.metadata['memoryId']}") print() - + # Advanced search with filtering print("2. Advanced search with type filtering...") - - filtered_search = await client.tool("search-memories", { - "query": "programming preferences and patterns", - "userId": user_id, - "memoryType": "preference", - "tags": ["programming", "functional"], - "minRelevance": 0.5, - "limit": 3 - }) - + + filtered_search = await client.tool( + "search-memories", + { + "query": "programming preferences and patterns", + "userId": user_id, + "memoryType": "preference", + "tags": ["programming", "functional"], + "minRelevance": 0.5, + "limit": 3, + }, + ) + print("Filtered Search Results:") print(filtered_search.content[0]["text"]) print() - + # Batch operations (FACT Memory enhancement) print("3. Memory statistics with detailed breakdown...") - - detailed_stats = await client.tool("get-memory-stats", { - "userId": user_id - }) - + + detailed_stats = await client.tool("get-memory-stats", {"userId": user_id}) + stats_data = detailed_stats.metadata print(f"Cache Performance:") print(f" Hit Rate: {stats_data['cacheHitRate']:.1%}") print(f" Avg Access Time: {stats_data['avgAccessTimeMs']}ms") print(f"Memory Distribution:") - for mem_type, count in stats_data['memoryByType'].items(): + for mem_type, count in stats_data["memoryByType"].items(): print(f" {mem_type.title()}: {count}") print() async def demonstrate_mem0_migration(): """Show how existing Mem0 code works with FACT Memory.""" - + print("=== Mem0 Migration Compatibility ===\n") - + # This is exactly how Mem0 clients work - no changes needed! client = FactMemoryMCPClient() - + print("Existing Mem0 client code (unchanged):") print(""" # Original Mem0 usage @@ -399,32 +408,31 @@ async def demonstrate_mem0_migration(): "userId": "alice" }) """) - + # Execute the exact same code print("\nExecuting with FACT Memory (100% compatible):") - - await client.tool("add-memory", { - "content": "User prefers dark mode", - "userId": "alice" - }) - - results = await client.tool("search-memories", { - "query": "interface preferences", - "userId": "alice" - }) - + + await client.tool( + "add-memory", {"content": "User prefers dark mode", "userId": "alice"} + ) + + results = await client.tool( + "search-memories", {"query": "interface preferences", "userId": "alice"} + ) + print("āœ“ Perfect compatibility - existing Mem0 code works unchanged!") print("āœ“ But with FACT Memory's superior performance and caching!") print() if __name__ == "__main__": + async def main(): try: await demonstrate_mcp_compatibility() await demonstrate_advanced_mcp_features() await demonstrate_mem0_migration() - + print("šŸŽ‰ All MCP examples completed successfully!") print("\nKey Benefits of FACT Memory over Mem0:") print("• 3-5x faster response times (cache-based)") @@ -432,10 +440,11 @@ async def main(): print("• Enhanced memory types and metadata") print("• Superior semantic understanding") print("• Integrated with FACT SDK ecosystem") - + except Exception as e: print(f"\nāŒ Error running MCP examples: {e}") import traceback + traceback.print_exc() - - asyncio.run(main()) \ No newline at end of file + + asyncio.run(main()) diff --git a/fact-memory/examples/performance_comparison.py b/fact-memory/examples/performance_comparison.py index 019d336..7b40cbc 100644 --- a/fact-memory/examples/performance_comparison.py +++ b/fact-memory/examples/performance_comparison.py @@ -16,6 +16,7 @@ @dataclass class PerformanceMetrics: """Performance measurement results.""" + operation: str approach: str response_times: List[float] @@ -27,33 +28,35 @@ class PerformanceMetrics: class VectorMemorySimulator: """Simulates traditional vector database memory system.""" - + def __init__(self): self.memories = {} self.embeddings_cache = {} - + async def add_memory(self, user_id: str, content: str) -> str: """Simulate vector embedding and storage.""" # Simulate embedding generation (100-200ms) await asyncio.sleep(0.15) - + memory_id = f"vec_{hash(content) % 1000000:06d}" self.memories[memory_id] = { "user_id": user_id, "content": content, "embedding": [0.1] * 1536, # Mock embedding - "created_at": time.time() + "created_at": time.time(), } return memory_id - - async def search_memories(self, user_id: str, query: str, limit: int = 10) -> List[Dict]: + + async def search_memories( + self, user_id: str, query: str, limit: int = 10 + ) -> List[Dict]: """Simulate vector similarity search.""" # Simulate query embedding generation await asyncio.sleep(0.12) - + # Simulate vector similarity computation await asyncio.sleep(0.08) - + # Return mock results user_memories = [m for m in self.memories.values() if m["user_id"] == user_id] return user_memories[:limit] @@ -61,35 +64,37 @@ async def search_memories(self, user_id: str, query: str, limit: int = 10) -> Li class FactMemorySimulator: """Simulates FACT Memory cache-based system.""" - + def __init__(self): self.cache = {} self.cache_hits = 0 self.total_requests = 0 - + async def add_memory(self, user_id: str, content: str) -> str: """Simulate prompt cache storage.""" # Simulate cache write (10-20ms) await asyncio.sleep(0.015) - + memory_id = f"fact_{hash(content) % 1000000:06d}" cache_key = f"memory:{user_id}:{memory_id}" - + self.cache[cache_key] = { "user_id": user_id, "content": content, "created_at": time.time(), - "access_count": 0 + "access_count": 0, } return memory_id - - async def search_memories(self, user_id: str, query: str, limit: int = 10) -> List[Dict]: + + async def search_memories( + self, user_id: str, query: str, limit: int = 10 + ) -> List[Dict]: """Simulate cache-based search with LLM semantic understanding.""" self.total_requests += 1 - + # Check cache for similar queries query_cache_key = f"search:{user_id}:{hash(query) % 10000}" - + if query_cache_key in self.cache: # Cache hit - super fast response self.cache_hits += 1 @@ -99,40 +104,45 @@ async def search_memories(self, user_id: str, query: str, limit: int = 10) -> Li await asyncio.sleep(0.065) # 65ms LLM processing # Cache the result self.cache[query_cache_key] = {"results": "cached"} - + # Return mock results - user_memories = [m for m in self.cache.values() - if isinstance(m, dict) and m.get("user_id") == user_id] + user_memories = [ + m + for m in self.cache.values() + if isinstance(m, dict) and m.get("user_id") == user_id + ] return user_memories[:limit] - + def get_cache_hit_rate(self) -> float: return self.cache_hits / max(self.total_requests, 1) -async def measure_performance(system, operation_func, iterations: int = 100) -> PerformanceMetrics: +async def measure_performance( + system, operation_func, iterations: int = 100 +) -> PerformanceMetrics: """Measure performance of a specific operation.""" - + response_times = [] - + for i in range(iterations): start_time = time.time() await operation_func(system, i) end_time = time.time() - + response_times.append((end_time - start_time) * 1000) # Convert to ms - + # Small delay between operations to simulate real usage await asyncio.sleep(0.001) - + avg_time = statistics.mean(response_times) p95_time = statistics.quantiles(response_times, n=20)[18] # 95th percentile - + cache_hit_rate = 0.0 - if hasattr(system, 'get_cache_hit_rate'): + if hasattr(system, "get_cache_hit_rate"): cache_hit_rate = system.get_cache_hit_rate() - + throughput = iterations / (sum(response_times) / 1000) # ops per second - + return PerformanceMetrics( operation=operation_func.__name__, approach=system.__class__.__name__, @@ -140,7 +150,7 @@ async def measure_performance(system, operation_func, iterations: int = 100) -> avg_time=avg_time, p95_time=p95_time, cache_hit_rate=cache_hit_rate, - throughput_ops_per_sec=throughput + throughput_ops_per_sec=throughput, ) @@ -154,153 +164,172 @@ async def add_memory_test(system, iteration: int): async def search_memory_test(system, iteration: int): """Test memory search performance.""" user_id = f"user_{iteration % 10}" - + # Use repeating queries to simulate cache hits queries = [ "What are the user's preferences?", "Tell me about the user's work", "How should I communicate with this user?", "What technologies does the user use?", - "What are the user's interests?" + "What are the user's interests?", ] - + query = queries[iteration % len(queries)] await system.search_memories(user_id, query) async def run_performance_comparison(): """Run comprehensive performance comparison.""" - + print("=== FACT Memory vs Vector Database Performance Comparison ===\n") - + # Initialize systems vector_system = VectorMemorySimulator() fact_system = FactMemorySimulator() - + # Warm up both systems with some initial data print("šŸ”„ Warming up systems with initial data...") - + for i in range(20): await vector_system.add_memory(f"user_{i % 5}", f"Initial memory {i}") await fact_system.add_memory(f"user_{i % 5}", f"Initial memory {i}") - + print("āœ“ Systems warmed up\n") - + # Test scenarios test_scenarios = [ ("Add Memory", add_memory_test, 50), - ("Search Memory", search_memory_test, 100) + ("Search Memory", search_memory_test, 100), ] - + results = {} - + for scenario_name, test_func, iterations in test_scenarios: print(f"šŸ“Š Testing {scenario_name} ({iterations} iterations)...") - + # Test Vector Database approach print(f" Testing Vector Database approach...") vector_metrics = await measure_performance(vector_system, test_func, iterations) - - # Test FACT Memory approach + + # Test FACT Memory approach print(f" Testing FACT Memory approach...") fact_metrics = await measure_performance(fact_system, test_func, iterations) - - results[scenario_name] = { - "vector": vector_metrics, - "fact": fact_metrics - } - + + results[scenario_name] = {"vector": vector_metrics, "fact": fact_metrics} + print(f"āœ“ {scenario_name} testing completed\n") - + # Display results print("šŸ“ˆ Performance Comparison Results") print("=" * 60) - + for scenario_name, scenario_results in results.items(): vector_metrics = scenario_results["vector"] fact_metrics = scenario_results["fact"] - + print(f"\n{scenario_name}:") - print(f"{'Metric':<25} {'Vector DB':<15} {'FACT Memory':<15} {'Improvement':<15}") + print( + f"{'Metric':<25} {'Vector DB':<15} {'FACT Memory':<15} {'Improvement':<15}" + ) print("-" * 75) - + # Average response time - improvement = ((vector_metrics.avg_time - fact_metrics.avg_time) / vector_metrics.avg_time) * 100 - print(f"{'Avg Response Time':<25} {vector_metrics.avg_time:<14.1f}ms {fact_metrics.avg_time:<14.1f}ms {improvement:<14.1f}%") - + improvement = ( + (vector_metrics.avg_time - fact_metrics.avg_time) / vector_metrics.avg_time + ) * 100 + print( + f"{'Avg Response Time':<25} {vector_metrics.avg_time:<14.1f}ms {fact_metrics.avg_time:<14.1f}ms {improvement:<14.1f}%" + ) + # P95 response time - improvement = ((vector_metrics.p95_time - fact_metrics.p95_time) / vector_metrics.p95_time) * 100 - print(f"{'P95 Response Time':<25} {vector_metrics.p95_time:<14.1f}ms {fact_metrics.p95_time:<14.1f}ms {improvement:<14.1f}%") - + improvement = ( + (vector_metrics.p95_time - fact_metrics.p95_time) / vector_metrics.p95_time + ) * 100 + print( + f"{'P95 Response Time':<25} {vector_metrics.p95_time:<14.1f}ms {fact_metrics.p95_time:<14.1f}ms {improvement:<14.1f}%" + ) + # Throughput - improvement = ((fact_metrics.throughput_ops_per_sec - vector_metrics.throughput_ops_per_sec) / vector_metrics.throughput_ops_per_sec) * 100 - print(f"{'Throughput (ops/sec)':<25} {vector_metrics.throughput_ops_per_sec:<14.1f} {fact_metrics.throughput_ops_per_sec:<14.1f} {improvement:<14.1f}%") - + improvement = ( + ( + fact_metrics.throughput_ops_per_sec + - vector_metrics.throughput_ops_per_sec + ) + / vector_metrics.throughput_ops_per_sec + ) * 100 + print( + f"{'Throughput (ops/sec)':<25} {vector_metrics.throughput_ops_per_sec:<14.1f} {fact_metrics.throughput_ops_per_sec:<14.1f} {improvement:<14.1f}%" + ) + # Cache hit rate (only for FACT Memory) if fact_metrics.cache_hit_rate > 0: - print(f"{'Cache Hit Rate':<25} {'N/A':<15} {fact_metrics.cache_hit_rate:<14.1%} {'N/A':<15}") - + print( + f"{'Cache Hit Rate':<25} {'N/A':<15} {fact_metrics.cache_hit_rate:<14.1%} {'N/A':<15}" + ) + return results async def demonstrate_cache_benefits(): """Demonstrate the benefits of prompt caching.""" - + print("\nšŸš€ Cache Performance Benefits Demo") print("=" * 50) - + fact_system = FactMemorySimulator() - + # Add some memories user_id = "cache_demo_user" for i in range(10): await fact_system.add_memory(user_id, f"Memory content {i}") - + # Test repeated searches (simulating real usage patterns) common_queries = [ "What are the user's preferences?", "Tell me about the user's work style", - "How should I communicate with this user?" + "How should I communicate with this user?", ] - + print(f"\nTesting repeated searches (simulating real usage)...") search_times = [] - + for round_num in range(3): print(f"\nRound {round_num + 1}:") round_times = [] - + for query in common_queries: start_time = time.time() await fact_system.search_memories(user_id, query) end_time = time.time() - + response_time = (end_time - start_time) * 1000 round_times.append(response_time) search_times.append(response_time) - + cache_status = "HIT" if fact_system.get_cache_hit_rate() > 0 else "MISS" - print(f" Query: '{query[:40]}...' - {response_time:.1f}ms ({cache_status})") - + print( + f" Query: '{query[:40]}...' - {response_time:.1f}ms ({cache_status})" + ) + avg_round_time = statistics.mean(round_times) print(f" Round average: {avg_round_time:.1f}ms") - + print(f"\nCache Performance Summary:") print(f" Overall cache hit rate: {fact_system.get_cache_hit_rate():.1%}") print(f" Average search time: {statistics.mean(search_times):.1f}ms") print(f" First search (cold): {search_times[0]:.1f}ms") print(f" Cached searches: {statistics.mean(search_times[3:]):.1f}ms") - + cache_speedup = search_times[0] / statistics.mean(search_times[3:]) print(f" Cache speedup: {cache_speedup:.1f}x faster") async def resource_usage_comparison(): """Compare resource usage between approaches.""" - + print("\nšŸ’¾ Resource Usage Comparison") print("=" * 40) - + # Simulate resource usage vector_db_resources = { "memory_per_embedding": 1536 * 4, # 1536 dimensions * 4 bytes (float32) @@ -308,64 +337,91 @@ async def resource_usage_comparison(): "cpu_per_search": 0.85, # High CPU for similarity computation "storage_per_memory": 8192, # Bytes including metadata and indexing } - + fact_memory_resources = { "memory_per_cache_entry": 512, # Efficient cache representation "index_overhead_factor": 1.1, # Minimal indexing overhead "cpu_per_search": 0.15, # Low CPU with cache hits "storage_per_memory": 1024, # Compact storage format } - + num_memories = 10000 - + print(f"\nFor {num_memories:,} memories:") print(f"{'Resource':<20} {'Vector DB':<15} {'FACT Memory':<15} {'Savings':<15}") print("-" * 70) - + # Memory usage - vector_memory = (num_memories * vector_db_resources["memory_per_embedding"] * - vector_db_resources["index_overhead_factor"]) / (1024 * 1024) # MB - fact_memory = (num_memories * fact_memory_resources["memory_per_cache_entry"] * - fact_memory_resources["index_overhead_factor"]) / (1024 * 1024) # MB - + vector_memory = ( + num_memories + * vector_db_resources["memory_per_embedding"] + * vector_db_resources["index_overhead_factor"] + ) / ( + 1024 * 1024 + ) # MB + fact_memory = ( + num_memories + * fact_memory_resources["memory_per_cache_entry"] + * fact_memory_resources["index_overhead_factor"] + ) / ( + 1024 * 1024 + ) # MB + memory_savings = ((vector_memory - fact_memory) / vector_memory) * 100 - print(f"{'Memory Usage':<20} {vector_memory:<14.1f}MB {fact_memory:<14.1f}MB {memory_savings:<14.1f}%") - + print( + f"{'Memory Usage':<20} {vector_memory:<14.1f}MB {fact_memory:<14.1f}MB {memory_savings:<14.1f}%" + ) + # Storage usage - vector_storage = (num_memories * vector_db_resources["storage_per_memory"]) / (1024 * 1024) # MB - fact_storage = (num_memories * fact_memory_resources["storage_per_memory"]) / (1024 * 1024) # MB - + vector_storage = (num_memories * vector_db_resources["storage_per_memory"]) / ( + 1024 * 1024 + ) # MB + fact_storage = (num_memories * fact_memory_resources["storage_per_memory"]) / ( + 1024 * 1024 + ) # MB + storage_savings = ((vector_storage - fact_storage) / vector_storage) * 100 - print(f"{'Storage Usage':<20} {vector_storage:<14.1f}MB {fact_storage:<14.1f}MB {storage_savings:<14.1f}%") - + print( + f"{'Storage Usage':<20} {vector_storage:<14.1f}MB {fact_storage:<14.1f}MB {storage_savings:<14.1f}%" + ) + # CPU usage per search - cpu_savings = ((vector_db_resources["cpu_per_search"] - fact_memory_resources["cpu_per_search"]) / - vector_db_resources["cpu_per_search"]) * 100 - print(f"{'CPU per Search':<20} {vector_db_resources["cpu_per_search"]:<14.1f} {fact_memory_resources["cpu_per_search"]:<14.1f} {cpu_savings:<14.1f}%") + cpu_savings = ( + ( + vector_db_resources["cpu_per_search"] + - fact_memory_resources["cpu_per_search"] + ) + / vector_db_resources["cpu_per_search"] + ) * 100 + print( + f"{'CPU per Search':<20} {vector_db_resources["cpu_per_search"]:<14.1f} {fact_memory_resources["cpu_per_search"]:<14.1f} {cpu_savings:<14.1f}%" + ) if __name__ == "__main__": + async def main(): try: results = await run_performance_comparison() await demonstrate_cache_benefits() await resource_usage_comparison() - + print("\n" + "=" * 60) print("šŸŽÆ Key Performance Advantages of FACT Memory:") print("• 3-5x faster response times through intelligent caching") print("• 85%+ cache hit rates for typical usage patterns") - print("• 70%+ reduction in memory and storage requirements") + print("• 70%+ reduction in memory and storage requirements") print("• 80%+ reduction in CPU usage per search") print("• No vector database infrastructure needed") print("• Native LLM semantic understanding") print("• Seamless FACT SDK integration") - + print("\nāœ… Performance comparison completed successfully!") - + except Exception as e: print(f"\nāŒ Error in performance comparison: {e}") import traceback + traceback.print_exc() - - asyncio.run(main()) \ No newline at end of file + + asyncio.run(main()) diff --git a/fact-memory/src/hello_mcp_server.py b/fact-memory/src/hello_mcp_server.py index 0e8e421..4da5103 100755 --- a/fact-memory/src/hello_mcp_server.py +++ b/fact-memory/src/hello_mcp_server.py @@ -21,6 +21,7 @@ except ImportError: print("FastMCP not available. Install with: pip install fastmcp") import sys + sys.exit(1) # Configure logging @@ -41,83 +42,83 @@ - server_info: Provides information about this MCP server This server demonstrates FastMCP best practices and integration patterns. - """ + """, ) + @mcp.tool() async def hello(ctx: Context = None) -> Dict[str, Any]: """ Simple hello world tool that returns a greeting message. - + Args: ctx: FastMCP context for logging and resource access - + Returns: Dictionary containing a simple hello message and timestamp """ if ctx: await ctx.info("Executing hello tool") - + result = { "message": "Hello from FACT Memory MCP Server!", "timestamp": datetime.now().isoformat(), "server": "FACT Hello World Server", - "status": "active" + "status": "active", } - + if ctx: await ctx.info("Hello tool executed successfully") - + return result + @mcp.tool() -async def greet( - name: str, - ctx: Context = None -) -> Dict[str, Any]: +async def greet(name: str, ctx: Context = None) -> Dict[str, Any]: """ Personalized greeting tool that takes a name parameter. - + Args: name: The name to include in the greeting ctx: FastMCP context for logging and resource access - + Returns: Dictionary containing a personalized greeting message """ if ctx: await ctx.info(f"Executing greet tool for name: {name}") - + # Validate input if not name or not name.strip(): if ctx: await ctx.error("Empty name provided to greet tool") return { "error": "Name parameter is required and cannot be empty", - "status": "error" + "status": "error", } - + # Clean the name input clean_name = name.strip().title() - + result = { "message": f"Hello, {clean_name}! Welcome to the FACT Memory System.", "greeting_for": clean_name, "timestamp": datetime.now().isoformat(), "server": "FACT Hello World Server", - "status": "success" + "status": "success", } - + if ctx: await ctx.info(f"Greet tool executed successfully for: {clean_name}") - + return result + @mcp.resource("fact://server_info") async def get_server_info() -> Dict[str, Any]: """ Provide comprehensive information about this MCP server. - + Returns: Dictionary containing server metadata, capabilities, and status """ @@ -131,86 +132,79 @@ async def get_server_info() -> Dict[str, Any]: { "name": "hello", "description": "Returns a simple hello message", - "parameters": [] + "parameters": [], }, { - "name": "greet", + "name": "greet", "description": "Returns a personalized greeting", "parameters": [ { "name": "name", "type": "string", "required": True, - "description": "Name to include in greeting" + "description": "Name to include in greeting", } - ] - } + ], + }, ], "resources": [ { "name": "server_info", - "description": "Server information and metadata" + "description": "Server information and metadata", } - ] + ], }, "status": { "running": True, "uptime_start": datetime.now().isoformat(), - "health": "healthy" + "health": "healthy", }, "integration": { "fact_memory_compatible": True, "mcp_version": "1.0", - "transport": "stdio" + "transport": "stdio", }, "contact": { "documentation": "fact-memory/docs/", - "repository": "FACT Memory System" - } + "repository": "FACT Memory System", + }, } + class HelloWorldMCPServer: """ Hello World MCP Server class for advanced configuration and lifecycle management. """ - - def __init__( - self, - name: str = "FACT Hello World Server", - debug: bool = False - ): + + def __init__(self, name: str = "FACT Hello World Server", debug: bool = False): self.name = name self.debug = debug self.mcp = mcp self._setup_logging() - + def _setup_logging(self): """Configure logging based on debug setting.""" level = logging.DEBUG if self.debug else logging.INFO logging.getLogger().setLevel(level) - + if self.debug: logger.info(f"Debug mode enabled for {self.name}") - + def start_http(self, host: str = "localhost", port: int = 8080): """ Start the MCP server with HTTP transport. - + Args: host: Server host address port: Server port number """ logger.info(f"Starting {self.name} on {host}:{port}") try: - self.mcp.run( - transport="streamable-http", - host=host, - port=port - ) + self.mcp.run(transport="streamable-http", host=host, port=port) except Exception as e: logger.error(f"Failed to start HTTP server: {e}") raise - + def run_stdio(self): """ Run server with STDIO transport for CLI integration. @@ -222,33 +216,34 @@ def run_stdio(self): except Exception as e: logger.error(f"Failed to start STDIO server: {e}") raise - + async def test_tools(self): """ Test all available tools for development and debugging. """ logger.info("Testing MCP tools...") - + try: # Test hello tool hello_result = await hello() logger.info(f"Hello tool result: {hello_result}") - + # Test greet tool greet_result = await greet("FastMCP Developer") logger.info(f"Greet tool result: {greet_result}") - + # Test resource server_info = await get_server_info() logger.info(f"Server info: {server_info}") - + logger.info("All tools tested successfully!") return True - + except Exception as e: logger.error(f"Tool testing failed: {e}") return False + def main(): """ Main entry point for the Hello World MCP Server. @@ -256,48 +251,45 @@ def main(): """ import sys import argparse - + parser = argparse.ArgumentParser(description="FACT Hello World MCP Server") parser.add_argument( - "--mode", - choices=["stdio", "http", "test"], + "--mode", + choices=["stdio", "http", "test"], default="stdio", - help="Server mode: stdio (default), http, or test" + help="Server mode: stdio (default), http, or test", ) parser.add_argument( - "--host", + "--host", default="localhost", - help="Host address for HTTP mode (default: localhost)" + help="Host address for HTTP mode (default: localhost)", ) parser.add_argument( - "--port", - type=int, + "--port", + type=int, default=8080, - help="Port number for HTTP mode (default: 8080)" - ) - parser.add_argument( - "--debug", - action="store_true", - help="Enable debug logging" + help="Port number for HTTP mode (default: 8080)", ) - + parser.add_argument("--debug", action="store_true", help="Enable debug logging") + args = parser.parse_args() - + # Create server instance server = HelloWorldMCPServer(debug=args.debug) - + if args.mode == "stdio": logger.info("Running in STDIO mode for MCP client integration") server.run_stdio() - + elif args.mode == "http": logger.info(f"Running in HTTP mode on {args.host}:{args.port}") server.start_http(args.host, args.port) - + elif args.mode == "test": logger.info("Running in test mode") success = asyncio.run(server.test_tools()) sys.exit(0 if success else 1) + if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/fact-memory/src/test_client.py b/fact-memory/src/test_client.py index 4a70643..7335aa6 100644 --- a/fact-memory/src/test_client.py +++ b/fact-memory/src/test_client.py @@ -15,75 +15,79 @@ logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) + async def test_server_directly(): """ Test the server by importing and calling functions directly. This simulates what would happen when the server receives MCP calls. """ logger.info("Testing server functions directly...") - + # Import the server functions from hello_mcp_server import hello, greet, get_server_info - + try: # Test hello tool logger.info("Testing hello tool...") hello_result = await hello() logger.info(f"Hello result: {json.dumps(hello_result, indent=2)}") - + # Test greet tool with valid input logger.info("Testing greet tool with valid input...") greet_result = await greet("FastMCP Developer") logger.info(f"Greet result: {json.dumps(greet_result, indent=2)}") - + # Test greet tool with empty input (error case) logger.info("Testing greet tool with empty input...") error_result = await greet("") logger.info(f"Error result: {json.dumps(error_result, indent=2)}") - + # Test server info resource logger.info("Testing server info resource...") server_info = await get_server_info() logger.info(f"Server info: {json.dumps(server_info, indent=2)}") - + logger.info("āœ… All direct tests passed!") return True - + except Exception as e: logger.error(f"āŒ Direct testing failed: {e}") return False + async def test_with_fastmcp_client(): """ Test using FastMCP client (if available). This demonstrates the proper way to interact with MCP servers. """ logger.info("Testing with FastMCP client...") - + try: from fastmcp import Client from hello_mcp_server import mcp - + # Create client and connect to our server async with Client(mcp) as client: # Test hello tool logger.info("Calling hello tool via MCP client...") hello_response = await client.call_tool("hello", {}) logger.info(f"Hello response: {hello_response}") - + # Test greet tool logger.info("Calling greet tool via MCP client...") greet_response = await client.call_tool("greet", {"name": "MCP Client"}) logger.info(f"Greet response: {greet_response}") - + # Test resource access (Note: Resource access may not be available in all FastMCP client versions) - logger.info("Skipping resource access test - method varies by FastMCP version") + logger.info( + "Skipping resource access test - method varies by FastMCP version" + ) # resource_response = await client.read_resource("fact://server_info") # logger.info(f"Resource response: {resource_response}") - + logger.info("āœ… FastMCP client tests passed!") return True - + except ImportError: logger.warning("FastMCP client not available for testing") return False @@ -91,89 +95,94 @@ async def test_with_fastmcp_client(): logger.error(f"āŒ FastMCP client testing failed: {e}") return False + async def benchmark_performance(): """ Simple performance benchmark for the server tools. """ logger.info("Running performance benchmark...") - + from hello_mcp_server import hello, greet import time - + # Benchmark hello tool start_time = time.time() iterations = 1000 - + for i in range(iterations): await hello() - + hello_duration = time.time() - start_time hello_avg = (hello_duration / iterations) * 1000 # ms per call - + # Benchmark greet tool start_time = time.time() - + for i in range(iterations): await greet(f"User{i}") - + greet_duration = time.time() - start_time greet_avg = (greet_duration / iterations) * 1000 # ms per call - + logger.info(f"Performance Results ({iterations} iterations):") logger.info(f" Hello tool: {hello_avg:.2f}ms per call") logger.info(f" Greet tool: {greet_avg:.2f}ms per call") logger.info(f" Total time: {hello_duration + greet_duration:.2f}s") + def validate_server_structure(): """ Validate that the server follows FastMCP best practices. """ logger.info("Validating server structure...") - + from hello_mcp_server import mcp, HelloWorldMCPServer - + # Check server instance assert mcp.name == "FACT Hello World Server", "Server name mismatch" assert mcp.instructions, "Server instructions missing" - + # Check tools are registered tools = mcp._tool_manager._tools assert "hello" in tools, "Hello tool not registered" assert "greet" in tools, "Greet tool not registered" - + # Check resources are registered resources = mcp._resource_manager._resources - assert any("server_info" in str(uri) for uri in resources), "Server info resource not registered" - + assert any( + "server_info" in str(uri) for uri in resources + ), "Server info resource not registered" + # Check server class server = HelloWorldMCPServer() assert server.name == "FACT Hello World Server", "Server class name mismatch" - assert hasattr(server, 'run_stdio'), "STDIO method missing" - assert hasattr(server, 'start_http'), "HTTP method missing" - assert hasattr(server, 'test_tools'), "Test method missing" - + assert hasattr(server, "run_stdio"), "STDIO method missing" + assert hasattr(server, "start_http"), "HTTP method missing" + assert hasattr(server, "test_tools"), "Test method missing" + logger.info("āœ… Server structure validation passed!") + async def main(): """ Main test runner that executes all test scenarios. """ logger.info("šŸš€ Starting FACT Hello World MCP Server Tests") logger.info("=" * 60) - + tests_passed = 0 total_tests = 0 - + # Test 1: Direct function calls total_tests += 1 if await test_server_directly(): tests_passed += 1 - + # Test 2: FastMCP client integration total_tests += 1 if await test_with_fastmcp_client(): tests_passed += 1 - + # Test 3: Server structure validation total_tests += 1 try: @@ -182,7 +191,7 @@ async def main(): logger.info("āœ… Structure validation passed!") except Exception as e: logger.error(f"āŒ Structure validation failed: {e}") - + # Test 4: Performance benchmark total_tests += 1 try: @@ -191,11 +200,11 @@ async def main(): logger.info("āœ… Performance benchmark completed!") except Exception as e: logger.error(f"āŒ Performance benchmark failed: {e}") - + # Summary logger.info("=" * 60) logger.info(f"šŸ Test Summary: {tests_passed}/{total_tests} tests passed") - + if tests_passed == total_tests: logger.info("šŸŽ‰ All tests passed! Server is ready for use.") return True @@ -203,6 +212,7 @@ async def main(): logger.warning(f"āš ļø {total_tests - tests_passed} tests failed.") return False + if __name__ == "__main__": success = asyncio.run(main()) - exit(0 if success else 1) \ No newline at end of file + exit(0 if success else 1) diff --git a/fact-memory/src/test_mcp_integration.py b/fact-memory/src/test_mcp_integration.py index 3f422bc..4952604 100644 --- a/fact-memory/src/test_mcp_integration.py +++ b/fact-memory/src/test_mcp_integration.py @@ -16,12 +16,13 @@ logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) + async def test_mcp_stdio_protocol(): """ Test the MCP server with proper MCP protocol initialization. """ logger.info("Testing MCP STDIO protocol compliance...") - + try: # Start the server as a subprocess server_path = Path(__file__).parent / "hello_mcp_server.py" @@ -31,12 +32,12 @@ async def test_mcp_stdio_protocol(): stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, - bufsize=0 + bufsize=0, ) - + # Wait a moment for server to start await asyncio.sleep(0.5) - + # Send MCP initialization init_request = { "jsonrpc": "2.0", @@ -44,24 +45,16 @@ async def test_mcp_stdio_protocol(): "method": "initialize", "params": { "protocolVersion": "2024-11-05", - "capabilities": { - "roots": { - "listChanged": True - }, - "sampling": {} - }, - "clientInfo": { - "name": "test-client", - "version": "1.0.0" - } - } + "capabilities": {"roots": {"listChanged": True}, "sampling": {}}, + "clientInfo": {"name": "test-client", "version": "1.0.0"}, + }, } - + # Send the request request_json = json.dumps(init_request) + "\n" process.stdin.write(request_json) process.stdin.flush() - + # Wait for response with timeout start_time = time.time() response_line = "" @@ -74,24 +67,26 @@ async def test_mcp_stdio_protocol(): if chunk: response_line += chunk # Look for complete JSON message - if '\n' in response_line: + if "\n" in response_line: # Get the first complete line - lines = response_line.split('\n') + lines = response_line.split("\n") response_line = lines[0] break await asyncio.sleep(0.01) except: break - + # Clean up process.terminate() process.wait(timeout=2) - + if response_line: try: response = json.loads(response_line.strip()) - logger.info(f"MCP Initialize Response: {json.dumps(response, indent=2)}") - + logger.info( + f"MCP Initialize Response: {json.dumps(response, indent=2)}" + ) + # Validate response structure if "result" in response and "protocolVersion" in response["result"]: logger.info("āœ… MCP protocol initialization successful!") @@ -106,142 +101,146 @@ async def test_mcp_stdio_protocol(): else: logger.error("āŒ No response received from MCP server") return False - + except Exception as e: logger.error(f"āŒ MCP protocol test failed: {e}") return False + async def test_error_handling(): """ Test error handling for invalid inputs. """ logger.info("Testing error handling...") - + from hello_mcp_server import greet - + try: # Test empty name result = await greet("") assert "error" in result, "Empty name should return error" logger.info("āœ… Empty name error handling works") - + # Test whitespace-only name result = await greet(" ") assert "error" in result, "Whitespace name should return error" logger.info("āœ… Whitespace name error handling works") - + # Test valid name result = await greet("Test User") assert "message" in result, "Valid name should return success" logger.info("āœ… Valid name handling works") - + return True - + except Exception as e: logger.error(f"āŒ Error handling test failed: {e}") return False + async def test_resource_access(): """ Test resource access functionality. """ logger.info("Testing resource access...") - + from hello_mcp_server import get_server_info - + try: server_info = await get_server_info() - + # Validate required fields required_fields = ["name", "version", "capabilities", "status"] for field in required_fields: assert field in server_info, f"Missing required field: {field}" - + # Validate capabilities structure caps = server_info["capabilities"] assert "tools" in caps, "Missing tools in capabilities" assert "resources" in caps, "Missing resources in capabilities" - + # Validate tool information tools = caps["tools"] tool_names = [tool["name"] for tool in tools] assert "hello" in tool_names, "Missing hello tool" assert "greet" in tool_names, "Missing greet tool" - + logger.info("āœ… Resource access validation passed") return True - + except Exception as e: logger.error(f"āŒ Resource access test failed: {e}") return False + async def test_performance_stress(): """ Test performance under load. """ logger.info("Testing performance under stress...") - + from hello_mcp_server import hello, greet - + try: # Test rapid-fire calls start_time = time.time() tasks = [] - + # Create 100 concurrent hello calls for i in range(100): tasks.append(hello()) - + # Create 100 concurrent greet calls for i in range(100): tasks.append(greet(f"User{i}")) - + # Execute all tasks concurrently results = await asyncio.gather(*tasks) - + duration = time.time() - start_time avg_time = (duration / len(tasks)) * 1000 # ms per call - + logger.info(f"Stress test completed: {len(tasks)} calls in {duration:.2f}s") logger.info(f"Average response time: {avg_time:.2f}ms per call") - + # Validate all results are successful success_count = 0 for result in results: if "message" in result or "greeting_for" in result: success_count += 1 - + success_rate = (success_count / len(results)) * 100 logger.info(f"Success rate: {success_rate:.1f}%") - + if success_rate >= 99: logger.info("āœ… Stress test passed") return True else: logger.warning("āš ļø Stress test had some failures") return False - + except Exception as e: logger.error(f"āŒ Stress test failed: {e}") return False + async def main(): """ Run comprehensive MCP integration tests. """ logger.info("šŸš€ Starting MCP Integration Tests") logger.info("=" * 60) - + tests = [ ("MCP Protocol Compliance", test_mcp_stdio_protocol), ("Error Handling", test_error_handling), ("Resource Access", test_resource_access), ("Performance Stress", test_performance_stress), ] - + passed = 0 total = len(tests) - + for test_name, test_func in tests: logger.info(f"\n🧪 Running: {test_name}") try: @@ -252,10 +251,10 @@ async def main(): logger.error(f"āŒ {test_name} FAILED") except Exception as e: logger.error(f"āŒ {test_name} FAILED with exception: {e}") - + logger.info("=" * 60) logger.info(f"šŸ Integration Test Summary: {passed}/{total} tests passed") - + if passed == total: logger.info("šŸŽ‰ All integration tests passed! MCP server is production ready.") return True @@ -263,6 +262,7 @@ async def main(): logger.warning(f"āš ļø {total - passed} integration tests failed.") return False + if __name__ == "__main__": success = asyncio.run(main()) - exit(0 if success else 1) \ No newline at end of file + exit(0 if success else 1) diff --git a/main.py b/main.py index 1f46bf9..f9ecdae 100644 --- a/main.py +++ b/main.py @@ -12,7 +12,7 @@ import argparse # Add src directory to Python path -sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'src')) +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "src")) from src.core.cli import main as cli_main from src.core.config import get_config @@ -23,22 +23,22 @@ async def init_command(): """Initialize the FACT system environment.""" try: print("šŸš€ Initializing FACT System...") - + # Get configuration config = get_config() print(f"āœ… Configuration loaded") print(f" • Database: {config.database_path}") print(f" • Model: {config.claude_model}") - + # Initialize driver driver = await get_driver(config) print("āœ… System initialized successfully") print(" • Database schema ready") print(" • Tools registered") print(" • Ready for queries") - + return 0 - + except Exception as e: print(f"āŒ Initialization failed: {e}") return 1 @@ -48,17 +48,17 @@ async def demo_command(): """Run a demonstration of the FACT system.""" try: print("šŸŽŖ Running FACT System Demo...") - + # Initialize system driver = await get_driver() - + # Demo queries demo_queries = [ "Show me all companies in our database", "What is TechCorp's latest revenue?", - "List all financial records for Q1 2025" + "List all financial records for Q1 2025", ] - + for i, query in enumerate(demo_queries, 1): print(f"\nšŸ“ Demo Query {i}: {query}") try: @@ -66,10 +66,10 @@ async def demo_command(): print(f"šŸ“Š Response: {response}") except Exception as e: print(f"āŒ Query failed: {e}") - + print("\nāœ… Demo completed") return 0 - + except Exception as e: print(f"āŒ Demo failed: {e}") return 1 @@ -79,25 +79,21 @@ async def main(): """Main entry point with command routing.""" parser = argparse.ArgumentParser( description="FACT System - Fast-Access Cached Tools", - formatter_class=argparse.RawDescriptionHelpFormatter + formatter_class=argparse.RawDescriptionHelpFormatter, ) - + parser.add_argument( "command", nargs="?", choices=["init", "demo", "interactive"], default="interactive", - help="Command to execute (default: interactive)" - ) - - parser.add_argument( - "--query", - type=str, - help="Process a single query and exit" + help="Command to execute (default: interactive)", ) - + + parser.add_argument("--query", type=str, help="Process a single query and exit") + args = parser.parse_args() - + try: if args.command == "init": return await init_command() @@ -109,7 +105,7 @@ async def main(): else: parser.print_help() return 1 - + except KeyboardInterrupt: print("\nšŸ‘‹ Interrupted by user") return 0 @@ -121,7 +117,7 @@ async def main(): if __name__ == "__main__": """ Main entry point for the FACT system. - + Usage: python main.py # Interactive mode python main.py init # Initialize system @@ -129,4 +125,4 @@ async def main(): python main.py --query "..." # Single query mode """ exit_code = asyncio.run(main()) - sys.exit(exit_code) \ No newline at end of file + sys.exit(exit_code) diff --git a/requirements.txt b/requirements.txt index 8b2d73f..2a7736b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,10 +1,10 @@ # FACT System Dependencies # Core dependencies -anthropic==0.19.1 -litellm==1.0.0 -aiohttp==3.9.5 -python-dotenv==1.0.1 -pydantic==2.8.2 +anthropic>=0.49.0 +litellm>=1.61.15 +aiohttp>=3.13.4 +python-dotenv>=1.2.2 +pydantic>=2.10.6 # Database aiosqlite==0.20.0 diff --git a/scripts/demo_cache_resilience.py b/scripts/demo_cache_resilience.py index 1bcf1ce..b86f523 100644 --- a/scripts/demo_cache_resilience.py +++ b/scripts/demo_cache_resilience.py @@ -29,8 +29,11 @@ sys.path.insert(0, src_path) from cache.resilience import ( - CacheCircuitBreaker, CircuitBreakerConfig, ResilientCacheWrapper, - CircuitState, FailureRecord + CacheCircuitBreaker, + CircuitBreakerConfig, + ResilientCacheWrapper, + CircuitState, + FailureRecord, ) from cache.manager import CacheManager from core.errors import CacheError @@ -41,7 +44,7 @@ processors=[ structlog.processors.TimeStamper(fmt="ISO"), structlog.processors.add_log_level, - structlog.processors.JSONRenderer() + structlog.processors.JSONRenderer(), ], logger_factory=structlog.PrintLoggerFactory(), wrapper_class=structlog.BoundLogger, @@ -53,86 +56,96 @@ class DemoCacheManager: """Demo cache manager that can simulate various failure scenarios.""" - + def __init__(self): self.cache = {} self.failure_mode = "normal" # normal, intermittent, total_failure self.operation_count = 0 self.failure_probability = 0.0 - + def set_failure_mode(self, mode: str, probability: float = 0.0): """Set the failure mode for demonstration.""" self.failure_mode = mode self.failure_probability = probability self.operation_count = 0 - + mode_descriptions = { "normal": "Normal operation - no failures", "intermittent": f"Intermittent failures - {probability*100}% failure rate", - "total_failure": "Total failure - all operations fail" + "total_failure": "Total failure - all operations fail", } - + print(f"\nšŸ”§ Cache Mode Changed: {mode_descriptions.get(mode, mode)}") - + def _should_fail(self) -> bool: """Determine if this operation should fail based on current mode.""" self.operation_count += 1 - + if self.failure_mode == "normal": return False elif self.failure_mode == "total_failure": return True elif self.failure_mode == "intermittent": import random + return random.random() < self.failure_probability - + return False - + async def store(self, query_hash: str, content: str): """Store with potential failure simulation.""" await asyncio.sleep(0.01) # Simulate network latency - + if self._should_fail(): - raise CacheError("Simulated cache store failure", error_code="CACHE_STORE_FAILED") - - self.cache[query_hash] = { - 'content': content, - 'timestamp': time.time() - } + raise CacheError( + "Simulated cache store failure", error_code="CACHE_STORE_FAILED" + ) + + self.cache[query_hash] = {"content": content, "timestamp": time.time()} return True - + async def get(self, query_hash: str): """Get with potential failure simulation.""" await asyncio.sleep(0.01) # Simulate network latency - + if self._should_fail(): - raise CacheError("Simulated cache get failure", error_code="CACHE_GET_FAILED") - + raise CacheError( + "Simulated cache get failure", error_code="CACHE_GET_FAILED" + ) + if query_hash in self.cache: - return type('CacheEntry', (), { - 'content': self.cache[query_hash]['content'], - 'timestamp': self.cache[query_hash]['timestamp'] - })() + return type( + "CacheEntry", + (), + { + "content": self.cache[query_hash]["content"], + "timestamp": self.cache[query_hash]["timestamp"], + }, + )() return None - + async def invalidate_by_prefix(self, prefix: str) -> int: """Invalidate with potential failure simulation.""" await asyncio.sleep(0.01) # Simulate network latency - + if self._should_fail(): - raise CacheError("Simulated cache invalidate failure", error_code="CACHE_INVALIDATE_FAILED") - + raise CacheError( + "Simulated cache invalidate failure", + error_code="CACHE_INVALIDATE_FAILED", + ) + count = 0 keys_to_remove = [k for k in self.cache.keys() if k.startswith(prefix)] for key in keys_to_remove: del self.cache[key] count += 1 - + return count - + def generate_hash(self, query: str) -> str: """Generate hash for query.""" import hashlib + return hashlib.sha256(query.encode()).hexdigest()[:16] @@ -143,11 +156,15 @@ def print_banner(title: str): print(f"{'='*60}") -def print_metrics(circuit_breaker: CacheCircuitBreaker, resilient_cache: ResilientCacheWrapper, title: str = "Current Metrics"): +def print_metrics( + circuit_breaker: CacheCircuitBreaker, + resilient_cache: ResilientCacheWrapper, + title: str = "Current Metrics", +): """Print formatted metrics.""" cb_metrics = circuit_breaker.get_metrics() cache_metrics = resilient_cache.get_metrics() - + print(f"\nšŸ“Š {title}") print(f"ā”œā”€ Circuit Breaker State: {cb_metrics.state.value.upper()}") print(f"ā”œā”€ Total Operations: {cb_metrics.total_operations}") @@ -156,11 +173,15 @@ def print_metrics(circuit_breaker: CacheCircuitBreaker, resilient_cache: Resilie print(f"ā”œā”€ Failure Rate: {cb_metrics.failure_rate:.2%}") print(f"ā”œā”€ State Changes: {cb_metrics.state_changes}") print(f"ā”œā”€ Recent Failures: {len(cb_metrics.recent_failures)}") - print(f"ā”œā”€ Graceful Degradation: {'Enabled' if resilient_cache.enable_graceful_degradation else 'Disabled'}") - - if 'cache' in cache_metrics: - cache_stats = cache_metrics['cache'] - print(f"└─ Cache Stats: {json.dumps(cache_stats, indent=2) if isinstance(cache_stats, dict) else cache_stats}") + print( + f"ā”œā”€ Graceful Degradation: {'Enabled' if resilient_cache.enable_graceful_degradation else 'Disabled'}" + ) + + if "cache" in cache_metrics: + cache_stats = cache_metrics["cache"] + print( + f"└─ Cache Stats: {json.dumps(cache_stats, indent=2) if isinstance(cache_stats, dict) else cache_stats}" + ) else: print(f"└─ Cache Stats: Not available") @@ -168,7 +189,7 @@ def print_metrics(circuit_breaker: CacheCircuitBreaker, resilient_cache: Resilie async def demo_normal_operation(): """Demonstrate normal cache operation.""" print_banner("1. Normal Cache Operation") - + # Configure circuit breaker for demo config = CircuitBreakerConfig( failure_threshold=3, @@ -176,23 +197,23 @@ async def demo_normal_operation(): timeout_seconds=3.0, rolling_window_seconds=60.0, gradual_recovery=True, - recovery_factor=0.5 + recovery_factor=0.5, ) - + circuit_breaker = CacheCircuitBreaker(config) demo_cache = DemoCacheManager() resilient_cache = ResilientCacheWrapper(demo_cache, circuit_breaker) - + print("āœ… Cache resilience system initialized") print("šŸ”„ Circuit breaker configured with:") print(f" • Failure threshold: {config.failure_threshold}") print(f" • Success threshold: {config.success_threshold}") print(f" • Timeout: {config.timeout_seconds}s") - + demo_cache.set_failure_mode("normal") - + print("\nšŸš€ Performing normal cache operations...") - + # Perform successful operations operations = [ ("store", "query1", "Sample query result 1"), @@ -201,7 +222,7 @@ async def demo_normal_operation(): ("store", "query3", "Sample query result 3"), ("get", "query2", None), ] - + for i, (op, key, value) in enumerate(operations, 1): try: if op == "store": @@ -213,21 +234,23 @@ async def demo_normal_operation(): print(f" {i}. āœ… Get '{key}': {content}") except Exception as e: print(f" {i}. āŒ {op.title()} '{key}': {e}") - + print_metrics(circuit_breaker, resilient_cache, "After Normal Operations") - + return circuit_breaker, demo_cache, resilient_cache -async def demo_failure_and_circuit_breaker(circuit_breaker, demo_cache, resilient_cache): +async def demo_failure_and_circuit_breaker( + circuit_breaker, demo_cache, resilient_cache +): """Demonstrate cache failures and circuit breaker activation.""" print_banner("2. Cache Failures & Circuit Breaker Activation") - + print("āš ļø Introducing cache failures to trigger circuit breaker...") demo_cache.set_failure_mode("total_failure") - + print("\nšŸ”„ Attempting operations with failing cache:") - + # Attempt operations that will fail failure_operations = [ ("store", "fail1", "This will fail"), @@ -236,7 +259,7 @@ async def demo_failure_and_circuit_breaker(circuit_breaker, demo_cache, resilien ("store", "fail3", "This will trigger circuit breaker"), ("get", "fail1", None), ] - + for i, (op, key, value) in enumerate(failure_operations, 1): try: if op == "store": @@ -244,7 +267,9 @@ async def demo_failure_and_circuit_breaker(circuit_breaker, demo_cache, resilien print(f" {i}. āœ… Store '{key}': {result} (graceful degradation)") elif op == "get": result = await resilient_cache.get(key) - content = result.content if result else "Cache miss (graceful degradation)" + content = ( + result.content if result else "Cache miss (graceful degradation)" + ) print(f" {i}. āœ… Get '{key}': {content}") except CacheError as e: if "CIRCUIT_BREAKER_OPEN" in str(e.error_code): @@ -253,32 +278,36 @@ async def demo_failure_and_circuit_breaker(circuit_breaker, demo_cache, resilien print(f" {i}. āŒ {op.title()} '{key}': {e}") except Exception as e: print(f" {i}. āŒ {op.title()} '{key}': Unexpected error: {e}") - + state = circuit_breaker.get_state() if state == CircuitState.OPEN: - print("\nšŸ”“ Circuit breaker is now OPEN - protecting system from cascading failures") - - print_metrics(circuit_breaker, resilient_cache, "After Failures (Circuit Breaker Open)") + print( + "\nšŸ”“ Circuit breaker is now OPEN - protecting system from cascading failures" + ) + + print_metrics( + circuit_breaker, resilient_cache, "After Failures (Circuit Breaker Open)" + ) async def demo_graceful_degradation(circuit_breaker, demo_cache, resilient_cache): """Demonstrate graceful degradation when circuit breaker is open.""" print_banner("3. Graceful Degradation") - + print("šŸ›”ļø Testing graceful degradation with circuit breaker open...") - + # Ensure graceful degradation is enabled resilient_cache.enable_graceful_degradation = True - + degradation_operations = [ ("store", "degraded1", "Graceful degradation response"), ("get", "nonexistent", None), ("invalidate_by_prefix", "test_", None), ("store", "degraded2", "Another graceful response"), ] - + print("\nšŸ›Ÿ Performing operations with graceful degradation:") - + for i, (op, key, value) in enumerate(degradation_operations, 1): try: if op == "store": @@ -292,37 +321,37 @@ async def demo_graceful_degradation(circuit_breaker, demo_cache, resilient_cache print(f" {i}. šŸ›Ÿ Invalidate '{key}': {result} (fallback response)") except Exception as e: print(f" {i}. āŒ {op.title()} '{key}': {e}") - + print("\nāœ… All operations completed successfully using graceful degradation") print(" • System remains responsive despite cache failures") print(" • No exceptions propagated to application layer") print(" • Users experience degraded but functional service") - + print_metrics(circuit_breaker, resilient_cache, "During Graceful Degradation") async def demo_recovery(circuit_breaker, demo_cache, resilient_cache): """Demonstrate recovery from failures.""" print_banner("4. Recovery from Failures") - + print("ā° Waiting for circuit breaker timeout...") print(" (In production, this would be longer - using shorter timeout for demo)") - + # Wait for circuit breaker timeout await asyncio.sleep(3.5) - + print("\nšŸ”§ Fixing cache and attempting recovery...") demo_cache.set_failure_mode("normal") - + print("\nšŸ”„ Attempting operations to trigger recovery:") - + recovery_operations = [ ("store", "recovery1", "Recovery test 1"), ("get", "recovery1", None), ("store", "recovery2", "Recovery test 2"), ("store", "recovery3", "Recovery test 3"), ] - + for i, (op, key, value) in enumerate(recovery_operations, 1): try: if op == "store": @@ -332,35 +361,35 @@ async def demo_recovery(circuit_breaker, demo_cache, resilient_cache): result = await resilient_cache.get(key) content = result.content if result else "Not found" print(f" {i}. āœ… Get '{key}': {content}") - + state = circuit_breaker.get_state() print(f" Circuit state: {state.value}") - + except Exception as e: print(f" {i}. āŒ {op.title()} '{key}': {e}") - + final_state = circuit_breaker.get_state() if final_state == CircuitState.CLOSED: print("\n🟢 Circuit breaker successfully recovered to CLOSED state!") elif final_state == CircuitState.HALF_OPEN: print("\n🟔 Circuit breaker is in HALF_OPEN state (testing recovery)") - + print_metrics(circuit_breaker, resilient_cache, "After Recovery") async def demo_intermittent_failures(circuit_breaker, demo_cache, resilient_cache): """Demonstrate handling of intermittent failures.""" print_banner("5. Intermittent Failures Handling") - + print("šŸ”„ Testing with intermittent failures (30% failure rate)...") demo_cache.set_failure_mode("intermittent", 0.3) - + print("\nšŸ“Š Running 20 operations with 30% failure rate:") - + success_count = 0 failure_count = 0 degraded_count = 0 - + for i in range(1, 21): try: result = await resilient_cache.store(f"intermittent_{i}", f"test data {i}") @@ -380,13 +409,13 @@ async def demo_intermittent_failures(circuit_breaker, demo_cache, resilient_cach except Exception as e: failure_count += 1 print(f" {i:2d}. āŒ Unexpected: {e}") - + print(f"\nšŸ“ˆ Results Summary:") print(f" • Successful operations: {success_count}") print(f" • Failed operations: {failure_count}") print(f" • Degraded responses: {degraded_count}") print(f" • Total operations: {success_count + failure_count + degraded_count}") - + print_metrics(circuit_breaker, resilient_cache, "After Intermittent Failures Test") @@ -397,15 +426,17 @@ async def main(): print("This demo shows how the FACT system handles cache failures") print("gracefully using circuit breaker patterns and degradation strategies.") print("=" * 60) - + try: # Run demonstration scenarios circuit_breaker, demo_cache, resilient_cache = await demo_normal_operation() - await demo_failure_and_circuit_breaker(circuit_breaker, demo_cache, resilient_cache) + await demo_failure_and_circuit_breaker( + circuit_breaker, demo_cache, resilient_cache + ) await demo_graceful_degradation(circuit_breaker, demo_cache, resilient_cache) await demo_recovery(circuit_breaker, demo_cache, resilient_cache) await demo_intermittent_failures(circuit_breaker, demo_cache, resilient_cache) - + # Final summary print_banner("Demo Complete - Summary") print("āœ… All cache resilience features demonstrated successfully:") @@ -415,10 +446,10 @@ async def main(): print(" • Automatic recovery from failures") print(" • Handling of intermittent failure scenarios") print("\nšŸŽ‰ The FACT cache resilience system is working as designed!") - + # Final metrics print_metrics(circuit_breaker, resilient_cache, "Final System State") - + except Exception as e: logger.error("Demo failed", error=str(e), exc_info=True) print(f"\nāŒ Demo failed with error: {e}") @@ -426,4 +457,4 @@ async def main(): if __name__ == "__main__": - asyncio.run(main()) \ No newline at end of file + asyncio.run(main()) diff --git a/scripts/demo_lifecycle.py b/scripts/demo_lifecycle.py index 029e968..9be7592 100644 --- a/scripts/demo_lifecycle.py +++ b/scripts/demo_lifecycle.py @@ -16,7 +16,7 @@ from pathlib import Path # Add src to path for imports -sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'src')) +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "src")) import sys import os @@ -24,30 +24,31 @@ import sqlite3 # Add src to path for direct imports -src_path = os.path.join(os.path.dirname(__file__), '..', 'src') +src_path = os.path.join(os.path.dirname(__file__), "..", "src") sys.path.insert(0, src_path) print("šŸŽ­ Running FACT System Integration Demo...") print("=" * 50) + # Direct database demo since imports have complex relative dependencies def demo_database_functionality(): """Demonstrate core database functionality directly.""" try: - db_path = os.path.join(os.path.dirname(__file__), '..', 'data', 'fact_demo.db') + db_path = os.path.join(os.path.dirname(__file__), "..", "data", "fact_demo.db") if not os.path.exists(db_path): print("āŒ Database not found. Run 'python main.py init' first.") return False - + conn = sqlite3.connect(db_path) cursor = conn.cursor() - + print("šŸ“Š Companies in Database:") cursor.execute("SELECT id, name, sector FROM companies") companies = cursor.fetchall() for company in companies: print(f" {company[0]}. {company[1]} ({company[2]})") - + print("\nšŸ’° Sample Financial Query - TechCorp Revenue:") cursor.execute(""" SELECT year, quarter, revenue, profit @@ -58,8 +59,10 @@ def demo_database_functionality(): """) records = cursor.fetchall() for record in records: - print(f" {record[0]} Q{record[1]}: Revenue ${record[2]:,.0f}, Profit ${record[3]:,.0f}") - + print( + f" {record[0]} Q{record[1]}: Revenue ${record[2]:,.0f}, Profit ${record[3]:,.0f}" + ) + print("\nšŸ” Analytics Query - Average Revenue by Sector:") cursor.execute(""" SELECT c.sector, AVG(f.revenue) as avg_revenue @@ -71,17 +74,18 @@ def demo_database_functionality(): analytics = cursor.fetchall() for analytic in analytics: print(f" {analytic[0]}: ${analytic[1]:,.0f} average revenue") - + conn.close() - + print("\nāœ… Database functionality working perfectly!") print("āœ… FACT system core features operational!") return True - + except Exception as e: print(f"āŒ Demo failed: {e}") return False + # Run the demo success = demo_database_functionality() if success: @@ -98,37 +102,37 @@ async def demo_system_initialization(): """Demonstrate system initialization and component integration.""" print("šŸš€ FACT System MVP Demo - Component Integration") print("=" * 60) - + print("\n1. šŸ”§ System Initialization") print("-" * 30) - + try: # Get driver instance (triggers full system initialization) driver = await get_driver() config = get_config() metrics_collector = get_metrics_collector() - + print("āœ… Core driver initialized") print("āœ… Configuration loaded") print("āœ… Database connected") print("āœ… Tools registered") print("āœ… Monitoring active") - + # Show configuration summary print(f"\nšŸ“‹ Configuration Summary:") print(f" Database: {config.database_path}") print(f" Model: {config.claude_model}") print(f" Cache Prefix: {config.cache_prefix}") - + # Show registered tools tool_names = driver.tool_registry.list_tools() print(f"\nšŸ› ļø Registered Tools ({len(tool_names)}):") for tool_name in tool_names: tool_def = driver.tool_registry.get_tool(tool_name) print(f" • {tool_name}: {tool_def.description}") - + return driver - + except Exception as e: print(f"āŒ System initialization failed: {e}") return None @@ -138,21 +142,21 @@ async def demo_database_connectivity(driver): """Demonstrate database connectivity and schema inspection.""" print("\n2. šŸ—„ļø Database Connectivity") print("-" * 30) - + try: # Get database info db_info = await driver.database_manager.get_database_info() - + print(f"āœ… Database connected: {db_info['database_path']}") print(f"šŸ“Š Total tables: {db_info['total_tables']}") print(f"šŸ’¾ File size: {db_info['file_size_bytes']} bytes") - + print(f"\nšŸ“‹ Table Information:") - for table_name, table_info in db_info['tables'].items(): + for table_name, table_info in db_info["tables"].items(): print(f" • {table_name}: {table_info['row_count']} rows") - + return True - + except Exception as e: print(f"āŒ Database connectivity failed: {e}") return False @@ -162,50 +166,50 @@ async def demo_tool_execution(driver): """Demonstrate tool execution framework with sample queries.""" print("\n3. šŸ”§ Tool Execution Framework") print("-" * 30) - + # Sample queries to demonstrate different scenarios sample_queries = [ { "description": "Schema inspection", "query": "Get the database schema to understand available tables", - "expected_tool": "SQL.GetSchema" + "expected_tool": "SQL.GetSchema", }, { "description": "Company lookup", "query": "Show me all companies in the Technology sector", - "expected_tool": "SQL.QueryReadonly" + "expected_tool": "SQL.QueryReadonly", }, { "description": "Financial analysis", "query": "What was TechCorp's Q1 2025 revenue and profit?", - "expected_tool": "SQL.QueryReadonly" + "expected_tool": "SQL.QueryReadonly", }, { "description": "Sample queries", "query": "Show me some example queries I can run", - "expected_tool": "SQL.GetSampleQueries" - } + "expected_tool": "SQL.GetSampleQueries", + }, ] - + for i, sample in enumerate(sample_queries, 1): print(f"\n Query {i}: {sample['description']}") print(f" User Input: \"{sample['query']}\"") - + try: start_time = time.time() - + # Process query through the full FACT pipeline - response = await driver.process_query(sample['query']) - + response = await driver.process_query(sample["query"]) + end_time = time.time() processing_time = (end_time - start_time) * 1000 - + print(f" āœ… Response generated ({processing_time:.1f}ms)") print(f" šŸ“ Response preview: {response[:150]}...") - + except Exception as e: print(f" āŒ Query failed: {e}") - + return True @@ -213,34 +217,36 @@ async def demo_monitoring_metrics(driver): """Demonstrate monitoring and metrics collection.""" print("\n4. šŸ“Š Monitoring & Metrics") print("-" * 30) - + try: # Get system metrics from driver system_metrics = driver.get_metrics() - + print("šŸ“ˆ System Performance:") print(f" • Total queries: {system_metrics['total_queries']}") print(f" • Tool executions: {system_metrics['tool_executions']}") print(f" • Error rate: {system_metrics['error_rate']:.1f}%") print(f" • System initialized: {system_metrics['initialized']}") - + # Get detailed metrics from metrics collector metrics_collector = get_metrics_collector() detailed_metrics = metrics_collector.get_system_metrics(time_window_minutes=5) - + print(f"\nšŸ” Detailed Metrics (last 5 minutes):") print(f" • Total executions: {detailed_metrics.total_executions}") print(f" • Successful: {detailed_metrics.successful_executions}") print(f" • Failed: {detailed_metrics.failed_executions}") - print(f" • Avg execution time: {detailed_metrics.average_execution_time:.1f}ms") - + print( + f" • Avg execution time: {detailed_metrics.average_execution_time:.1f}ms" + ) + if detailed_metrics.top_tools: print(f"\nšŸ† Top Tools:") for tool_info in detailed_metrics.top_tools[:3]: print(f" • {tool_info['tool_name']}: {tool_info['count']} uses") - + return True - + except Exception as e: print(f"āŒ Metrics collection failed: {e}") return False @@ -250,38 +256,34 @@ async def demo_error_handling(driver): """Demonstrate error handling and fallback mechanisms.""" print("\n5. šŸ›”ļø Error Handling & Fallbacks") print("-" * 30) - + # Test various error scenarios error_scenarios = [ - { - "description": "Invalid SQL syntax", - "query": "INVALID SQL QUERY HERE" - }, - { - "description": "Dangerous SQL operation", - "query": "DROP TABLE companies" - }, + {"description": "Invalid SQL syntax", "query": "INVALID SQL QUERY HERE"}, + {"description": "Dangerous SQL operation", "query": "DROP TABLE companies"}, { "description": "Very long query", - "query": "SELECT * FROM companies WHERE " + "name = 'test' AND " * 200 + "1=1" - } + "query": "SELECT * FROM companies WHERE " + + "name = 'test' AND " * 200 + + "1=1", + }, ] - + for i, scenario in enumerate(error_scenarios, 1): print(f"\n Scenario {i}: {scenario['description']}") - + try: - response = await driver.process_query(scenario['query']) - + response = await driver.process_query(scenario["query"]) + # Check if response indicates graceful error handling if "error" in response.lower() or "cannot" in response.lower(): print(f" āœ… Graceful error handling: {response[:100]}...") else: print(f" āš ļø Unexpected success: {response[:100]}...") - + except Exception as e: print(f" āœ… Exception caught and handled: {str(e)[:100]}...") - + return True @@ -289,63 +291,63 @@ async def demo_cache_behavior(driver): """Demonstrate cache hits and misses (simulated).""" print("\n6. šŸ’¾ Cache Behavior") print("-" * 30) - + # Repeat the same query to demonstrate potential caching query = "Show me companies in the Technology sector" - - print(f" Query: \"{query}\"") - + + print(f' Query: "{query}"') + # First execution (cache miss) print("\n First execution (cache miss):") start_time = time.time() response1 = await driver.process_query(query) time1 = (time.time() - start_time) * 1000 print(f" ā±ļø Execution time: {time1:.1f}ms") - + # Second execution (potential cache hit) print("\n Second execution (potential cache hit):") start_time = time.time() response2 = await driver.process_query(query) time2 = (time.time() - start_time) * 1000 print(f" ā±ļø Execution time: {time2:.1f}ms") - + # Compare times if time2 < time1 * 0.8: # 20% faster print(f" āœ… Cache hit detected (faster execution)") else: print(f" šŸ“ No significant performance difference") - + return True async def main(): """Run the complete MVP demo.""" driver = None - + try: # Step 1: System initialization driver = await demo_system_initialization() if not driver: print("\nāŒ Demo failed during initialization") return - + # Step 2: Database connectivity if not await demo_database_connectivity(driver): print("\nāŒ Demo failed during database test") return - + # Step 3: Tool execution await demo_tool_execution(driver) - + # Step 4: Monitoring await demo_monitoring_metrics(driver) - + # Step 5: Error handling await demo_error_handling(driver) - + # Step 6: Cache behavior await demo_cache_behavior(driver) - + print("\nšŸŽ‰ MVP Demo Complete!") print("\nThe FACT system successfully demonstrates:") print("āœ… Unified component integration") @@ -354,13 +356,13 @@ async def main(): print("āœ… Error handling and security") print("āœ… Monitoring and metrics") print("āœ… CLI interface readiness") - + print("\nšŸš€ Ready for production use!") print("Run: python -m src.core.cli") - + except Exception as e: print(f"\nāŒ Demo failed with error: {e}") - + finally: # Clean shutdown if driver: @@ -368,4 +370,4 @@ async def main(): if __name__ == "__main__": - asyncio.run(main()) \ No newline at end of file + asyncio.run(main()) diff --git a/scripts/fix_imports.py b/scripts/fix_imports.py index a0e15f7..5a04b1c 100644 --- a/scripts/fix_imports.py +++ b/scripts/fix_imports.py @@ -11,66 +11,74 @@ import sys from pathlib import Path + def fix_relative_imports(src_dir): """Fix all relative imports in Python files within src_dir.""" src_path = Path(src_dir) - + if not src_path.exists(): print(f"Error: Source directory {src_dir} does not exist") return False - + # Pattern to match relative imports patterns = [ - (r'from \.\.([a-zA-Z_][a-zA-Z0-9_]*(?:\.[a-zA-Z_][a-zA-Z0-9_]*)*)', r'from \1'), # from ..module - (r'from \.([a-zA-Z_][a-zA-Z0-9_]*(?:\.[a-zA-Z_][a-zA-Z0-9_]*)*)', r'from \1'), # from .module + ( + r"from \.\.([a-zA-Z_][a-zA-Z0-9_]*(?:\.[a-zA-Z_][a-zA-Z0-9_]*)*)", + r"from \1", + ), # from ..module + ( + r"from \.([a-zA-Z_][a-zA-Z0-9_]*(?:\.[a-zA-Z_][a-zA-Z0-9_]*)*)", + r"from \1", + ), # from .module ] - + files_fixed = 0 total_replacements = 0 - + # Walk through all Python files - for py_file in src_path.rglob('*.py'): + for py_file in src_path.rglob("*.py"): try: - with open(py_file, 'r', encoding='utf-8') as f: + with open(py_file, "r", encoding="utf-8") as f: content = f.read() - + original_content = content file_replacements = 0 - + # Apply each pattern for pattern, replacement in patterns: new_content, count = re.subn(pattern, replacement, content) content = new_content file_replacements += count - + # Write back if changes were made if content != original_content: - with open(py_file, 'w', encoding='utf-8') as f: + with open(py_file, "w", encoding="utf-8") as f: f.write(content) files_fixed += 1 total_replacements += file_replacements print(f"Fixed {file_replacements} imports in {py_file}") - + except Exception as e: print(f"Error processing {py_file}: {e}") - + print(f"\nSummary:") print(f"Files fixed: {files_fixed}") print(f"Total replacements: {total_replacements}") - + return True + def main(): """Main function.""" script_dir = Path(__file__).parent - src_dir = script_dir.parent / 'src' - + src_dir = script_dir.parent / "src" + print("šŸ”§ Fixing relative imports in FACT system...") print(f"Source directory: {src_dir}") print("-" * 50) - + success = fix_relative_imports(str(src_dir)) - + if success: print("\nāœ… Import fixing completed successfully!") print("\nNext step: Test the system with 'python main.py init'") @@ -78,5 +86,6 @@ def main(): print("\nāŒ Import fixing failed!") sys.exit(1) + if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/scripts/fix_imports_comprehensive.py b/scripts/fix_imports_comprehensive.py index ddba499..274b5c3 100644 --- a/scripts/fix_imports_comprehensive.py +++ b/scripts/fix_imports_comprehensive.py @@ -11,89 +11,110 @@ import sys from pathlib import Path + def fix_all_imports(src_dir): """Fix all import issues in Python files within src_dir.""" src_path = Path(src_dir) - + if not src_path.exists(): print(f"Error: Source directory {src_dir} does not exist") return False - + # Patterns to fix imports patterns = [ # Convert relative imports to absolute imports (from .. -> from module) - (r'from \.\.([a-zA-Z_][a-zA-Z0-9_]*(?:\.[a-zA-Z_][a-zA-Z0-9_]*)*)', r'from \1'), - + (r"from \.\.([a-zA-Z_][a-zA-Z0-9_]*(?:\.[a-zA-Z_][a-zA-Z0-9_]*)*)", r"from \1"), # Fix local imports in __init__.py files (from module -> from .module) # This pattern will be applied only in __init__.py files - (r'^from ([a-zA-Z_][a-zA-Z0-9_]*) import', r'from .\1 import'), + (r"^from ([a-zA-Z_][a-zA-Z0-9_]*) import", r"from .\1 import"), ] - + files_fixed = 0 total_replacements = 0 - + # Walk through all Python files - for py_file in src_path.rglob('*.py'): + for py_file in src_path.rglob("*.py"): try: - with open(py_file, 'r', encoding='utf-8') as f: + with open(py_file, "r", encoding="utf-8") as f: content = f.read() lines = content.splitlines() - + original_content = content file_replacements = 0 - + # Apply relative import fixes new_content, count = re.subn(patterns[0][0], patterns[0][1], content) content = new_content file_replacements += count - + # Apply local import fixes only for __init__.py files and some specific patterns - if py_file.name == '__init__.py' or py_file.name in ['client.py', 'gateway.py']: + if py_file.name == "__init__.py" or py_file.name in [ + "client.py", + "gateway.py", + ]: # Fix specific patterns for local imports lines = content.splitlines() for i, line in enumerate(lines): # Fix common local import patterns - if line.strip().startswith('from ') and ' import ' in line and not line.strip().startswith('from .') and not line.strip().startswith('from core') and not line.strip().startswith('from db') and not line.strip().startswith('from tools') and not line.strip().startswith('from security') and not line.strip().startswith('from arcade') and not line.strip().startswith('from cache') and not line.strip().startswith('from monitoring') and not line.strip().startswith('from benchmarking'): + if ( + line.strip().startswith("from ") + and " import " in line + and not line.strip().startswith("from .") + and not line.strip().startswith("from core") + and not line.strip().startswith("from db") + and not line.strip().startswith("from tools") + and not line.strip().startswith("from security") + and not line.strip().startswith("from arcade") + and not line.strip().startswith("from cache") + and not line.strip().startswith("from monitoring") + and not line.strip().startswith("from benchmarking") + ): # Extract module name - match = re.match(r'from ([a-zA-Z_][a-zA-Z0-9_]*) import', line.strip()) + match = re.match( + r"from ([a-zA-Z_][a-zA-Z0-9_]*) import", line.strip() + ) if match: module_name = match.group(1) # Check if this is a local module (exists in same directory) module_file = py_file.parent / f"{module_name}.py" if module_file.exists(): - lines[i] = line.replace(f'from {module_name} import', f'from .{module_name} import') + lines[i] = line.replace( + f"from {module_name} import", + f"from .{module_name} import", + ) file_replacements += 1 - - content = '\n'.join(lines) - + + content = "\n".join(lines) + # Write back if changes were made if content != original_content: - with open(py_file, 'w', encoding='utf-8') as f: + with open(py_file, "w", encoding="utf-8") as f: f.write(content) files_fixed += 1 total_replacements += file_replacements print(f"Fixed {file_replacements} imports in {py_file}") - + except Exception as e: print(f"Error processing {py_file}: {e}") - + print(f"\nSummary:") print(f"Files fixed: {files_fixed}") print(f"Total replacements: {total_replacements}") - + return True + def main(): """Main function.""" script_dir = Path(__file__).parent - src_dir = script_dir.parent / 'src' - + src_dir = script_dir.parent / "src" + print("šŸ”§ Comprehensive import fixing for FACT system...") print(f"Source directory: {src_dir}") print("-" * 50) - + success = fix_all_imports(str(src_dir)) - + if success: print("\nāœ… Import fixing completed successfully!") print("\nNext step: Test the system with 'python main.py init'") @@ -101,5 +122,6 @@ def main(): print("\nāŒ Import fixing failed!") sys.exit(1) + if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/scripts/init_environment.py b/scripts/init_environment.py index e92523f..e9a6ac6 100644 --- a/scripts/init_environment.py +++ b/scripts/init_environment.py @@ -14,7 +14,7 @@ from pathlib import Path # Add src to path for imports -sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'src')) +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "src")) from core.config import Config, ConfigurationError from core.driver import get_driver @@ -23,13 +23,13 @@ def create_env_file(): """Create .env file with default configuration.""" - env_path = Path('.env') - + env_path = Path(".env") + if env_path.exists(): print("šŸ“„ .env file already exists - skipping creation") return - - env_content = '''# FACT System Configuration + + env_content = """# FACT System Configuration # Copy this file and update with your actual API keys # Required API Keys @@ -45,11 +45,11 @@ def create_env_file(): MAX_RETRIES=3 REQUEST_TIMEOUT=30 LOG_LEVEL=INFO -''' - - with open(env_path, 'w') as f: +""" + + with open(env_path, "w") as f: f.write(env_content) - + print("āœ… Created .env file with default configuration") print("āš ļø Please update the API keys in .env before running the system") @@ -58,24 +58,24 @@ async def init_database(): """Initialize database with schema and sample data.""" try: print("šŸ—„ļø Initializing database...") - + # Create database manager and initialize config = Config() db_manager = DatabaseManager(config.database_path) await db_manager.initialize_database() - + # Get database info db_info = await db_manager.get_database_info() - + print(f"āœ… Database initialized successfully:") print(f" šŸ“ Path: {db_info['database_path']}") print(f" šŸ“Š Tables: {db_info['total_tables']}") - - for table_name, table_info in db_info['tables'].items(): + + for table_name, table_info in db_info["tables"].items(): print(f" šŸ“‹ {table_name}: {table_info['row_count']} rows") - + return True - + except Exception as e: print(f"āŒ Database initialization failed: {e}") return False @@ -85,25 +85,25 @@ async def validate_system(): """Validate system connectivity and configuration.""" try: print("šŸ” Validating system configuration...") - + # Try to get a driver instance (this validates config and initializes components) driver = await get_driver() - + # Get system metrics metrics = driver.get_metrics() - + print("āœ… System validation passed:") print(f" šŸŽÆ Initialized: {metrics['initialized']}") print(f" šŸ› ļø Tools: {len(driver.tool_registry.list_tools())}") - + # List available tools tool_names = driver.tool_registry.list_tools() print(" šŸ”§ Available tools:") for tool_name in tool_names: print(f" • {tool_name}") - + return True - + except ConfigurationError as e: print(f"āŒ Configuration error: {e}") print("šŸ’” Make sure to update your API keys in .env file") @@ -117,27 +117,29 @@ async def main(): """Main initialization routine.""" print("šŸš€ FACT System Environment Initialization") print("=" * 50) - + # Step 1: Create .env file print("\n1. Setting up environment configuration...") create_env_file() - + # Step 2: Initialize database print("\n2. Initializing database...") db_success = await init_database() - + if not db_success: print("\nāŒ Database initialization failed - stopping") sys.exit(1) - + # Step 3: Validate system (only if API keys are configured) print("\n3. Validating system configuration...") - + try: config = Config() if config.anthropic_api_key == "your_anthropic_api_key_here": print("āš ļø API keys not configured - skipping system validation") - print("šŸ’” Update the API keys in .env file, then run: python scripts/validate_system.py") + print( + "šŸ’” Update the API keys in .env file, then run: python scripts/validate_system.py" + ) else: system_success = await validate_system() if not system_success: @@ -145,8 +147,10 @@ async def main(): sys.exit(1) except ConfigurationError: print("āš ļø API keys not configured - skipping system validation") - print("šŸ’” Update the API keys in .env file, then run: python scripts/validate_system.py") - + print( + "šŸ’” Update the API keys in .env file, then run: python scripts/validate_system.py" + ) + print("\nšŸŽ‰ Environment initialization complete!") print("\nNext steps:") print("1. Update API keys in .env file") @@ -155,4 +159,4 @@ async def main(): if __name__ == "__main__": - asyncio.run(main()) \ No newline at end of file + asyncio.run(main()) diff --git a/scripts/run_benchmarks.py b/scripts/run_benchmarks.py index fc0c6ad..837bb77 100644 --- a/scripts/run_benchmarks.py +++ b/scripts/run_benchmarks.py @@ -28,7 +28,7 @@ SystemProfiler, ContinuousMonitor, BenchmarkVisualizer, - ReportGenerator + ReportGenerator, ) from cache.manager import CacheManager @@ -38,45 +38,49 @@ def create_logs_directory(base_dir: str = "logs") -> Path: timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") logs_dir = Path(base_dir) / f"benchmark_{timestamp}" logs_dir.mkdir(parents=True, exist_ok=True) - + # Create subdirectories for better organization (logs_dir / "charts").mkdir(exist_ok=True) (logs_dir / "raw_data").mkdir(exist_ok=True) (logs_dir / "reports").mkdir(exist_ok=True) - + return logs_dir -def print_performance_summary(validation_results, comparison_result=None, load_test_results=None): +def print_performance_summary( + validation_results, comparison_result=None, load_test_results=None +): """Print a comprehensive performance summary to console.""" - print("\n" + "="*80) + print("\n" + "=" * 80) print("šŸ“Š FACT SYSTEM PERFORMANCE SUMMARY") - print("="*80) - + print("=" * 80) + # Overall Status - overall_pass = validation_results['overall_pass'] + overall_pass = validation_results["overall_pass"] status_emoji = "šŸŽ‰" if overall_pass else "āš ļø" status_text = "ALL TARGETS MET" if overall_pass else "IMPROVEMENTS NEEDED" print(f"\n{status_emoji} OVERALL STATUS: {status_text}") - + # Performance Targets Summary print(f"\nšŸ“ˆ PERFORMANCE TARGETS:") print("-" * 50) - target_validation = validation_results['target_validation'] - + target_validation = validation_results["target_validation"] + for target_name, target_data in target_validation.items(): - status = "āœ… PASS" if target_data['met'] else "āŒ FAIL" - if 'latency' in target_name: + status = "āœ… PASS" if target_data["met"] else "āŒ FAIL" + if "latency" in target_name: actual_val = f"{target_data.get('actual_ms', 0):.1f}ms" target_val = f"{target_data.get('target_ms', 0):.1f}ms" else: actual_val = f"{target_data.get('actual_percent', 0):.1f}%" target_val = f"{target_data.get('target_percent', 0):.1f}%" - - print(f" {target_name:25} {status:8} Actual: {actual_val:10} Target: {target_val}") - + + print( + f" {target_name:25} {status:8} Actual: {actual_val:10} Target: {target_val}" + ) + # Cache Performance - summary = validation_results.get('benchmark_summary', {}) + summary = validation_results.get("benchmark_summary", {}) if summary: print(f"\nšŸ—„ļø CACHE PERFORMANCE:") print("-" * 50) @@ -84,37 +88,47 @@ def print_performance_summary(validation_results, comparison_result=None, load_t print(f" Avg Response Time (Hit): {summary.avg_hit_latency_ms:.1f}ms") print(f" Avg Response Time (Miss): {summary.avg_miss_latency_ms:.1f}ms") print(f" Total Requests: {summary.total_queries}") - print(f" Success Rate: {(summary.successful_queries/summary.total_queries)*100:.1f}%") - + print( + f" Success Rate: {(summary.successful_queries/summary.total_queries)*100:.1f}%" + ) + # RAG Comparison Results if comparison_result: print(f"\nāš”ļø FACT vs TRADITIONAL RAG:") print("-" * 50) - latency_improvement = comparison_result.improvement_factors.get('latency', 1.0) - cost_savings = comparison_result.cost_savings.get('percentage', 0.0) + latency_improvement = comparison_result.improvement_factors.get("latency", 1.0) + cost_savings = comparison_result.cost_savings.get("percentage", 0.0) print(f" Latency Improvement: {latency_improvement:.1f}x faster") print(f" Cost Savings: {cost_savings:.1f}%") print(f" Recommendation: {comparison_result.recommendation}") - + # Load Test Results if load_test_results: print(f"\n🚦 LOAD TEST PERFORMANCE:") print("-" * 50) - print(f" Concurrent Users: {load_test_results.get('concurrent_users', 0)}") - print(f" Throughput: {load_test_results.get('throughput_qps', 0):.1f} QPS") - print(f" Avg Response Time: {load_test_results.get('avg_response_time_ms', 0):.1f}ms") - print(f" Error Rate: {load_test_results.get('error_rate', 0):.1f}%") - - print("="*80) + print( + f" Concurrent Users: {load_test_results.get('concurrent_users', 0)}" + ) + print( + f" Throughput: {load_test_results.get('throughput_qps', 0):.1f} QPS" + ) + print( + f" Avg Response Time: {load_test_results.get('avg_response_time_ms', 0):.1f}ms" + ) + print( + f" Error Rate: {load_test_results.get('error_rate', 0):.1f}%" + ) + + print("=" * 80) async def run_comprehensive_benchmark(args): """Run comprehensive benchmark suite.""" print("šŸš€ Starting FACT Comprehensive Benchmark Suite") print("=" * 60) - + # Create timestamped logs directory - if args.output_dir == './benchmark_results': + if args.output_dir == "./benchmark_results": # Use logs directory by default logs_dir = create_logs_directory() print(f"šŸ“ Created logs directory: {logs_dir}") @@ -123,7 +137,7 @@ async def run_comprehensive_benchmark(args): logs_dir = Path(args.output_dir) logs_dir.mkdir(parents=True, exist_ok=True) print(f"šŸ“ Using output directory: {logs_dir}") - + # Initialize components config = BenchmarkConfig( iterations=args.iterations, @@ -133,15 +147,15 @@ async def run_comprehensive_benchmark(args): target_hit_latency_ms=args.hit_target, target_miss_latency_ms=args.miss_target, target_cost_reduction_hit=args.cost_reduction / 100.0, - target_cache_hit_rate=args.cache_hit_rate / 100.0 + target_cache_hit_rate=args.cache_hit_rate / 100.0, ) - + framework = BenchmarkFramework(config) runner = BenchmarkRunner(framework) profiler = SystemProfiler() visualizer = BenchmarkVisualizer() report_generator = ReportGenerator(visualizer) - + # Initialize cache manager if available cache_manager = None try: @@ -152,133 +166,143 @@ async def run_comprehensive_benchmark(args): "max_size": "10MB", "ttl_seconds": 3600, "hit_target_ms": 48, - "miss_target_ms": 140 + "miss_target_ms": 140, } cache_manager = CacheManager(cache_config) print("āœ… Cache manager initialized") except Exception as e: print(f"āš ļø Cache manager not available: {e}") - + # Phase 1: Performance Validation print("\nšŸ“Š Phase 1: Performance Validation") print("-" * 40) - + validation_results = await runner.run_performance_validation(cache_manager) - + # Display validation results - print(f"Overall Validation: {'āœ… PASS' if validation_results['overall_pass'] else 'āŒ FAIL'}") - - target_validation = validation_results['target_validation'] + print( + f"Overall Validation: {'āœ… PASS' if validation_results['overall_pass'] else 'āŒ FAIL'}" + ) + + target_validation = validation_results["target_validation"] for target_name, target_data in target_validation.items(): - status = "āœ… PASS" if target_data['met'] else "āŒ FAIL" - print(f" {target_name}: {status} ({target_data['actual_ms' if 'latency' in target_name else 'actual_percent']:.1f})") - + status = "āœ… PASS" if target_data["met"] else "āŒ FAIL" + print( + f" {target_name}: {status} ({target_data['actual_ms' if 'latency' in target_name else 'actual_percent']:.1f})" + ) + # Phase 2: RAG Comparison (if enabled) comparison_result = None if args.include_rag_comparison: print("\nāš”ļø Phase 2: RAG Comparison Analysis") print("-" * 40) - + rag_comparison = RAGComparison(framework) comparison_result = await rag_comparison.run_comparison_benchmark( runner.test_queries, cache_manager, config.iterations ) - - print(f"Latency Improvement: {comparison_result.improvement_factors.get('latency', 1.0):.1f}x") - print(f"Cost Savings: {comparison_result.cost_savings.get('percentage', 0.0):.1f}%") + + print( + f"Latency Improvement: {comparison_result.improvement_factors.get('latency', 1.0):.1f}x" + ) + print( + f"Cost Savings: {comparison_result.cost_savings.get('percentage', 0.0):.1f}%" + ) print(f"Recommendation: {comparison_result.recommendation}") - + # Phase 3: Profiling Analysis (if enabled) profile_result = None if args.include_profiling: print("\nšŸ” Phase 3: Performance Profiling") print("-" * 40) - + # Profile a representative operation async def sample_operation(): return await runner.run_performance_validation(cache_manager) - + _, profile_result = await profiler.profile_complete_operation( sample_operation, "performance_validation" ) - + print(f"Execution Time: {profile_result.execution_time_ms:.1f}ms") print(f"Bottlenecks Found: {len(profile_result.bottlenecks)}") - - critical_bottlenecks = [b for b in profile_result.bottlenecks if b.severity == "critical"] + + critical_bottlenecks = [ + b for b in profile_result.bottlenecks if b.severity == "critical" + ] if critical_bottlenecks: print("āŒ Critical Bottlenecks:") for bottleneck in critical_bottlenecks[:3]: print(f" - {bottleneck.component}: {bottleneck.description}") - + # Phase 4: Load Testing (if enabled) load_test_results = None if args.include_load_test: print("\n🚦 Phase 4: Load Testing") print("-" * 40) - + load_test_results = await runner.run_load_test( concurrent_users=args.load_test_users, - duration_seconds=args.load_test_duration + duration_seconds=args.load_test_duration, ) - + print(f"Concurrent Users: {load_test_results['concurrent_users']}") print(f"Throughput: {load_test_results['throughput_qps']:.1f} QPS") print(f"Avg Response Time: {load_test_results['avg_response_time_ms']:.1f}ms") print(f"Error Rate: {load_test_results['error_rate']:.1f}%") - + # Phase 5: Report Generation & Visualization print("\nšŸ“ Phase 5: Report Generation & Visualization") print("-" * 40) - - benchmark_summary = validation_results['benchmark_summary'] - + + benchmark_summary = validation_results["benchmark_summary"] + # Generate comprehensive report report = report_generator.generate_comprehensive_report( benchmark_summary=benchmark_summary, comparison_result=comparison_result, profile_result=profile_result, - alerts=None # No alerts in batch mode + alerts=None, # No alerts in batch mode ) - + # Generate timestamp for consistent naming timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") - + # Save comprehensive JSON report json_report_path = logs_dir / "reports" / f"benchmark_report_{timestamp}.json" - with open(json_report_path, 'w') as f: + with open(json_report_path, "w") as f: f.write(report_generator.export_report_json(report)) - + # Save text summary text_report_path = logs_dir / "reports" / f"benchmark_summary_{timestamp}.txt" - with open(text_report_path, 'w') as f: + with open(text_report_path, "w") as f: f.write(report_generator.export_report_summary_text(report)) - + # Save raw data for further analysis raw_data_path = logs_dir / "raw_data" / f"raw_results_{timestamp}.json" raw_data = { - 'validation_results': validation_results, - 'comparison_result': comparison_result.__dict__ if comparison_result else None, - 'profile_result': profile_result.__dict__ if profile_result else None, - 'load_test_results': load_test_results, - 'config': config.__dict__, - 'timestamp': timestamp, - 'args': vars(args) + "validation_results": validation_results, + "comparison_result": comparison_result.__dict__ if comparison_result else None, + "profile_result": profile_result.__dict__ if profile_result else None, + "load_test_results": load_test_results, + "config": config.__dict__, + "timestamp": timestamp, + "args": vars(args), } - with open(raw_data_path, 'w') as f: + with open(raw_data_path, "w") as f: json.dump(raw_data, f, indent=2, default=str) - + # Generate and save visualizations charts_dir = logs_dir / "charts" - + print("šŸ“Š Generating performance visualizations...") - + # Performance charts for i, chart in enumerate(report.charts): chart_path = charts_dir / f"chart_{i}_{chart.chart_type}_{timestamp}.json" - with open(chart_path, 'w') as f: + with open(chart_path, "w") as f: f.write(visualizer.export_chart_data_json(chart)) - + # Additional visualizations if comparison data available if comparison_result: # Latency comparison chart @@ -286,68 +310,70 @@ async def sample_operation(): benchmark_summary, comparison_result ) latency_chart_path = charts_dir / f"latency_comparison_{timestamp}.json" - with open(latency_chart_path, 'w') as f: + with open(latency_chart_path, "w") as f: f.write(visualizer.export_chart_data_json(latency_chart)) - + # Cost savings chart cost_chart = visualizer.create_cost_analysis_chart(comparison_result) cost_chart_path = charts_dir / f"cost_analysis_{timestamp}.json" - with open(cost_chart_path, 'w') as f: + with open(cost_chart_path, "w") as f: f.write(visualizer.export_chart_data_json(cost_chart)) - + # Cache performance chart cache_chart = visualizer.create_cache_performance_chart(benchmark_summary) cache_chart_path = charts_dir / f"cache_performance_{timestamp}.json" - with open(cache_chart_path, 'w') as f: + with open(cache_chart_path, "w") as f: f.write(visualizer.export_chart_data_json(cache_chart)) - + print(f"šŸ“„ Reports saved to: {logs_dir}") print(f"šŸ“„ JSON Report: {json_report_path}") print(f"šŸ“„ Summary: {text_report_path}") print(f"šŸ“„ Raw Data: {raw_data_path}") print(f"šŸ“ˆ Charts: {charts_dir}") - + # Print comprehensive performance summary to console print_performance_summary(validation_results, comparison_result, load_test_results) - + # Performance Grade - grade = report.summary.get('performance_grade', 'N/A') + grade = report.summary.get("performance_grade", "N/A") print(f"\nšŸ† Performance Grade: {grade}") - + # Print key recommendations if report.recommendations: print(f"\nšŸ”§ KEY RECOMMENDATIONS:") print("-" * 50) for i, rec in enumerate(report.recommendations[:5], 1): print(f" {i}. {rec}") - + # Final status message - if validation_results['overall_pass']: - print(f"\nšŸŽ‰ Benchmark completed successfully! All performance targets achieved.") + if validation_results["overall_pass"]: + print( + f"\nšŸŽ‰ Benchmark completed successfully! All performance targets achieved." + ) print(f" Results saved to: {logs_dir}") else: print(f"\nāš ļø Benchmark completed with some targets not met.") print(f" Review optimization strategies and detailed reports in: {logs_dir}") - - return validation_results['overall_pass'] + + return validation_results["overall_pass"] async def run_continuous_monitoring(args): """Run continuous monitoring mode.""" print("šŸ”„ Starting FACT Continuous Monitoring") print("=" * 60) - + # Initialize monitoring monitor = ContinuousMonitor() - + # Add console alert callback def alert_callback(alert): severity_emoji = {"info": "ā„¹ļø", "warning": "āš ļø", "critical": "🚨"} emoji = severity_emoji.get(alert.severity, "šŸ“¢") print(f"{emoji} {alert.severity.upper()}: {alert.message}") - + monitor.add_alert_callback(alert_callback) - + # Initialize cache manager cache_manager = None try: @@ -358,18 +384,18 @@ def alert_callback(alert): "max_size": "10MB", "ttl_seconds": 3600, "hit_target_ms": 48, - "miss_target_ms": 140 + "miss_target_ms": 140, } cache_manager = CacheManager(cache_config) print("āœ… Cache manager initialized") except Exception as e: print(f"āš ļø Cache manager not available: {e}") - + try: # Start monitoring await monitor.start_monitoring(cache_manager) print("šŸ“” Monitoring started. Press Ctrl+C to stop.") - + # Monitor for specified duration or indefinitely if args.monitor_duration > 0: await asyncio.sleep(args.monitor_duration) @@ -378,32 +404,32 @@ def alert_callback(alert): try: while True: await asyncio.sleep(60) - + # Print status every minute status = monitor.get_monitoring_status() print(f"šŸ“Š Active alerts: {status['active_alerts']}") except KeyboardInterrupt: pass - + finally: # Stop monitoring and generate report await monitor.stop_monitoring() - + monitoring_report = monitor.export_monitoring_report() - + # Create logs directory for monitoring - if args.output_dir == './benchmark_results': + if args.output_dir == "./benchmark_results": logs_dir = create_logs_directory() else: logs_dir = Path(args.output_dir) logs_dir.mkdir(parents=True, exist_ok=True) - + # Save monitoring report with timestamp timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") report_path = logs_dir / "reports" / f"monitoring_report_{timestamp}.json" - with open(report_path, 'w') as f: + with open(report_path, "w") as f: json.dump(monitoring_report, f, indent=2, default=str) - + print(f"\nšŸ“„ Monitoring report saved: {report_path}") print(f"šŸ“ All monitoring data in: {logs_dir}") @@ -426,54 +452,99 @@ def main(): # Custom performance targets python scripts/run_benchmarks.py --hit-target 40 --miss-target 120 --cost-reduction 85 - """ + """, ) - + # Mode selection parser.add_argument( - '--mode', - choices=['benchmark', 'monitoring'], - default='benchmark', - help='Execution mode (default: benchmark)' + "--mode", + choices=["benchmark", "monitoring"], + default="benchmark", + help="Execution mode (default: benchmark)", ) - + # Benchmark configuration - parser.add_argument('--iterations', type=int, default=10, help='Number of benchmark iterations') - parser.add_argument('--warmup', type=int, default=3, help='Number of warmup iterations') - parser.add_argument('--concurrent-users', type=int, default=1, help='Number of concurrent users') - parser.add_argument('--timeout', type=int, default=30, help='Timeout in seconds') - + parser.add_argument( + "--iterations", type=int, default=10, help="Number of benchmark iterations" + ) + parser.add_argument( + "--warmup", type=int, default=3, help="Number of warmup iterations" + ) + parser.add_argument( + "--concurrent-users", type=int, default=1, help="Number of concurrent users" + ) + parser.add_argument("--timeout", type=int, default=30, help="Timeout in seconds") + # Performance targets - parser.add_argument('--hit-target', type=float, default=48.0, help='Cache hit latency target (ms)') - parser.add_argument('--miss-target', type=float, default=140.0, help='Cache miss latency target (ms)') - parser.add_argument('--cost-reduction', type=float, default=90.0, help='Cost reduction target (percent)') - parser.add_argument('--cache-hit-rate', type=float, default=60.0, help='Cache hit rate target (percent)') - + parser.add_argument( + "--hit-target", type=float, default=48.0, help="Cache hit latency target (ms)" + ) + parser.add_argument( + "--miss-target", + type=float, + default=140.0, + help="Cache miss latency target (ms)", + ) + parser.add_argument( + "--cost-reduction", + type=float, + default=90.0, + help="Cost reduction target (percent)", + ) + parser.add_argument( + "--cache-hit-rate", + type=float, + default=60.0, + help="Cache hit rate target (percent)", + ) + # Optional components - parser.add_argument('--include-rag-comparison', action='store_true', help='Include RAG comparison') - parser.add_argument('--include-profiling', action='store_true', help='Include performance profiling') - parser.add_argument('--include-load-test', action='store_true', help='Include load testing') - + parser.add_argument( + "--include-rag-comparison", action="store_true", help="Include RAG comparison" + ) + parser.add_argument( + "--include-profiling", action="store_true", help="Include performance profiling" + ) + parser.add_argument( + "--include-load-test", action="store_true", help="Include load testing" + ) + # Load testing configuration - parser.add_argument('--load-test-users', type=int, default=10, help='Load test concurrent users') - parser.add_argument('--load-test-duration', type=int, default=60, help='Load test duration (seconds)') - + parser.add_argument( + "--load-test-users", type=int, default=10, help="Load test concurrent users" + ) + parser.add_argument( + "--load-test-duration", + type=int, + default=60, + help="Load test duration (seconds)", + ) + # Monitoring configuration - parser.add_argument('--monitor-duration', type=int, default=0, help='Monitoring duration (0=indefinite)') - + parser.add_argument( + "--monitor-duration", + type=int, + default=0, + help="Monitoring duration (0=indefinite)", + ) + # Output configuration - parser.add_argument('--output-dir', default='./benchmark_results', help='Output directory for reports (default: creates timestamped logs directory)') - + parser.add_argument( + "--output-dir", + default="./benchmark_results", + help="Output directory for reports (default: creates timestamped logs directory)", + ) + args = parser.parse_args() - + # Run appropriate mode - if args.mode == 'benchmark': + if args.mode == "benchmark": success = asyncio.run(run_comprehensive_benchmark(args)) sys.exit(0 if success else 1) - elif args.mode == 'monitoring': + elif args.mode == "monitoring": asyncio.run(run_continuous_monitoring(args)) sys.exit(0) if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/scripts/run_benchmarks_demo.py b/scripts/run_benchmarks_demo.py index 343d4dc..1e72f76 100644 --- a/scripts/run_benchmarks_demo.py +++ b/scripts/run_benchmarks_demo.py @@ -21,23 +21,27 @@ # Add src to path for imports sys.path.insert(0, str(Path(__file__).parent.parent / "src")) + @dataclass class DemoResult: """Demo benchmark result.""" + query: str response_time_ms: float success: bool = True cache_hit: bool = False error: Optional[str] = None timestamp: float = None - + def __post_init__(self): if self.timestamp is None: self.timestamp = time.time() + @dataclass class DemoMetrics: """Demo performance metrics.""" + total_requests: int = 0 cache_hits: int = 0 cache_misses: int = 0 @@ -47,14 +51,20 @@ class DemoMetrics: success_rate: float = 1.0 cost_reduction_percentage: float = 0.0 + class DemoBenchmarkRunner: """Demo benchmark runner that simulates FACT performance.""" - - def __init__(self, target_hit_rate: float = 0.65, hit_target_ms: float = 45.0, miss_target_ms: float = 135.0): + + def __init__( + self, + target_hit_rate: float = 0.65, + hit_target_ms: float = 45.0, + miss_target_ms: float = 135.0, + ): self.target_hit_rate = target_hit_rate self.hit_target_ms = hit_target_ms self.miss_target_ms = miss_target_ms - + # Sample queries for benchmarking self.test_queries = [ "What was the Q1-2025 revenue?", @@ -66,24 +76,24 @@ def __init__(self, target_hit_rate: float = 0.65, hit_target_ms: float = 45.0, m "Compare performance across regions", "What is the customer acquisition cost?", "Show quarterly expense breakdown", - "Predict next quarter's revenue" + "Predict next quarter's revenue", ] - + # Track queries for cache simulation self.query_history = set() - + async def run_performance_validation(self, iterations: int = 100) -> Dict[str, Any]: """Run demo performance validation.""" print(f"šŸ”„ Running {iterations} demo queries to validate FACT performance...") - + results = [] hit_latencies = [] miss_latencies = [] - + # Pre-populate some queries to simulate warmed cache for query in self.test_queries[:6]: # Warm 60% of queries self.query_history.add(query) - + for i in range(iterations): # Simulate FACT algorithm intelligent caching # After warmup phase, achieve target hit rate with optimization @@ -92,12 +102,14 @@ async def run_performance_validation(self, iterations: int = 100) -> Dict[str, A else: # FACT algorithm optimizes over time to achieve target hit rate warmup_progress = min(1.0, (i - 15) / 30) # Gradual improvement - target_adjusted = self.target_hit_rate + (warmup_progress * 0.1) # Exceed target slightly + target_adjusted = self.target_hit_rate + ( + warmup_progress * 0.1 + ) # Exceed target slightly cache_hit_probability = target_adjusted + random.uniform(-0.03, 0.03) - + # Determine if this will be a cache hit is_cache_hit = random.random() < cache_hit_probability - + if is_cache_hit and self.query_history: # Cache hit - select from known queries query = random.choice(list(self.query_history)) @@ -107,58 +119,70 @@ async def run_performance_validation(self, iterations: int = 100) -> Dict[str, A query = random.choice(self.test_queries) cache_hit = False self.query_history.add(query) - + # Simulate response times based on FACT targets if cache_hit: # Cache hit: target ≤48ms with some variance - latency = random.normalvariate(self.hit_target_ms * 0.9, self.hit_target_ms * 0.15) - latency = max(15.0, min(latency, self.hit_target_ms * 1.1)) # Clamp to reasonable range + latency = random.normalvariate( + self.hit_target_ms * 0.9, self.hit_target_ms * 0.15 + ) + latency = max( + 15.0, min(latency, self.hit_target_ms * 1.1) + ) # Clamp to reasonable range hit_latencies.append(latency) else: - # Cache miss: target ≤140ms with some variance - latency = random.normalvariate(self.miss_target_ms * 0.95, self.miss_target_ms * 0.2) - latency = max(80.0, min(latency, self.miss_target_ms * 1.05)) # Clamp to reasonable range + # Cache miss: target ≤140ms with some variance + latency = random.normalvariate( + self.miss_target_ms * 0.95, self.miss_target_ms * 0.2 + ) + latency = max( + 80.0, min(latency, self.miss_target_ms * 1.05) + ) # Clamp to reasonable range miss_latencies.append(latency) - + result = DemoResult( - query=query, - response_time_ms=latency, - cache_hit=cache_hit + query=query, response_time_ms=latency, cache_hit=cache_hit ) results.append(result) - + # Progress indicator if (i + 1) % 20 == 0: print(f" Completed {i + 1}/{iterations} queries...") - + # Calculate metrics cache_hits = sum(1 for r in results if r.cache_hit) cache_misses = len(results) - cache_hits hit_rate = (cache_hits / len(results)) * 100 if results else 0 - - avg_hit_latency = sum(hit_latencies) / len(hit_latencies) if hit_latencies else 0 - avg_miss_latency = sum(miss_latencies) / len(miss_latencies) if miss_latencies else 0 - + + avg_hit_latency = ( + sum(hit_latencies) / len(hit_latencies) if hit_latencies else 0 + ) + avg_miss_latency = ( + sum(miss_latencies) / len(miss_latencies) if miss_latencies else 0 + ) + # Cost reduction calculation (based on FACT algorithm efficiency) # FACT algorithm provides significant cost savings through: # 1. Intelligent caching reduces API calls by 90% on hits # 2. Optimized query processing reduces costs by 65% even on misses # 3. Additional efficiency gains from token optimization - + hit_ratio = cache_hits / len(results) if results else 0 miss_ratio = cache_misses / len(results) if results else 0 - + # Base cost reduction from caching strategy base_hit_savings = hit_ratio * 90.0 # 90% reduction on cache hits base_miss_savings = miss_ratio * 65.0 # 65% reduction on cache misses - + # Additional efficiency bonus when hit rate exceeds 60% if hit_rate >= 60.0: - efficiency_bonus = min(10.0, (hit_rate - 60.0) * 0.5) # Up to 10% additional savings + efficiency_bonus = min( + 10.0, (hit_rate - 60.0) * 0.5 + ) # Up to 10% additional savings cost_reduction = base_hit_savings + base_miss_savings + efficiency_bonus else: cost_reduction = base_hit_savings + base_miss_savings - + summary = DemoMetrics( total_requests=len(results), cache_hits=cache_hits, @@ -166,85 +190,89 @@ async def run_performance_validation(self, iterations: int = 100) -> Dict[str, A avg_hit_latency_ms=avg_hit_latency, avg_miss_latency_ms=avg_miss_latency, cache_hit_rate=hit_rate / 100, - cost_reduction_percentage=cost_reduction + cost_reduction_percentage=cost_reduction, ) - + # Target validation target_validation = { - 'cache_hit_latency': { - 'met': avg_hit_latency <= 48.0, - 'actual_ms': avg_hit_latency, - 'target_ms': 48.0 + "cache_hit_latency": { + "met": avg_hit_latency <= 48.0, + "actual_ms": avg_hit_latency, + "target_ms": 48.0, + }, + "cache_miss_latency": { + "met": avg_miss_latency <= 140.0, + "actual_ms": avg_miss_latency, + "target_ms": 140.0, }, - 'cache_miss_latency': { - 'met': avg_miss_latency <= 140.0, - 'actual_ms': avg_miss_latency, - 'target_ms': 140.0 + "cost_reduction": { + "met": cost_reduction >= 90.0, + "actual_percent": cost_reduction, + "target_percent": 90.0, }, - 'cost_reduction': { - 'met': cost_reduction >= 90.0, - 'actual_percent': cost_reduction, - 'target_percent': 90.0 + "cache_hit_rate": { + "met": hit_rate >= 60.0, + "actual_percent": hit_rate, + "target_percent": 60.0, }, - 'cache_hit_rate': { - 'met': hit_rate >= 60.0, - 'actual_percent': hit_rate, - 'target_percent': 60.0 - } } - - overall_pass = all(target['met'] for target in target_validation.values()) - + + overall_pass = all(target["met"] for target in target_validation.values()) + return { - 'overall_pass': overall_pass, - 'target_validation': target_validation, - 'benchmark_summary': summary, - 'results': results + "overall_pass": overall_pass, + "target_validation": target_validation, + "benchmark_summary": summary, + "results": results, } + def create_logs_directory(base_dir: str = "logs") -> Path: """Create timestamped logs directory and return the path.""" timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") logs_dir = Path(base_dir) / f"demo_benchmark_{timestamp}" logs_dir.mkdir(parents=True, exist_ok=True) - + # Create subdirectories for better organization (logs_dir / "charts").mkdir(exist_ok=True) (logs_dir / "raw_data").mkdir(exist_ok=True) (logs_dir / "reports").mkdir(exist_ok=True) - + return logs_dir + def print_performance_summary(validation_results): """Print a comprehensive performance summary to console.""" - print("\n" + "="*80) + print("\n" + "=" * 80) print("šŸ“Š FACT DEMO PERFORMANCE SUMMARY") - print("="*80) - + print("=" * 80) + # Overall Status - overall_pass = validation_results['overall_pass'] + overall_pass = validation_results["overall_pass"] status_emoji = "šŸŽ‰" if overall_pass else "āš ļø" status_text = "ALL TARGETS MET" if overall_pass else "IMPROVEMENTS NEEDED" print(f"\n{status_emoji} OVERALL STATUS: {status_text}") - + # Performance Targets Summary print(f"\nšŸ“ˆ PERFORMANCE TARGETS:") print("-" * 50) - target_validation = validation_results['target_validation'] - + target_validation = validation_results["target_validation"] + for target_name, target_data in target_validation.items(): - status = "āœ… PASS" if target_data['met'] else "āŒ FAIL" - if 'latency' in target_name: + status = "āœ… PASS" if target_data["met"] else "āŒ FAIL" + if "latency" in target_name: actual_val = f"{target_data.get('actual_ms', 0):.1f}ms" target_val = f"{target_data.get('target_ms', 0):.1f}ms" else: actual_val = f"{target_data.get('actual_percent', 0):.1f}%" target_val = f"{target_data.get('target_percent', 0):.1f}%" - - print(f" {target_name:25} {status:8} Actual: {actual_val:10} Target: {target_val}") - + + print( + f" {target_name:25} {status:8} Actual: {actual_val:10} Target: {target_val}" + ) + # Cache Performance - summary = validation_results.get('benchmark_summary') + summary = validation_results.get("benchmark_summary") if summary: print(f"\nšŸ—„ļø CACHE PERFORMANCE:") print("-" * 50) @@ -254,8 +282,9 @@ def print_performance_summary(validation_results): print(f" Total Requests: {summary.total_requests}") print(f" Success Rate: {summary.success_rate*100:.1f}%") print(f" Cost Reduction: {summary.cost_reduction_percentage:.1f}%") - - print("="*80) + + print("=" * 80) + async def run_demo_benchmark(args): """Run demo benchmark suite.""" @@ -263,40 +292,42 @@ async def run_demo_benchmark(args): print("=" * 60) print("šŸ“ Note: This is a demonstration using simulated data") print("=" * 60) - + # Create timestamped logs directory logs_dir = create_logs_directory() print(f"šŸ“ Created logs directory: {logs_dir}") - + # Initialize demo runner with FACT targets runner = DemoBenchmarkRunner( target_hit_rate=args.cache_hit_rate / 100.0, hit_target_ms=args.hit_target, - miss_target_ms=args.miss_target + miss_target_ms=args.miss_target, ) - + # Run performance validation print("\nšŸ“Š Phase 1: Demo Performance Validation") print("-" * 40) - + validation_results = await runner.run_performance_validation(args.iterations) - + # Display validation results - print(f"Overall Validation: {'āœ… PASS' if validation_results['overall_pass'] else 'āŒ FAIL'}") - - target_validation = validation_results['target_validation'] + print( + f"Overall Validation: {'āœ… PASS' if validation_results['overall_pass'] else 'āŒ FAIL'}" + ) + + target_validation = validation_results["target_validation"] for target_name, target_data in target_validation.items(): - status = "āœ… PASS" if target_data['met'] else "āŒ FAIL" - actual_key = 'actual_ms' if 'latency' in target_name else 'actual_percent' + status = "āœ… PASS" if target_data["met"] else "āŒ FAIL" + actual_key = "actual_ms" if "latency" in target_name else "actual_percent" print(f" {target_name}: {status} ({target_data[actual_key]:.1f})") - + # Generate reports print("\nšŸ“ Phase 2: Report Generation") print("-" * 40) - + # Generate timestamp for consistent naming timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") - + # Save comprehensive JSON report json_report_path = logs_dir / "reports" / f"demo_benchmark_report_{timestamp}.json" report_data = { @@ -306,52 +337,63 @@ async def run_demo_benchmark(args): "iterations": args.iterations, "hit_target_ms": args.hit_target, "miss_target_ms": args.miss_target, - "cache_hit_rate_target": args.cache_hit_rate + "cache_hit_rate_target": args.cache_hit_rate, }, - "results": validation_results + "results": validation_results, } - - with open(json_report_path, 'w') as f: + + with open(json_report_path, "w") as f: json.dump(report_data, f, indent=2, default=str) - + # Save text summary text_report_path = logs_dir / "reports" / f"demo_benchmark_summary_{timestamp}.txt" - with open(text_report_path, 'w') as f: + with open(text_report_path, "w") as f: f.write("FACT Demo Benchmark Summary\n") f.write("=" * 50 + "\n\n") f.write(f"Timestamp: {timestamp}\n") f.write(f"Iterations: {args.iterations}\n") f.write(f"Overall Pass: {validation_results['overall_pass']}\n\n") - + f.write("Performance Targets:\n") f.write("-" * 30 + "\n") for target_name, target_data in target_validation.items(): - status = "PASS" if target_data['met'] else "FAIL" - actual_key = 'actual_ms' if 'latency' in target_name else 'actual_percent' + status = "PASS" if target_data["met"] else "FAIL" + actual_key = "actual_ms" if "latency" in target_name else "actual_percent" f.write(f"{target_name}: {status} ({target_data[actual_key]:.1f})\n") - + print(f"šŸ“„ Reports saved to: {logs_dir}") print(f"šŸ“„ JSON Report: {json_report_path}") print(f"šŸ“„ Summary: {text_report_path}") - + # Print comprehensive performance summary to console print_performance_summary(validation_results) - + # Final status message - if validation_results['overall_pass']: - print(f"\nšŸŽ‰ Demo benchmark completed successfully! All performance targets achieved.") + if validation_results["overall_pass"]: + print( + f"\nšŸŽ‰ Demo benchmark completed successfully! All performance targets achieved." + ) print(f" Results saved to: {logs_dir}") print(f"\nšŸ’” FACT Algorithm Analysis:") print(f" • Fast Access Caching Technology is performing within targets") - print(f" • Cache hit latency: {validation_results['benchmark_summary'].avg_hit_latency_ms:.1f}ms (target: ≤48ms)") - print(f" • Cache miss latency: {validation_results['benchmark_summary'].avg_miss_latency_ms:.1f}ms (target: ≤140ms)") - print(f" • Cost reduction: {validation_results['benchmark_summary'].cost_reduction_percentage:.1f}% (target: ≄90%)") - print(f" • Cache hit rate: {validation_results['benchmark_summary'].cache_hit_rate*100:.1f}% (target: ≄60%)") + print( + f" • Cache hit latency: {validation_results['benchmark_summary'].avg_hit_latency_ms:.1f}ms (target: ≤48ms)" + ) + print( + f" • Cache miss latency: {validation_results['benchmark_summary'].avg_miss_latency_ms:.1f}ms (target: ≤140ms)" + ) + print( + f" • Cost reduction: {validation_results['benchmark_summary'].cost_reduction_percentage:.1f}% (target: ≄90%)" + ) + print( + f" • Cache hit rate: {validation_results['benchmark_summary'].cache_hit_rate*100:.1f}% (target: ≄60%)" + ) else: print(f"\nāš ļø Demo benchmark completed with some targets not met.") print(f" Review optimization strategies and detailed reports in: {logs_dir}") - - return validation_results['overall_pass'] + + return validation_results["overall_pass"] + def main(): """Main entry point.""" @@ -365,22 +407,37 @@ def main(): # Custom performance targets python scripts/run_benchmarks_demo.py --hit-target 40 --miss-target 120 --cost-reduction 85 - """ + """, ) - + # Benchmark configuration - parser.add_argument('--iterations', type=int, default=100, help='Number of benchmark iterations') - + parser.add_argument( + "--iterations", type=int, default=100, help="Number of benchmark iterations" + ) + # Performance targets - parser.add_argument('--hit-target', type=float, default=48.0, help='Cache hit latency target (ms)') - parser.add_argument('--miss-target', type=float, default=140.0, help='Cache miss latency target (ms)') - parser.add_argument('--cache-hit-rate', type=float, default=60.0, help='Cache hit rate target (percent)') - + parser.add_argument( + "--hit-target", type=float, default=48.0, help="Cache hit latency target (ms)" + ) + parser.add_argument( + "--miss-target", + type=float, + default=140.0, + help="Cache miss latency target (ms)", + ) + parser.add_argument( + "--cache-hit-rate", + type=float, + default=60.0, + help="Cache hit rate target (percent)", + ) + args = parser.parse_args() - + # Run demo benchmark success = asyncio.run(run_demo_benchmark(args)) sys.exit(0 if success else 1) + if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/scripts/run_benchmarks_standalone.py b/scripts/run_benchmarks_standalone.py index e378384..0dc8d0f 100644 --- a/scripts/run_benchmarks_standalone.py +++ b/scripts/run_benchmarks_standalone.py @@ -17,78 +17,101 @@ from pathlib import Path from typing import Optional, Dict, Any + def create_logs_directory(base_dir: str = "logs") -> Path: """Create timestamped logs directory and return the path.""" timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") logs_dir = Path(base_dir) / f"benchmark_{timestamp}" logs_dir.mkdir(parents=True, exist_ok=True) - + # Create subdirectories for better organization (logs_dir / "charts").mkdir(exist_ok=True) (logs_dir / "raw_data").mkdir(exist_ok=True) (logs_dir / "reports").mkdir(exist_ok=True) - + return logs_dir -def print_performance_summary(validation_results, comparison_result=None, load_test_results=None): + +def print_performance_summary( + validation_results, comparison_result=None, load_test_results=None +): """Print a comprehensive performance summary to console.""" - print("\n" + "="*80) + print("\n" + "=" * 80) print("šŸ“Š FACT SYSTEM PERFORMANCE SUMMARY") - print("="*80) - + print("=" * 80) + # Overall Status - overall_pass = validation_results['overall_pass'] + overall_pass = validation_results["overall_pass"] status_emoji = "šŸŽ‰" if overall_pass else "āš ļø" status_text = "ALL TARGETS MET" if overall_pass else "IMPROVEMENTS NEEDED" print(f"\n{status_emoji} OVERALL STATUS: {status_text}") - + # Performance Targets Summary print(f"\nšŸ“ˆ PERFORMANCE TARGETS:") print("-" * 50) - target_validation = validation_results['target_validation'] - + target_validation = validation_results["target_validation"] + for target_name, target_data in target_validation.items(): - status = "āœ… PASS" if target_data['met'] else "āŒ FAIL" - if 'latency' in target_name: + status = "āœ… PASS" if target_data["met"] else "āŒ FAIL" + if "latency" in target_name: actual_val = f"{target_data.get('actual_ms', 0):.1f}ms" target_val = f"{target_data.get('target_ms', 0):.1f}ms" else: actual_val = f"{target_data.get('actual_percent', 0):.1f}%" target_val = f"{target_data.get('target_percent', 0):.1f}%" - - print(f" {target_name:25} {status:8} Actual: {actual_val:10} Target: {target_val}") - + + print( + f" {target_name:25} {status:8} Actual: {actual_val:10} Target: {target_val}" + ) + # Cache Performance - summary = validation_results.get('benchmark_summary', {}) + summary = validation_results.get("benchmark_summary", {}) if summary: print(f"\nšŸ—„ļø CACHE PERFORMANCE:") print("-" * 50) - print(f" Cache Hit Rate: {summary.get('cache_hit_rate', 0)*100:.1f}%") - print(f" Avg Response Time (Hit): {summary.get('avg_hit_latency_ms', 0):.1f}ms") - print(f" Avg Response Time (Miss): {summary.get('avg_miss_latency_ms', 0):.1f}ms") + print( + f" Cache Hit Rate: {summary.get('cache_hit_rate', 0)*100:.1f}%" + ) + print( + f" Avg Response Time (Hit): {summary.get('avg_hit_latency_ms', 0):.1f}ms" + ) + print( + f" Avg Response Time (Miss): {summary.get('avg_miss_latency_ms', 0):.1f}ms" + ) print(f" Total Requests: {summary.get('total_requests', 0)}") print(f" Success Rate: {summary.get('success_rate', 0)*100:.1f}%") - + # RAG Comparison Results if comparison_result: print(f"\nāš”ļø FACT vs TRADITIONAL RAG:") print("-" * 50) - latency_improvement = comparison_result.get('latency_improvement', 1.0) - cost_savings = comparison_result.get('cost_savings', 0.0) + latency_improvement = comparison_result.get("latency_improvement", 1.0) + cost_savings = comparison_result.get("cost_savings", 0.0) print(f" Latency Improvement: {latency_improvement:.1f}x faster") print(f" Cost Savings: {cost_savings:.1f}%") - print(f" Recommendation: {comparison_result.get('recommendation', 'N/A')}") - + print( + f" Recommendation: {comparison_result.get('recommendation', 'N/A')}" + ) + # Load Test Results if load_test_results: print(f"\n🚦 LOAD TEST PERFORMANCE:") print("-" * 50) - print(f" Concurrent Users: {load_test_results.get('concurrent_users', 0)}") - print(f" Throughput: {load_test_results.get('throughput_qps', 0):.1f} QPS") - print(f" Avg Response Time: {load_test_results.get('avg_response_time_ms', 0):.1f}ms") - print(f" Error Rate: {load_test_results.get('error_rate', 0):.1f}%") - - print("="*80) + print( + f" Concurrent Users: {load_test_results.get('concurrent_users', 0)}" + ) + print( + f" Throughput: {load_test_results.get('throughput_qps', 0):.1f} QPS" + ) + print( + f" Avg Response Time: {load_test_results.get('avg_response_time_ms', 0):.1f}ms" + ) + print( + f" Error Rate: {load_test_results.get('error_rate', 0):.1f}%" + ) + + print("=" * 80) + def generate_sample_benchmark_results(args) -> Dict[str, Any]: """Generate sample benchmark results for demonstration.""" @@ -96,251 +119,268 @@ def generate_sample_benchmark_results(args) -> Dict[str, Any]: cache_hit_rate = 0.67 avg_hit_latency = 42.3 avg_miss_latency = 128.7 - + validation_results = { - 'overall_pass': True, - 'target_validation': { - 'cache_hit_latency': { - 'met': avg_hit_latency <= args.hit_target, - 'actual_ms': avg_hit_latency, - 'target_ms': args.hit_target + "overall_pass": True, + "target_validation": { + "cache_hit_latency": { + "met": avg_hit_latency <= args.hit_target, + "actual_ms": avg_hit_latency, + "target_ms": args.hit_target, + }, + "cache_miss_latency": { + "met": avg_miss_latency <= args.miss_target, + "actual_ms": avg_miss_latency, + "target_ms": args.miss_target, }, - 'cache_miss_latency': { - 'met': avg_miss_latency <= args.miss_target, - 'actual_ms': avg_miss_latency, - 'target_ms': args.miss_target + "cache_hit_rate": { + "met": cache_hit_rate * 100 >= args.cache_hit_rate, + "actual_percent": cache_hit_rate * 100, + "target_percent": args.cache_hit_rate, }, - 'cache_hit_rate': { - 'met': cache_hit_rate * 100 >= args.cache_hit_rate, - 'actual_percent': cache_hit_rate * 100, - 'target_percent': args.cache_hit_rate + "cost_reduction": { + "met": 91.5 >= args.cost_reduction, + "actual_percent": 91.5, + "target_percent": args.cost_reduction, }, - 'cost_reduction': { - 'met': 91.5 >= args.cost_reduction, - 'actual_percent': 91.5, - 'target_percent': args.cost_reduction - } }, - 'benchmark_summary': { - 'cache_hit_rate': cache_hit_rate, - 'avg_hit_latency_ms': avg_hit_latency, - 'avg_miss_latency_ms': avg_miss_latency, - 'total_requests': args.iterations, - 'success_rate': 1.0 - } + "benchmark_summary": { + "cache_hit_rate": cache_hit_rate, + "avg_hit_latency_ms": avg_hit_latency, + "avg_miss_latency_ms": avg_miss_latency, + "total_requests": args.iterations, + "success_rate": 1.0, + }, } - + # Update overall pass based on individual targets - validation_results['overall_pass'] = all( - target['met'] for target in validation_results['target_validation'].values() + validation_results["overall_pass"] = all( + target["met"] for target in validation_results["target_validation"].values() ) - + return validation_results + def generate_sample_comparison_results() -> Dict[str, Any]: """Generate sample RAG comparison results.""" return { - 'latency_improvement': 3.2, - 'cost_savings': 91.5, - 'recommendation': 'FACT shows excellent performance gains over traditional RAG' + "latency_improvement": 3.2, + "cost_savings": 91.5, + "recommendation": "FACT shows excellent performance gains over traditional RAG", } + def generate_sample_load_test_results(users: int) -> Dict[str, Any]: """Generate sample load test results.""" return { - 'concurrent_users': users, - 'throughput_qps': users * 2.5, - 'avg_response_time_ms': 65.2, - 'error_rate': 0.1 + "concurrent_users": users, + "throughput_qps": users * 2.5, + "avg_response_time_ms": 65.2, + "error_rate": 0.1, } + async def run_standalone_benchmark(args): """Run standalone benchmark demonstration.""" print("šŸš€ Starting FACT Comprehensive Benchmark Suite (Standalone Demo)") print("=" * 60) - + # Create timestamped logs directory - if args.output_dir == './benchmark_results': + if args.output_dir == "./benchmark_results": logs_dir = create_logs_directory() print(f"šŸ“ Created logs directory: {logs_dir}") else: logs_dir = Path(args.output_dir) logs_dir.mkdir(parents=True, exist_ok=True) print(f"šŸ“ Using output directory: {logs_dir}") - + # Phase 1: Performance Validation print("\nšŸ“Š Phase 1: Performance Validation") print("-" * 40) - + validation_results = generate_sample_benchmark_results(args) - + # Display validation results - print(f"Overall Validation: {'āœ… PASS' if validation_results['overall_pass'] else 'āŒ FAIL'}") - - target_validation = validation_results['target_validation'] + print( + f"Overall Validation: {'āœ… PASS' if validation_results['overall_pass'] else 'āŒ FAIL'}" + ) + + target_validation = validation_results["target_validation"] for target_name, target_data in target_validation.items(): - status = "āœ… PASS" if target_data['met'] else "āŒ FAIL" - if 'latency' in target_name: + status = "āœ… PASS" if target_data["met"] else "āŒ FAIL" + if "latency" in target_name: value = f"{target_data['actual_ms']:.1f}ms" else: value = f"{target_data['actual_percent']:.1f}%" print(f" {target_name}: {status} ({value})") - + # Phase 2: RAG Comparison (if enabled) comparison_result = None if args.include_rag_comparison: print("\nāš”ļø Phase 2: RAG Comparison Analysis") print("-" * 40) - + comparison_result = generate_sample_comparison_results() - + print(f"Latency Improvement: {comparison_result['latency_improvement']:.1f}x") print(f"Cost Savings: {comparison_result['cost_savings']:.1f}%") print(f"Recommendation: {comparison_result['recommendation']}") - + # Phase 3: Profiling Analysis (if enabled) if args.include_profiling: print("\nšŸ” Phase 3: Performance Profiling") print("-" * 40) - + print("Execution Time: 1250.3ms") print("Bottlenecks Found: 2") print("āŒ Critical Bottlenecks:") print(" - Database Query: Slow index lookup detected") print(" - Network: High latency in external API calls") - + # Phase 4: Load Testing (if enabled) load_test_results = None if args.include_load_test: print("\n🚦 Phase 4: Load Testing") print("-" * 40) - + load_test_results = generate_sample_load_test_results(args.load_test_users) - + print(f"Concurrent Users: {load_test_results['concurrent_users']}") print(f"Throughput: {load_test_results['throughput_qps']:.1f} QPS") print(f"Avg Response Time: {load_test_results['avg_response_time_ms']:.1f}ms") print(f"Error Rate: {load_test_results['error_rate']:.1f}%") - + # Phase 5: Report Generation & Visualization print("\nšŸ“ Phase 5: Report Generation & Visualization") print("-" * 40) - + # Generate timestamp for consistent naming timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") - + # Create comprehensive report data report_data = { - 'metadata': { - 'timestamp': timestamp, - 'benchmark_version': '1.0.0', - 'args': vars(args) + "metadata": { + "timestamp": timestamp, + "benchmark_version": "1.0.0", + "args": vars(args), }, - 'validation_results': validation_results, - 'comparison_result': comparison_result, - 'load_test_results': load_test_results, - 'performance_grade': 'A+' if validation_results['overall_pass'] else 'B', - 'recommendations': [ - 'Cache performance is excellent - maintain current configuration', - 'Consider increasing cache size for even better hit rates', - 'Monitor performance under higher concurrent load', - 'Implement database query optimization for better response times' - ] + "validation_results": validation_results, + "comparison_result": comparison_result, + "load_test_results": load_test_results, + "performance_grade": "A+" if validation_results["overall_pass"] else "B", + "recommendations": [ + "Cache performance is excellent - maintain current configuration", + "Consider increasing cache size for even better hit rates", + "Monitor performance under higher concurrent load", + "Implement database query optimization for better response times", + ], } - + # Save comprehensive JSON report json_report_path = logs_dir / "reports" / f"benchmark_report_{timestamp}.json" - with open(json_report_path, 'w') as f: + with open(json_report_path, "w") as f: json.dump(report_data, f, indent=2, default=str) - + # Save text summary text_report_path = logs_dir / "reports" / f"benchmark_summary_{timestamp}.txt" - with open(text_report_path, 'w') as f: + with open(text_report_path, "w") as f: f.write("FACT Benchmark Summary\n") f.write("=" * 50 + "\n\n") f.write(f"Timestamp: {timestamp}\n") f.write(f"Performance Grade: {report_data['performance_grade']}\n") f.write(f"Overall Pass: {validation_results['overall_pass']}\n\n") - + f.write("Performance Targets:\n") for target_name, target_data in target_validation.items(): - status = "PASS" if target_data['met'] else "FAIL" + status = "PASS" if target_data["met"] else "FAIL" f.write(f" {target_name}: {status}\n") - + f.write("\nRecommendations:\n") - for i, rec in enumerate(report_data['recommendations'], 1): + for i, rec in enumerate(report_data["recommendations"], 1): f.write(f" {i}. {rec}\n") - + # Save raw data for further analysis raw_data_path = logs_dir / "raw_data" / f"raw_results_{timestamp}.json" - with open(raw_data_path, 'w') as f: + with open(raw_data_path, "w") as f: json.dump(report_data, f, indent=2, default=str) - + # Generate sample visualization data charts_dir = logs_dir / "charts" - + print("šŸ“Š Generating performance visualizations...") - + # Performance overview chart performance_chart = { - 'chart_type': 'performance_overview', - 'title': 'FACT Performance Overview', - 'data': { - 'cache_hit_rate': validation_results['benchmark_summary']['cache_hit_rate'], - 'avg_hit_latency': validation_results['benchmark_summary']['avg_hit_latency_ms'], - 'avg_miss_latency': validation_results['benchmark_summary']['avg_miss_latency_ms'], - 'success_rate': validation_results['benchmark_summary']['success_rate'] - } + "chart_type": "performance_overview", + "title": "FACT Performance Overview", + "data": { + "cache_hit_rate": validation_results["benchmark_summary"]["cache_hit_rate"], + "avg_hit_latency": validation_results["benchmark_summary"][ + "avg_hit_latency_ms" + ], + "avg_miss_latency": validation_results["benchmark_summary"][ + "avg_miss_latency_ms" + ], + "success_rate": validation_results["benchmark_summary"]["success_rate"], + }, } - + performance_chart_path = charts_dir / f"performance_overview_{timestamp}.json" - with open(performance_chart_path, 'w') as f: + with open(performance_chart_path, "w") as f: json.dump(performance_chart, f, indent=2) - + # Latency comparison chart (if comparison data available) if comparison_result: latency_chart = { - 'chart_type': 'latency_comparison', - 'title': 'FACT vs Traditional RAG Latency', - 'data': { - 'fact_latency': validation_results['benchmark_summary']['avg_hit_latency_ms'], - 'rag_latency': validation_results['benchmark_summary']['avg_hit_latency_ms'] * comparison_result['latency_improvement'], - 'improvement_factor': comparison_result['latency_improvement'] - } + "chart_type": "latency_comparison", + "title": "FACT vs Traditional RAG Latency", + "data": { + "fact_latency": validation_results["benchmark_summary"][ + "avg_hit_latency_ms" + ], + "rag_latency": validation_results["benchmark_summary"][ + "avg_hit_latency_ms" + ] + * comparison_result["latency_improvement"], + "improvement_factor": comparison_result["latency_improvement"], + }, } - + latency_chart_path = charts_dir / f"latency_comparison_{timestamp}.json" - with open(latency_chart_path, 'w') as f: + with open(latency_chart_path, "w") as f: json.dump(latency_chart, f, indent=2) - + print(f"šŸ“„ Reports saved to: {logs_dir}") print(f"šŸ“„ JSON Report: {json_report_path}") print(f"šŸ“„ Summary: {text_report_path}") print(f"šŸ“„ Raw Data: {raw_data_path}") print(f"šŸ“ˆ Charts: {charts_dir}") - + # Print comprehensive performance summary to console print_performance_summary(validation_results, comparison_result, load_test_results) - + # Performance Grade - grade = report_data['performance_grade'] + grade = report_data["performance_grade"] print(f"\nšŸ† Performance Grade: {grade}") - + # Print key recommendations print(f"\nšŸ”§ KEY RECOMMENDATIONS:") print("-" * 50) - for i, rec in enumerate(report_data['recommendations'][:5], 1): + for i, rec in enumerate(report_data["recommendations"][:5], 1): print(f" {i}. {rec}") - + # Final status message - if validation_results['overall_pass']: - print(f"\nšŸŽ‰ Benchmark completed successfully! All performance targets achieved.") + if validation_results["overall_pass"]: + print( + f"\nšŸŽ‰ Benchmark completed successfully! All performance targets achieved." + ) print(f" Results saved to: {logs_dir}") else: print(f"\nāš ļø Benchmark completed with some targets not met.") print(f" Review optimization strategies and detailed reports in: {logs_dir}") - - return validation_results['overall_pass'] + + return validation_results["overall_pass"] + def main(): """Main entry point.""" @@ -357,35 +397,69 @@ def main(): # Custom performance targets python scripts/run_benchmarks_standalone.py --hit-target 40 --miss-target 120 --cost-reduction 85 - """ + """, ) - + # Benchmark configuration - parser.add_argument('--iterations', type=int, default=10, help='Number of benchmark iterations') - parser.add_argument('--concurrent-users', type=int, default=1, help='Number of concurrent users') - + parser.add_argument( + "--iterations", type=int, default=10, help="Number of benchmark iterations" + ) + parser.add_argument( + "--concurrent-users", type=int, default=1, help="Number of concurrent users" + ) + # Performance targets - parser.add_argument('--hit-target', type=float, default=48.0, help='Cache hit latency target (ms)') - parser.add_argument('--miss-target', type=float, default=140.0, help='Cache miss latency target (ms)') - parser.add_argument('--cost-reduction', type=float, default=90.0, help='Cost reduction target (percent)') - parser.add_argument('--cache-hit-rate', type=float, default=60.0, help='Cache hit rate target (percent)') - + parser.add_argument( + "--hit-target", type=float, default=48.0, help="Cache hit latency target (ms)" + ) + parser.add_argument( + "--miss-target", + type=float, + default=140.0, + help="Cache miss latency target (ms)", + ) + parser.add_argument( + "--cost-reduction", + type=float, + default=90.0, + help="Cost reduction target (percent)", + ) + parser.add_argument( + "--cache-hit-rate", + type=float, + default=60.0, + help="Cache hit rate target (percent)", + ) + # Optional components - parser.add_argument('--include-rag-comparison', action='store_true', help='Include RAG comparison') - parser.add_argument('--include-profiling', action='store_true', help='Include performance profiling') - parser.add_argument('--include-load-test', action='store_true', help='Include load testing') - + parser.add_argument( + "--include-rag-comparison", action="store_true", help="Include RAG comparison" + ) + parser.add_argument( + "--include-profiling", action="store_true", help="Include performance profiling" + ) + parser.add_argument( + "--include-load-test", action="store_true", help="Include load testing" + ) + # Load testing configuration - parser.add_argument('--load-test-users', type=int, default=10, help='Load test concurrent users') - + parser.add_argument( + "--load-test-users", type=int, default=10, help="Load test concurrent users" + ) + # Output configuration - parser.add_argument('--output-dir', default='./benchmark_results', help='Output directory for reports (default: creates timestamped logs directory)') - + parser.add_argument( + "--output-dir", + default="./benchmark_results", + help="Output directory for reports (default: creates timestamped logs directory)", + ) + args = parser.parse_args() - + # Run standalone benchmark demo success = asyncio.run(run_standalone_benchmark(args)) sys.exit(0 if success else 1) + if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/scripts/run_optimized_benchmarks.py b/scripts/run_optimized_benchmarks.py index 1651e75..b8c24be 100644 --- a/scripts/run_optimized_benchmarks.py +++ b/scripts/run_optimized_benchmarks.py @@ -22,7 +22,10 @@ from cache.manager import CacheManager, get_cache_manager from cache.warming import get_cache_warmer, warm_cache_startup from cache.metrics import get_metrics_collector -from monitoring.performance_optimizer import get_performance_optimizer, start_performance_optimization +from monitoring.performance_optimizer import ( + get_performance_optimizer, + start_performance_optimization, +) from core.config import get_config @@ -31,13 +34,13 @@ def create_optimized_logs_directory(base_dir: str = "./logs") -> Path: timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") logs_dir = Path(base_dir) / f"optimized_benchmark_{timestamp}" logs_dir.mkdir(parents=True, exist_ok=True) - + # Create subdirectories for organized output (logs_dir / "reports").mkdir(exist_ok=True) (logs_dir / "charts").mkdir(exist_ok=True) (logs_dir / "raw_data").mkdir(exist_ok=True) (logs_dir / "optimization_logs").mkdir(exist_ok=True) - + return logs_dir @@ -59,17 +62,17 @@ async def run_optimized_benchmark_suite(args): print("šŸš€ Starting FACT Optimized Benchmark Suite") print(" Enhanced with intelligent cache warming, latency optimization,") print(" and real-time performance monitoring\n") - + display_benchmark_targets() - + # Create timestamped logs directory logs_dir = create_optimized_logs_directory() timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") - + try: # Initialize optimized configuration print("\nšŸ”§ Initializing optimized FACT system...") - + # Enhanced cache configuration for optimization cache_config = { "prefix": "fact_optimized_benchmark", @@ -77,33 +80,37 @@ async def run_optimized_benchmark_suite(args): "max_size": "15MB", # Increased cache size "ttl_seconds": 7200, # 2 hours "hit_target_ms": 48.0, # Updated target - "miss_target_ms": 140.0 + "miss_target_ms": 140.0, } - + # Initialize cache manager with optimizations cache_manager = CacheManager(cache_config) - + # Initialize performance optimizer optimizer = get_performance_optimizer(cache_manager) - + # Start optimization monitoring print("šŸŽÆ Starting performance optimization monitoring...") await start_performance_optimization(cache_manager) - + # Intelligent cache warming print("šŸ”„ Performing intelligent cache warming...") cache_warmer = get_cache_warmer(cache_manager) - + # Pre-warm with optimized queries - warmup_result = await cache_warmer.warm_cache_intelligently(max_queries=args.warmup_queries) - print(f" āœ“ Warmed {warmup_result.queries_successful}/{warmup_result.queries_attempted} queries") + warmup_result = await cache_warmer.warm_cache_intelligently( + max_queries=args.warmup_queries + ) + print( + f" āœ“ Warmed {warmup_result.queries_successful}/{warmup_result.queries_attempted} queries" + ) print(f" āœ“ Cached {warmup_result.total_tokens_cached} tokens") print(f" āœ“ Warming completed in {warmup_result.total_time_ms:.1f}ms") - + # Wait for optimization to take effect print("\nā³ Allowing optimization system to analyze and adjust...") await asyncio.sleep(5) - + # Initialize benchmark framework with optimized settings benchmark_config = BenchmarkConfig( iterations=args.iterations, @@ -113,12 +120,12 @@ async def run_optimized_benchmark_suite(args): target_cache_hit_rate=0.60, target_hit_latency_ms=48.0, target_miss_latency_ms=140.0, - target_cost_reduction_hit=0.90 + target_cost_reduction_hit=0.90, ) - + framework = BenchmarkFramework(benchmark_config) runner = BenchmarkRunner(framework) - + # Enhanced test queries for comprehensive evaluation enhanced_queries = [ "What was the Q1-2025 revenue?", @@ -135,30 +142,32 @@ async def run_optimized_benchmark_suite(args): "Show customer satisfaction scores", "Analyze market performance by region", "What is current profit margin?", - "Show predictive revenue forecasts" + "Show predictive revenue forecasts", ] - + # Run performance validation print("\nšŸ“ˆ Running optimized performance validation...") print(f" • Testing {len(enhanced_queries)} diverse queries") - print(f" • {args.iterations} iterations with {args.concurrent_users} concurrent users") + print( + f" • {args.iterations} iterations with {args.concurrent_users} concurrent users" + ) print(f" • Real-time optimization monitoring active") - + # Execute benchmarks with enhanced monitoring start_time = time.time() - + # Run enhanced benchmark suite summary = await framework.run_benchmark_suite(enhanced_queries, cache_manager) - + execution_time = time.time() - start_time - + # Get optimization status opt_status = optimizer.get_optimization_status() - + # Collect comprehensive metrics final_metrics = cache_manager.get_metrics() performance_stats = cache_manager.get_performance_stats() - + # Generate validation results validation_results = { "timestamp": time.time(), @@ -174,105 +183,125 @@ async def run_optimized_benchmark_suite(args): "avg_response_time_ms": summary.avg_response_time_ms, "cost_reduction_percentage": summary.cost_reduction_percentage, "error_rate": summary.error_rate, - "throughput_qps": summary.throughput_qps + "throughput_qps": summary.throughput_qps, }, "target_validation": { "cache_hit_latency": { "target_ms": 48.0, "actual_ms": summary.avg_hit_latency_ms, "met": summary.avg_hit_latency_ms <= 48.0, - "status": "PASS" if summary.avg_hit_latency_ms <= 48.0 else "FAIL" + "status": "PASS" if summary.avg_hit_latency_ms <= 48.0 else "FAIL", }, "cache_miss_latency": { "target_ms": 140.0, "actual_ms": summary.avg_miss_latency_ms, "met": summary.avg_miss_latency_ms <= 140.0, - "status": "PASS" if summary.avg_miss_latency_ms <= 140.0 else "FAIL" + "status": ( + "PASS" if summary.avg_miss_latency_ms <= 140.0 else "FAIL" + ), }, "cost_reduction": { "target_percent": 90.0, "actual_percent": summary.cost_reduction_percentage, "met": summary.cost_reduction_percentage >= 90.0, - "status": "PASS" if summary.cost_reduction_percentage >= 90.0 else "FAIL" + "status": ( + "PASS" if summary.cost_reduction_percentage >= 90.0 else "FAIL" + ), }, "cache_hit_rate": { "target_percent": 60.0, "actual_percent": summary.cache_hit_rate * 100, "met": (summary.cache_hit_rate * 100) >= 60.0, - "status": "PASS" if (summary.cache_hit_rate * 100) >= 60.0 else "FAIL" - } + "status": ( + "PASS" if (summary.cache_hit_rate * 100) >= 60.0 else "FAIL" + ), + }, }, "optimization_status": opt_status, "cache_metrics": { "total_entries": final_metrics.total_entries, - "memory_utilization_percent": (final_metrics.total_size / cache_manager.max_size_bytes * 100), + "memory_utilization_percent": ( + final_metrics.total_size / cache_manager.max_size_bytes * 100 + ), "hit_rate_percent": final_metrics.hit_rate, - "token_efficiency": final_metrics.token_efficiency + "token_efficiency": final_metrics.token_efficiency, }, "performance_improvements": { "warming_result": { "queries_warmed": warmup_result.queries_successful, "tokens_cached": warmup_result.total_tokens_cached, - "warming_time_ms": warmup_result.total_time_ms + "warming_time_ms": warmup_result.total_time_ms, }, - "optimization_actions": len(opt_status.get('recent_actions', [])), - "strategy": opt_status.get('current_strategy', 'unknown') - } + "optimization_actions": len(opt_status.get("recent_actions", [])), + "strategy": opt_status.get("current_strategy", "unknown"), + }, } - + # Calculate overall pass status targets = validation_results["target_validation"] - validation_results["overall_pass"] = all([ - targets["cache_hit_latency"]["met"], - targets["cache_miss_latency"]["met"], - targets["cost_reduction"]["met"], - targets["cache_hit_rate"]["met"] - ]) - + validation_results["overall_pass"] = all( + [ + targets["cache_hit_latency"]["met"], + targets["cache_miss_latency"]["met"], + targets["cost_reduction"]["met"], + targets["cache_hit_rate"]["met"], + ] + ) + # Save comprehensive results await save_benchmark_results(validation_results, logs_dir, timestamp) - + # Display results display_optimization_results(validation_results, logs_dir) - + # Stop optimization monitoring from monitoring.performance_optimizer import stop_performance_optimization + await stop_performance_optimization() - + return validation_results["overall_pass"] - + except Exception as e: print(f"\nāŒ Benchmark execution failed: {str(e)}") import traceback + traceback.print_exc() return False async def save_benchmark_results(results: dict, logs_dir: Path, timestamp: str): """Save comprehensive benchmark results.""" - + # Save JSON report - json_report_path = logs_dir / "reports" / f"optimized_benchmark_report_{timestamp}.json" - with open(json_report_path, 'w') as f: + json_report_path = ( + logs_dir / "reports" / f"optimized_benchmark_report_{timestamp}.json" + ) + with open(json_report_path, "w") as f: json.dump(results, f, indent=2, default=str) - + # Save text summary - text_report_path = logs_dir / "reports" / f"optimized_benchmark_summary_{timestamp}.txt" - with open(text_report_path, 'w') as f: + text_report_path = ( + logs_dir / "reports" / f"optimized_benchmark_summary_{timestamp}.txt" + ) + with open(text_report_path, "w") as f: f.write("FACT Optimized Benchmark Results\n") f.write("=" * 50 + "\n\n") - + f.write(f"Execution Time: {results['execution_time_seconds']:.2f} seconds\n") f.write(f"Timestamp: {datetime.fromtimestamp(results['timestamp'])}\n\n") - + f.write("Performance Targets:\n") for metric, data in results["target_validation"].items(): status = "āœ… PASS" if data["met"] else "āŒ FAIL" - f.write(f" {metric}: {status} - Target: {data.get('target_ms', data.get('target_percent', 'N/A'))}, " - f"Actual: {data.get('actual_ms', data.get('actual_percent', 'N/A'))}\n") - - f.write(f"\nOverall Result: {'āœ… PASS' if results['overall_pass'] else 'āŒ FAIL'}\n\n") - + f.write( + f" {metric}: {status} - Target: {data.get('target_ms', data.get('target_percent', 'N/A'))}, " + f"Actual: {data.get('actual_ms', data.get('actual_percent', 'N/A'))}\n" + ) + + f.write( + f"\nOverall Result: {'āœ… PASS' if results['overall_pass'] else 'āŒ FAIL'}\n\n" + ) + summary = results["benchmark_summary"] f.write("Benchmark Summary:\n") f.write(f" Total Queries: {summary['total_queries']}\n") @@ -281,41 +310,52 @@ async def save_benchmark_results(results: dict, logs_dir: Path, timestamp: str): f.write(f" Average Miss Latency: {summary['avg_miss_latency_ms']:.1f}ms\n") f.write(f" Cost Reduction: {summary['cost_reduction_percentage']:.1f}%\n") f.write(f" Throughput: {summary['throughput_qps']:.1f} QPS\n") - + # Save raw data raw_data_path = logs_dir / "raw_data" / f"benchmark_data_{timestamp}.json" - with open(raw_data_path, 'w') as f: - json.dump({ - 'benchmark_summary': results['benchmark_summary'], - 'cache_metrics': results['cache_metrics'], - 'optimization_status': results['optimization_status'] - }, f, indent=2, default=str) + with open(raw_data_path, "w") as f: + json.dump( + { + "benchmark_summary": results["benchmark_summary"], + "cache_metrics": results["cache_metrics"], + "optimization_status": results["optimization_status"], + }, + f, + indent=2, + default=str, + ) def display_optimization_results(results: dict, logs_dir: Path): """Display comprehensive optimization results.""" - - print("\n" + "="*80) + + print("\n" + "=" * 80) print("šŸŽÆ FACT OPTIMIZATION BENCHMARK RESULTS") - print("="*80) - + print("=" * 80) + # Performance targets validation print("\nšŸ“Š Performance Target Validation:") targets = results["target_validation"] - + for metric_name, data in targets.items(): status_icon = "āœ…" if data["met"] else "āŒ" metric_display = metric_name.replace("_", " ").title() - + if "latency" in metric_name: - print(f" {status_icon} {metric_display:<25} {data['actual_ms']:>8.1f}ms Target: ≤{data['target_ms']}ms") + print( + f" {status_icon} {metric_display:<25} {data['actual_ms']:>8.1f}ms Target: ≤{data['target_ms']}ms" + ) else: - print(f" {status_icon} {metric_display:<25} {data['actual_percent']:>8.1f}% Target: ≄{data['target_percent']}%") - + print( + f" {status_icon} {metric_display:<25} {data['actual_percent']:>8.1f}% Target: ≄{data['target_percent']}%" + ) + # Overall result - overall_status = "šŸŽ‰ SUCCESS" if results["overall_pass"] else "āš ļø NEEDS IMPROVEMENT" + overall_status = ( + "šŸŽ‰ SUCCESS" if results["overall_pass"] else "āš ļø NEEDS IMPROVEMENT" + ) print(f"\nšŸ† Overall Result: {overall_status}") - + # Performance summary summary = results["benchmark_summary"] print(f"\nšŸ“ˆ Performance Summary:") @@ -324,30 +364,40 @@ def display_optimization_results(results: dict, logs_dir: Path): print(f" • Average Response Time: {summary['avg_response_time_ms']:.1f}ms") print(f" • System Throughput: {summary['throughput_qps']:.1f} QPS") print(f" • Error Rate: {summary['error_rate']:.1f}%") - + # Optimization impact improvements = results["performance_improvements"] print(f"\n⚔ Optimization Impact:") - print(f" • Cache Entries Warmed: {improvements['warming_result']['queries_warmed']}") - print(f" • Tokens Pre-cached: {improvements['warming_result']['tokens_cached']:,}") + print( + f" • Cache Entries Warmed: {improvements['warming_result']['queries_warmed']}" + ) + print( + f" • Tokens Pre-cached: {improvements['warming_result']['tokens_cached']:,}" + ) print(f" • Optimization Actions: {improvements['optimization_actions']}") print(f" • Strategy: {improvements['strategy']}") - + # Cost efficiency print(f"\nšŸ’° Cost Efficiency:") print(f" • Cost Reduction Achieved: {summary['cost_reduction_percentage']:.1f}%") - print(f" • Cache Memory Utilization: {results['cache_metrics']['memory_utilization_percent']:.1f}%") - print(f" • Token Efficiency: {results['cache_metrics']['token_efficiency']:.1f} tokens/KB") - + print( + f" • Cache Memory Utilization: {results['cache_metrics']['memory_utilization_percent']:.1f}%" + ) + print( + f" • Token Efficiency: {results['cache_metrics']['token_efficiency']:.1f} tokens/KB" + ) + # Results location print(f"\nšŸ“ Detailed Results Saved To:") print(f" {logs_dir}") print(f" ā”œā”€ā”€ reports/optimized_benchmark_report_*.json") print(f" ā”œā”€ā”€ reports/optimized_benchmark_summary_*.txt") print(f" └── raw_data/benchmark_data_*.json") - + if results["overall_pass"]: - print(f"\nšŸŽ‰ Optimization successful! FACT system is performing within all targets.") + print( + f"\nšŸŽ‰ Optimization successful! FACT system is performing within all targets." + ) print(f" The enhanced caching system with intelligent warming and real-time") print(f" optimization has achieved the required performance benchmarks.") else: @@ -378,29 +428,45 @@ def main(): # Load testing with optimization python scripts/run_optimized_benchmarks.py --concurrent-users 10 --iterations 15 - """ + """, ) - + # Benchmark configuration - parser.add_argument('--iterations', type=int, default=15, - help='Number of benchmark iterations (default: 15)') - parser.add_argument('--warmup', type=int, default=3, - help='Number of warmup iterations (default: 3)') - parser.add_argument('--warmup-queries', type=int, default=40, - help='Number of queries to warm before benchmarking (default: 40)') - parser.add_argument('--concurrent-users', type=int, default=5, - help='Number of concurrent users for load testing (default: 5)') - + parser.add_argument( + "--iterations", + type=int, + default=15, + help="Number of benchmark iterations (default: 15)", + ) + parser.add_argument( + "--warmup", type=int, default=3, help="Number of warmup iterations (default: 3)" + ) + parser.add_argument( + "--warmup-queries", + type=int, + default=40, + help="Number of queries to warm before benchmarking (default: 40)", + ) + parser.add_argument( + "--concurrent-users", + type=int, + default=5, + help="Number of concurrent users for load testing (default: 5)", + ) + # Output configuration - parser.add_argument('--output-dir', default='./logs', - help='Output directory for results (default: ./logs)') - + parser.add_argument( + "--output-dir", + default="./logs", + help="Output directory for results (default: ./logs)", + ) + args = parser.parse_args() - + # Run optimized benchmark success = asyncio.run(run_optimized_benchmark_suite(args)) sys.exit(0 if success else 1) if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/scripts/setup.py b/scripts/setup.py index 5b1d197..54e0256 100644 --- a/scripts/setup.py +++ b/scripts/setup.py @@ -27,30 +27,30 @@ async def setup_database(): """Set up the database with schema and sample data.""" print("šŸ—„ļø Setting up database...") - + try: config = get_config() - + # Create data directory if it doesn't exist data_dir = Path(config.database_path).parent data_dir.mkdir(parents=True, exist_ok=True) print(f"āœ… Created data directory: {data_dir}") - + # Initialize database db_manager = create_database_manager(config.database_path) await db_manager.initialize_database() - + # Get database info db_info = await db_manager.get_database_info() print(f"āœ… Database initialized: {config.database_path}") print(f" • File size: {db_info['file_size_bytes']} bytes") print(f" • Tables: {db_info['total_tables']}") - - for table_name, table_info in db_info['tables'].items(): + + for table_name, table_info in db_info["tables"].items(): print(f" • {table_name}: {table_info['row_count']} rows") - + return True - + except Exception as e: print(f"āŒ Database setup failed: {e}") return False @@ -59,16 +59,16 @@ async def setup_database(): def setup_directories(): """Create necessary directories.""" print("šŸ“ Setting up directories...") - + directories = [ "data", - "logs", + "logs", "output", "docs/api", "docs/deployment", - "docs/development" + "docs/development", ] - + for directory in directories: Path(directory).mkdir(parents=True, exist_ok=True) print(f"āœ… Created directory: {directory}") @@ -77,34 +77,36 @@ def setup_directories(): def check_environment(): """Check environment configuration.""" print("āš™ļø Checking environment configuration...") - + env_file = Path(".env") env_example = Path(".env.example") - + if not env_file.exists(): if env_example.exists(): - print("āŒ .env file not found. Please copy .env.example to .env and configure it.") + print( + "āŒ .env file not found. Please copy .env.example to .env and configure it." + ) print(" cp .env.example .env") return False else: print("āŒ Neither .env nor .env.example found.") return False - + print("āœ… .env file found") - + # Check for required environment variables required_vars = ["ANTHROPIC_API_KEY", "ARCADE_API_KEY"] missing_vars = [] - + for var in required_vars: if not os.getenv(var): missing_vars.append(var) - + if missing_vars: print(f"āŒ Missing required environment variables: {', '.join(missing_vars)}") print(" Please configure these in your .env file") return False - + print("āœ… Required environment variables configured") return True @@ -112,16 +114,17 @@ def check_environment(): def install_dependencies(): """Check if dependencies are installed.""" print("šŸ“¦ Checking dependencies...") - + try: # Try importing key dependencies import anthropic import litellm import aiosqlite import structlog + print("āœ… Core dependencies installed") return True - + except ImportError as e: print(f"āŒ Missing dependencies: {e}") print(" Please install dependencies: pip install -r requirements.txt") @@ -132,27 +135,27 @@ async def main(): """Main setup routine.""" print("šŸš€ FACT System Setup") print("=" * 50) - + success = True - + # Check dependencies if not install_dependencies(): success = False - + # Set up directories setup_directories() - + # Check environment if not check_environment(): success = False - + # Set up database (only if environment is configured) if success: if not await setup_database(): success = False - + print("\n" + "=" * 50) - + if success: print("āœ… FACT System setup completed successfully!") print("\nNext steps:") @@ -165,4 +168,4 @@ async def main(): if __name__ == "__main__": - asyncio.run(main()) \ No newline at end of file + asyncio.run(main()) diff --git a/scripts/setup_env.py b/scripts/setup_env.py index b4773e9..3ea30ce 100755 --- a/scripts/setup_env.py +++ b/scripts/setup_env.py @@ -19,23 +19,23 @@ from typing import Dict, Optional, List # Add src to path for imports -sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'src')) +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "src")) class EnvironmentSetup: """Interactive environment setup for FACT system.""" - + API_KEY_PATTERNS = { - 'ANTHROPIC_API_KEY': r'^sk-ant-api03-[A-Za-z0-9_-]+$', - 'ARCADE_API_KEY': r'^arc_[A-Za-z0-9_-]+$', - 'OPENAI_API_KEY': r'^sk-proj-[A-Za-z0-9_-]+$' + "ANTHROPIC_API_KEY": r"^sk-ant-api03-[A-Za-z0-9_-]+$", + "ARCADE_API_KEY": r"^arc_[A-Za-z0-9_-]+$", + "OPENAI_API_KEY": r"^sk-proj-[A-Za-z0-9_-]+$", } - + def __init__(self, force: bool = False, minimal: bool = False): self.force = force self.minimal = minimal self.config: Dict[str, str] = {} - + def print_header(self): """Print setup script header.""" print("šŸš€ FACT Environment Configuration Setup") @@ -43,81 +43,94 @@ def print_header(self): print("This script will help you configure your FACT system environment.") print("You'll need API keys from Anthropic and Arcade AI to proceed.") print() - + def check_existing_env(self) -> bool: """Check if .env file already exists.""" - env_path = Path('.env') - + env_path = Path(".env") + if env_path.exists() and not self.force: print("āš ļø A .env file already exists!") print(f"šŸ“„ Path: {env_path.absolute()}") print() - - choice = input("Do you want to:\n" - " [u] Update existing file\n" - " [o] Overwrite completely\n" - " [c] Cancel setup\n" - "Choice (u/o/c): ").lower().strip() - - if choice == 'c': + + choice = ( + input( + "Do you want to:\n" + " [u] Update existing file\n" + " [o] Overwrite completely\n" + " [c] Cancel setup\n" + "Choice (u/o/c): " + ) + .lower() + .strip() + ) + + if choice == "c": print("āŒ Setup cancelled.") return False - elif choice == 'o': + elif choice == "o": self.force = True - elif choice == 'u': + elif choice == "u": # Load existing configuration self.load_existing_config() else: print("āŒ Invalid choice. Setup cancelled.") return False - + return True - + def load_existing_config(self): """Load existing .env configuration.""" try: - with open('.env', 'r') as f: + with open(".env", "r") as f: for line in f: line = line.strip() - if line and not line.startswith('#') and '=' in line: - key, value = line.split('=', 1) + if line and not line.startswith("#") and "=" in line: + key, value = line.split("=", 1) self.config[key] = value print(f"āœ… Loaded {len(self.config)} existing configuration items") except Exception as e: print(f"āš ļø Could not load existing config: {e}") - - def get_api_key(self, service: str, key_name: str, pattern: str, - description: str, url: str, required: bool = True) -> Optional[str]: + + def get_api_key( + self, + service: str, + key_name: str, + pattern: str, + description: str, + url: str, + required: bool = True, + ) -> Optional[str]: """Get API key from user with validation.""" - + # Check if key already exists - existing_key = self.config.get(key_name, '') + existing_key = self.config.get(key_name, "") if existing_key and existing_key != f"your_{key_name.lower()}_here": print(f"šŸ“‹ {service} API Key: Already configured") keep = input(f" Keep existing key? [Y/n]: ").lower().strip() - if keep != 'n': + if keep != "n": return existing_key - + print(f"\nšŸ”‘ {service} API Key Configuration") print(f" Description: {description}") print(f" Get your key: {url}") print(f" Format: {pattern}") - + if not required: skip = input(f" Skip {service} configuration? [y/N]: ").lower().strip() - if skip == 'y': + if skip == "y": return None - + while True: key = input(f" Enter your {service} API key: ").strip() - + if not key: if required: print(" āŒ This key is required. Please enter a valid key.") continue else: return None - + # Validate format if re.match(self.API_KEY_PATTERNS[key_name], key): print(f" āœ… Valid {service} API key format") @@ -125,20 +138,20 @@ def get_api_key(self, service: str, key_name: str, pattern: str, else: print(f" āŒ Invalid key format. Expected: {pattern}") print(f" šŸ’” Make sure you copied the full key correctly") - + retry = input(" Try again? [Y/n]: ").lower().strip() - if retry == 'n': + if retry == "n": if required: print(" āŒ This key is required to continue.") continue else: return None - + def configure_required_keys(self) -> bool: """Configure required API keys.""" print("\nšŸ“‹ REQUIRED API KEYS") print("-" * 30) - + # Anthropic API Key anthropic_key = self.get_api_key( service="Anthropic Claude", @@ -146,15 +159,15 @@ def configure_required_keys(self) -> bool: pattern="sk-ant-api03-*", description="Claude API for LLM capabilities", url="https://console.anthropic.com/", - required=True + required=True, ) - + if not anthropic_key: print("āŒ Anthropic API key is required. Setup cannot continue.") return False - - self.config['ANTHROPIC_API_KEY'] = anthropic_key - + + self.config["ANTHROPIC_API_KEY"] = anthropic_key + # Arcade API Key arcade_key = self.get_api_key( service="Arcade AI", @@ -162,25 +175,25 @@ def configure_required_keys(self) -> bool: pattern="arc_*", description="Arcade AI for tool execution", url="https://arcade-ai.com/dashboard", - required=True + required=True, ) - + if not arcade_key: print("āŒ Arcade AI API key is required. Setup cannot continue.") return False - - self.config['ARCADE_API_KEY'] = arcade_key - + + self.config["ARCADE_API_KEY"] = arcade_key + return True - + def configure_optional_keys(self) -> bool: """Configure optional API keys.""" if self.minimal: return True - + print("\nšŸ“‹ OPTIONAL API KEYS") print("-" * 30) - + # OpenAI API Key openai_key = self.get_api_key( service="OpenAI", @@ -188,162 +201,166 @@ def configure_optional_keys(self) -> bool: pattern="sk-proj-*", description="OpenAI API for extended LLM capabilities", url="https://platform.openai.com/api-keys", - required=False + required=False, ) - + if openai_key: - self.config['OPENAI_API_KEY'] = openai_key - + self.config["OPENAI_API_KEY"] = openai_key + return True - + def configure_system_settings(self): """Configure system settings.""" if self.minimal: # Set minimal required settings defaults = { - 'ARCADE_BASE_URL': 'https://api.arcade-ai.com', - 'DATABASE_PATH': 'data/fact_demo.db', - 'CLAUDE_MODEL': 'claude-3-5-sonnet-20241022', - 'LOG_LEVEL': 'INFO' + "ARCADE_BASE_URL": "https://api.arcade-ai.com", + "DATABASE_PATH": "data/fact_demo.db", + "CLAUDE_MODEL": "claude-3-5-sonnet-20241022", + "LOG_LEVEL": "INFO", } - + for key, value in defaults.items(): if key not in self.config: self.config[key] = value - + return - + print("\nšŸ“‹ SYSTEM CONFIGURATION") print("-" * 30) - + # Claude Model Selection models = [ - 'claude-3-5-sonnet-20241022', - 'claude-3-haiku-20240307', - 'claude-3-opus-20240229' + "claude-3-5-sonnet-20241022", + "claude-3-haiku-20240307", + "claude-3-opus-20240229", ] - - current_model = self.config.get('CLAUDE_MODEL', models[0]) + + current_model = self.config.get("CLAUDE_MODEL", models[0]) print(f"šŸ¤– Claude Model (current: {current_model})") print(" Available models:") for i, model in enumerate(models, 1): marker = " (recommended)" if model == models[0] else "" print(f" {i}. {model}{marker}") - - choice = input(f" Select model [1-{len(models)}] or press Enter for current: ").strip() + + choice = input( + f" Select model [1-{len(models)}] or press Enter for current: " + ).strip() if choice.isdigit() and 1 <= int(choice) <= len(models): - self.config['CLAUDE_MODEL'] = models[int(choice) - 1] + self.config["CLAUDE_MODEL"] = models[int(choice) - 1] elif not choice: - self.config['CLAUDE_MODEL'] = current_model - + self.config["CLAUDE_MODEL"] = current_model + # Log Level - log_levels = ['DEBUG', 'INFO', 'WARNING', 'ERROR'] - current_log = self.config.get('LOG_LEVEL', 'INFO') + log_levels = ["DEBUG", "INFO", "WARNING", "ERROR"] + current_log = self.config.get("LOG_LEVEL", "INFO") print(f"\nšŸ“Š Log Level (current: {current_log})") print(" Available levels:") for i, level in enumerate(log_levels, 1): - marker = " (recommended)" if level == 'INFO' else "" + marker = " (recommended)" if level == "INFO" else "" print(f" {i}. {level}{marker}") - - choice = input(f" Select level [1-{len(log_levels)}] or press Enter for current: ").strip() + + choice = input( + f" Select level [1-{len(log_levels)}] or press Enter for current: " + ).strip() if choice.isdigit() and 1 <= int(choice) <= len(log_levels): - self.config['LOG_LEVEL'] = log_levels[int(choice) - 1] + self.config["LOG_LEVEL"] = log_levels[int(choice) - 1] elif not choice: - self.config['LOG_LEVEL'] = current_log - + self.config["LOG_LEVEL"] = current_log + # Set other defaults defaults = { - 'ARCADE_BASE_URL': 'https://api.arcade-ai.com', - 'DATABASE_PATH': 'data/fact_demo.db', - 'SYSTEM_PROMPT': 'You are a deterministic finance assistant. When uncertain, request data via tools.', - 'MAX_RETRIES': '3', - 'REQUEST_TIMEOUT': '30' + "ARCADE_BASE_URL": "https://api.arcade-ai.com", + "DATABASE_PATH": "data/fact_demo.db", + "SYSTEM_PROMPT": "You are a deterministic finance assistant. When uncertain, request data via tools.", + "MAX_RETRIES": "3", + "REQUEST_TIMEOUT": "30", } - + for key, value in defaults.items(): if key not in self.config: self.config[key] = value - + def configure_security_settings(self): """Configure security settings.""" if self.minimal: # Set secure defaults for minimal setup security_defaults = { - 'STRICT_MODE': 'true', - 'DEBUG_MODE': 'false', - 'ENFORCE_HTTPS': 'true', - 'RATE_LIMITING_ENABLED': 'true' + "STRICT_MODE": "true", + "DEBUG_MODE": "false", + "ENFORCE_HTTPS": "true", + "RATE_LIMITING_ENABLED": "true", } - + for key, value in security_defaults.items(): if key not in self.config: self.config[key] = value - + return - + print("\nšŸ”’ SECURITY CONFIGURATION") print("-" * 30) - + # Environment type env_types = { - 'production': { - 'STRICT_MODE': 'true', - 'DEBUG_MODE': 'false', - 'ENFORCE_HTTPS': 'true', - 'RATE_LIMITING_ENABLED': 'true', - 'LOG_SECURITY_EVENTS': 'true' + "production": { + "STRICT_MODE": "true", + "DEBUG_MODE": "false", + "ENFORCE_HTTPS": "true", + "RATE_LIMITING_ENABLED": "true", + "LOG_SECURITY_EVENTS": "true", + }, + "development": { + "STRICT_MODE": "false", + "DEBUG_MODE": "true", + "ENFORCE_HTTPS": "false", + "RATE_LIMITING_ENABLED": "false", + "LOG_SECURITY_EVENTS": "false", }, - 'development': { - 'STRICT_MODE': 'false', - 'DEBUG_MODE': 'true', - 'ENFORCE_HTTPS': 'false', - 'RATE_LIMITING_ENABLED': 'false', - 'LOG_SECURITY_EVENTS': 'false' - } } - + print("šŸ—ļø Environment Type:") print(" 1. Production (strict security)") print(" 2. Development (relaxed security)") - + choice = input(" Select environment [1-2]: ").strip() - - if choice == '1': - env_config = env_types['production'] + + if choice == "1": + env_config = env_types["production"] print(" āœ… Production security settings applied") - elif choice == '2': - env_config = env_types['development'] + elif choice == "2": + env_config = env_types["development"] print(" āœ… Development security settings applied") else: - env_config = env_types['production'] + env_config = env_types["production"] print(" āœ… Default (production) security settings applied") - + # Apply security settings for key, value in env_config.items(): if key not in self.config: self.config[key] = value - + def write_env_file(self) -> bool: """Write configuration to .env file.""" try: # Create data directory if it doesn't exist - data_dir = Path('data') + data_dir = Path("data") data_dir.mkdir(exist_ok=True) - + # Generate .env content env_content = self.generate_env_content() - + # Write to file - with open('.env', 'w') as f: + with open(".env", "w") as f: f.write(env_content) - + print(f"āœ… Configuration saved to .env") return True - + except Exception as e: print(f"āŒ Failed to write .env file: {e}") return False - + def generate_env_content(self) -> str: """Generate .env file content with proper formatting.""" content = [ @@ -354,102 +371,118 @@ def generate_env_content(self) -> str: "# =============================================================================", "# REQUIRED API KEYS", "# =============================================================================", - "" + "", ] - + # Required API keys - required_keys = ['ANTHROPIC_API_KEY', 'ARCADE_API_KEY'] + required_keys = ["ANTHROPIC_API_KEY", "ARCADE_API_KEY"] for key in required_keys: if key in self.config: content.append(f"{key}={self.config[key]}") - - content.extend([ - "", - "# =============================================================================", - "# OPTIONAL API KEYS", - "# =============================================================================", - "" - ]) - + + content.extend( + [ + "", + "# =============================================================================", + "# OPTIONAL API KEYS", + "# =============================================================================", + "", + ] + ) + # Optional API keys - optional_keys = ['OPENAI_API_KEY', 'ENCRYPTION_KEY', 'CACHE_ENCRYPTION_KEY'] + optional_keys = ["OPENAI_API_KEY", "ENCRYPTION_KEY", "CACHE_ENCRYPTION_KEY"] for key in optional_keys: if key in self.config: content.append(f"{key}={self.config[key]}") else: content.append(f"# {key}=your_{key.lower()}_here") - - content.extend([ - "", - "# =============================================================================", - "# SYSTEM CONFIGURATION", - "# =============================================================================", - "" - ]) - + + content.extend( + [ + "", + "# =============================================================================", + "# SYSTEM CONFIGURATION", + "# =============================================================================", + "", + ] + ) + # System configuration system_keys = [ - 'ARCADE_BASE_URL', 'DATABASE_PATH', 'CLAUDE_MODEL', 'SYSTEM_PROMPT', - 'MAX_RETRIES', 'REQUEST_TIMEOUT', 'LOG_LEVEL' + "ARCADE_BASE_URL", + "DATABASE_PATH", + "CLAUDE_MODEL", + "SYSTEM_PROMPT", + "MAX_RETRIES", + "REQUEST_TIMEOUT", + "LOG_LEVEL", ] - + for key in system_keys: if key in self.config: content.append(f"{key}={self.config[key]}") - - content.extend([ - "", - "# =============================================================================", - "# SECURITY CONFIGURATION", - "# =============================================================================", - "" - ]) - + + content.extend( + [ + "", + "# =============================================================================", + "# SECURITY CONFIGURATION", + "# =============================================================================", + "", + ] + ) + # Security configuration security_keys = [ - 'STRICT_MODE', 'DEBUG_MODE', 'ENFORCE_HTTPS', 'RATE_LIMITING_ENABLED', - 'LOG_SECURITY_EVENTS' + "STRICT_MODE", + "DEBUG_MODE", + "ENFORCE_HTTPS", + "RATE_LIMITING_ENABLED", + "LOG_SECURITY_EVENTS", ] - + for key in security_keys: if key in self.config: content.append(f"{key}={self.config[key]}") - - content.extend([ - "", - "# =============================================================================", - "# CACHE CONFIGURATION (Defaults will be used if not specified)", - "# =============================================================================", - "", - "# CACHE_PREFIX=fact_v1", - "# CACHE_MIN_TOKENS=50", - "# CACHE_MAX_SIZE=100MB", - "# CACHE_TTL_SECONDS=3600", - "", - "# For more configuration options, see:", - "# - docs/environment-configuration-guide.md", - "# - .env.template", - "" - ]) - - return '\n'.join(content) - + + content.extend( + [ + "", + "# =============================================================================", + "# CACHE CONFIGURATION (Defaults will be used if not specified)", + "# =============================================================================", + "", + "# CACHE_PREFIX=fact_v1", + "# CACHE_MIN_TOKENS=50", + "# CACHE_MAX_SIZE=100MB", + "# CACHE_TTL_SECONDS=3600", + "", + "# For more configuration options, see:", + "# - docs/environment-configuration-guide.md", + "# - .env.template", + "", + ] + ) + + return "\n".join(content) + def run_validation(self) -> bool: """Run configuration validation.""" print("\nšŸ” VALIDATING CONFIGURATION") print("-" * 30) - + try: # Import and run validation from pathlib import Path import subprocess - + result = subprocess.run( - [sys.executable, 'scripts/validate_env.py', '--verbose'], + [sys.executable, "scripts/validate_env.py", "--verbose"], capture_output=True, - text=True + text=True, ) - + if result.returncode == 0: print("āœ… Configuration validation passed!") return True @@ -458,13 +491,13 @@ def run_validation(self) -> bool: print(result.stdout) print(result.stderr) return False - + except Exception as e: print(f"āš ļø Could not run validation: {e}") print("šŸ’” You can manually validate later with:") print(" python scripts/validate_env.py --verbose") return True - + def print_next_steps(self): """Print next steps for the user.""" print("\nšŸŽ‰ SETUP COMPLETE!") @@ -493,39 +526,39 @@ def print_next_steps(self): print("• Never commit .env files to version control") print("• Rotate your API keys regularly") print("• Use different keys for different environments") - + def run_setup(self) -> bool: """Run the complete setup process.""" self.print_header() - + # Check existing environment if not self.check_existing_env(): return False - + # Configure required API keys if not self.configure_required_keys(): return False - + # Configure optional API keys if not self.configure_optional_keys(): return False - + # Configure system settings self.configure_system_settings() - + # Configure security settings self.configure_security_settings() - + # Write configuration file if not self.write_env_file(): return False - + # Validate configuration validation_success = self.run_validation() - + # Print next steps self.print_next_steps() - + return validation_success @@ -542,30 +575,32 @@ def main(): This script will guide you through configuring your FACT environment with the required API keys and optimal settings. - """ + """, ) - + parser.add_argument( - '--force', '-f', - action='store_true', - help='Overwrite existing .env file without prompting' + "--force", + "-f", + action="store_true", + help="Overwrite existing .env file without prompting", ) - + parser.add_argument( - '--minimal', '-m', - action='store_true', - help='Set up minimal configuration with defaults' + "--minimal", + "-m", + action="store_true", + help="Set up minimal configuration with defaults", ) - + args = parser.parse_args() - + # Run setup setup = EnvironmentSetup(force=args.force, minimal=args.minimal) success = setup.run_setup() - + # Exit with appropriate code sys.exit(0 if success else 1) if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/scripts/test_benchmark_runner.py b/scripts/test_benchmark_runner.py index 39cd020..70091cb 100644 --- a/scripts/test_benchmark_runner.py +++ b/scripts/test_benchmark_runner.py @@ -8,13 +8,14 @@ import subprocess from pathlib import Path + def test_benchmark_runner(): """Test the benchmark runner script.""" script_path = Path(__file__).parent / "run_benchmarks.py" - + print("🧪 Testing FACT Benchmark Runner") print("=" * 50) - + # Test 1: Check if script can be imported print("Test 1: Import validation...") try: @@ -22,30 +23,35 @@ def test_benchmark_runner(): src_path = str(Path(__file__).parent.parent / "src") if src_path not in sys.path: sys.path.insert(0, src_path) - + # Test basic imports that don't require relative imports print(" Testing basic module structure...") - + # Check if benchmarking module exists import benchmarking + print(" āœ… Benchmarking module found") - + # Check if cache module exists import cache + print(" āœ… Cache module found") - + print("āœ… Core modules accessible") except ImportError as e: print(f"āŒ Import failed: {e}") return False - + # Test 2: Check command line help print("\nTest 2: Command line interface...") try: - result = subprocess.run([ - sys.executable, str(script_path), "--help" - ], capture_output=True, text=True, timeout=10) - + result = subprocess.run( + [sys.executable, str(script_path), "--help"], + capture_output=True, + text=True, + timeout=10, + ) + if result.returncode == 0: print("āœ… CLI help works correctly") else: @@ -54,21 +60,22 @@ def test_benchmark_runner(): except Exception as e: print(f"āŒ CLI test failed: {e}") return False - + # Test 3: Directory creation functionality print("\nTest 3: Directory creation...") try: # Import our enhanced functions sys.path.insert(0, str(Path(__file__).parent)) import run_benchmarks - + # Test directory creation test_dir = run_benchmarks.create_logs_directory("test_logs") if test_dir.exists(): print(f"āœ… Directory created: {test_dir}") - + # Clean up import shutil + shutil.rmtree(test_dir.parent) print("āœ… Cleanup successful") else: @@ -77,15 +84,16 @@ def test_benchmark_runner(): except Exception as e: print(f"āŒ Directory test failed: {e}") return False - + print("\nšŸŽ‰ All tests passed! Benchmark runner is ready to use.") print("\nTo run benchmarks:") print(f" python {script_path}") print(f" python {script_path} --include-rag-comparison") print(f" python {script_path} --include-profiling --include-load-test") - + return True + if __name__ == "__main__": success = test_benchmark_runner() - sys.exit(0 if success else 1) \ No newline at end of file + sys.exit(0 if success else 1) diff --git a/scripts/test_cache_fix.py b/scripts/test_cache_fix.py index 0c7ceda..aa91d8a 100644 --- a/scripts/test_cache_fix.py +++ b/scripts/test_cache_fix.py @@ -20,36 +20,36 @@ def test_cache_validation_fix(): """Test that cache validation now uses configured minimum tokens.""" - + print("=== FACT Cache Validation Fix Verification ===\n") - + # Test 1: Default configuration print("1. Testing default configuration (500 tokens)...") default_config = get_default_cache_config() manager_default = CacheManager(default_config) print(f" Default min_tokens: {manager_default.min_tokens}") - - # Test 2: Custom configuration + + # Test 2: Custom configuration print("\n2. Testing custom configuration (100 tokens)...") custom_config = default_config.copy() - custom_config['min_tokens'] = 100 + custom_config["min_tokens"] = 100 manager_custom = CacheManager(custom_config) print(f" Custom min_tokens: {manager_custom.min_tokens}") - + # Test 3: Environment override print("\n3. Testing environment override...") - os.environ['CACHE_MIN_TOKENS'] = '75' + os.environ["CACHE_MIN_TOKENS"] = "75" try: env_config = load_cache_config_from_env() manager_env = CacheManager(env_config.to_dict()) print(f" Environment min_tokens: {manager_env.min_tokens}") except Exception as e: print(f" Environment test failed: {e}") - + # Test 4: Validation consistency print("\n4. Testing validation consistency...") test_content = "This is a test response. " * 15 # ~375 tokens - + try: # Should fail with default manager (requires 500 tokens) hash_key = manager_default.generate_hash("test_query_1") @@ -57,7 +57,7 @@ def test_cache_validation_fix(): print(" āŒ ERROR: Default manager should have rejected content") except Exception: print(" āœ… Default manager correctly rejected content (<500 tokens)") - + try: # Should succeed with custom manager (requires 100 tokens) hash_key = manager_custom.generate_hash("test_query_2") @@ -65,7 +65,7 @@ def test_cache_validation_fix(): print(f" āœ… Custom manager accepted content ({entry.token_count} tokens)") except Exception as e: print(f" āŒ Custom manager failed: {e}") - + print("\n=== Fix Summary ===") print("āœ… Hard-coded 500 token minimum removed from CacheEntry._validate()") print("āœ… CacheEntry now accepts configurable min_tokens parameter") @@ -75,4 +75,4 @@ def test_cache_validation_fix(): if __name__ == "__main__": - test_cache_validation_fix() \ No newline at end of file + test_cache_validation_fix() diff --git a/scripts/test_cache_resilience.py b/scripts/test_cache_resilience.py index 534f5bd..0681045 100644 --- a/scripts/test_cache_resilience.py +++ b/scripts/test_cache_resilience.py @@ -20,8 +20,11 @@ sys.path.insert(0, src_path) from cache.resilience import ( - CacheCircuitBreaker, CircuitBreakerConfig, ResilientCacheWrapper, - CircuitState, FailureRecord + CacheCircuitBreaker, + CircuitBreakerConfig, + ResilientCacheWrapper, + CircuitState, + FailureRecord, ) from cache.manager import CacheManager from core.errors import CacheError @@ -32,7 +35,7 @@ processors=[ structlog.processors.TimeStamper(fmt="ISO"), structlog.processors.add_log_level, - structlog.processors.JSONRenderer() + structlog.processors.JSONRenderer(), ], logger_factory=structlog.PrintLoggerFactory(), wrapper_class=structlog.BoundLogger, @@ -44,11 +47,11 @@ class FailingCacheManager: """Mock cache manager that can be configured to fail.""" - + def __init__(self, failure_rate: float = 0.0, failure_after: int = 0): """ Initialize failing cache manager. - + Args: failure_rate: Probability of failure (0.0 = never fail, 1.0 = always fail) failure_after: Fail after this many successful operations @@ -57,69 +60,79 @@ def __init__(self, failure_rate: float = 0.0, failure_after: int = 0): self.failure_after = failure_after self.operation_count = 0 self.cache = {} - + def set_failure_rate(self, rate: float): """Set new failure rate.""" self.failure_rate = rate logger.info("Cache failure rate changed", rate=rate) - + def set_failure_after(self, count: int): """Set to fail after specific number of operations.""" self.failure_after = count self.operation_count = 0 logger.info("Cache set to fail after operations", count=count) - + def _should_fail(self) -> bool: """Determine if this operation should fail.""" self.operation_count += 1 - + if self.failure_after > 0 and self.operation_count > self.failure_after: return True - + import random + return random.random() < self.failure_rate - + async def store(self, query_hash: str, content: str): """Store with potential failure.""" if self._should_fail(): - raise CacheError("Simulated cache store failure", error_code="CACHE_STORE_FAILED") - - self.cache[query_hash] = { - 'content': content, - 'timestamp': time.time() - } + raise CacheError( + "Simulated cache store failure", error_code="CACHE_STORE_FAILED" + ) + + self.cache[query_hash] = {"content": content, "timestamp": time.time()} return True - + async def get(self, query_hash: str): """Get with potential failure.""" if self._should_fail(): - raise CacheError("Simulated cache get failure", error_code="CACHE_GET_FAILED") - + raise CacheError( + "Simulated cache get failure", error_code="CACHE_GET_FAILED" + ) + if query_hash in self.cache: - return type('CacheEntry', (), { - 'content': self.cache[query_hash]['content'], - 'timestamp': self.cache[query_hash]['timestamp'] - })() + return type( + "CacheEntry", + (), + { + "content": self.cache[query_hash]["content"], + "timestamp": self.cache[query_hash]["timestamp"], + }, + )() return None - + async def invalidate_by_prefix(self, prefix: str) -> int: """Invalidate with potential failure.""" if self._should_fail(): - raise CacheError("Simulated cache invalidate failure", error_code="CACHE_INVALIDATE_FAILED") - + raise CacheError( + "Simulated cache invalidate failure", + error_code="CACHE_INVALIDATE_FAILED", + ) + # Simulate invalidation return len(self.cache) - + def generate_hash(self, query: str) -> str: """Generate hash (no failure simulation for local operation).""" import hashlib + return hashlib.sha256(query.encode()).hexdigest() async def test_circuit_breaker_states(): """Test circuit breaker state transitions.""" logger.info("=== Testing Circuit Breaker State Transitions ===") - + # Configure circuit breaker for fast testing config = CircuitBreakerConfig( failure_threshold=3, # Open after 3 failures @@ -127,31 +140,33 @@ async def test_circuit_breaker_states(): timeout_seconds=2.0, # Fast timeout for testing rolling_window_seconds=60.0, gradual_recovery=True, - recovery_factor=0.5 + recovery_factor=0.5, ) - + circuit_breaker = CacheCircuitBreaker(config) failing_cache = FailingCacheManager(failure_rate=0.0) # Start with no failures resilient_cache = ResilientCacheWrapper(failing_cache, circuit_breaker) - + # Test normal operation (CLOSED state) logger.info("Testing CLOSED state - normal operations") assert circuit_breaker.get_state() == CircuitState.CLOSED - + for i in range(3): result = await resilient_cache.store(f"test_key_{i}", f"test_content_{i}") assert result == True - + metrics = circuit_breaker.get_metrics() - logger.info("CLOSED state metrics", - state=metrics.state.value, - success_count=metrics.success_count, - failure_count=metrics.failure_count) - + logger.info( + "CLOSED state metrics", + state=metrics.state.value, + success_count=metrics.success_count, + failure_count=metrics.failure_count, + ) + # Introduce failures to trigger OPEN state logger.info("Introducing failures to trigger OPEN state") failing_cache.set_failure_rate(1.0) # 100% failure rate - + failure_count = 0 for i in range(5): try: @@ -161,144 +176,148 @@ async def test_circuit_breaker_states(): if "CIRCUIT_BREAKER_OPEN" in str(e.error_code): logger.info("Circuit breaker opened successfully") break - + assert circuit_breaker.get_state() == CircuitState.OPEN - + # Test that operations use graceful degradation when OPEN logger.info("Testing OPEN state - operations should use graceful degradation") - + # With graceful degradation enabled, operations should return fallback values result = await resilient_cache.get("test_key") assert result is None, "Expected fallback response (None) from graceful degradation" - + # Test store operation also returns fallback result = await resilient_cache.store("test_key", "test_content") assert result == True, "Expected fallback response (True) from graceful degradation" - + # Disable graceful degradation to test actual circuit breaker exceptions resilient_cache.enable_graceful_degradation = False - + try: await resilient_cache.get("test_key") - assert False, "Expected circuit breaker to raise exception when graceful degradation disabled" + assert ( + False + ), "Expected circuit breaker to raise exception when graceful degradation disabled" except CacheError as e: assert "CIRCUIT_BREAKER_OPEN" in str(e.error_code) - logger.info("Circuit breaker correctly raised exception when graceful degradation disabled") - + logger.info( + "Circuit breaker correctly raised exception when graceful degradation disabled" + ) + # Re-enable graceful degradation for remaining tests resilient_cache.enable_graceful_degradation = True - + # Wait for timeout and test HALF_OPEN transition logger.info("Waiting for timeout to test HALF_OPEN transition") await asyncio.sleep(2.5) # Wait longer than timeout - + # Fix the cache and try operation to trigger HALF_OPEN failing_cache.set_failure_rate(0.0) # Fix the cache - + # This should transition to HALF_OPEN and succeed result = await resilient_cache.store("recovery_test", "recovery_content") assert result == True - + # Circuit should be HALF_OPEN or CLOSED now state = circuit_breaker.get_state() logger.info("State after recovery attempt", state=state.value) - + # Continue with successful operations to close circuit for i in range(3): await resilient_cache.store(f"recovery_key_{i}", f"recovery_content_{i}") - + final_state = circuit_breaker.get_state() logger.info("Final state after recovery", state=final_state.value) - + metrics = circuit_breaker.get_metrics() - logger.info("Final metrics", - state=metrics.state.value, - success_count=metrics.success_count, - failure_count=metrics.failure_count, - failure_rate=metrics.failure_rate, - state_changes=metrics.state_changes) - + logger.info( + "Final metrics", + state=metrics.state.value, + success_count=metrics.success_count, + failure_count=metrics.failure_count, + failure_rate=metrics.failure_rate, + state_changes=metrics.state_changes, + ) + logger.info("āœ… Circuit breaker state transitions test passed") async def test_graceful_degradation(): """Test graceful degradation when cache fails.""" logger.info("=== Testing Graceful Degradation ===") - + # Configure circuit breaker config = CircuitBreakerConfig( failure_threshold=2, success_threshold=2, timeout_seconds=1.0, - gradual_recovery=True + gradual_recovery=True, ) - + circuit_breaker = CacheCircuitBreaker(config) failing_cache = FailingCacheManager(failure_rate=1.0) # Always fail resilient_cache = ResilientCacheWrapper(failing_cache, circuit_breaker) - + # Enable graceful degradation resilient_cache.enable_graceful_degradation = True - + # First, trigger circuit breaker to open logger.info("Triggering circuit breaker to open state") - + for i in range(3): # Exceed failure threshold try: await resilient_cache.store(f"fail_key_{i}", "content") except CacheError: pass # Expected failures - + # Verify circuit is open assert circuit_breaker.get_state() == CircuitState.OPEN logger.info("Circuit breaker is now OPEN") - + logger.info("Testing cache operations with graceful degradation") - + # Test store operation - should return graceful response result = await resilient_cache.store("test_key", "test_content") logger.info("Store operation result", result=result) assert result == True # Graceful fallback - + # Test get operation - should return None (cache miss) result = await resilient_cache.get("test_key") logger.info("Get operation result", result=result) assert result is None - + # Test invalidate operation - should return 0 result = await resilient_cache.invalidate_by_prefix("test_") logger.info("Invalidate operation result", result=result) assert result == 0 - + # Check circuit breaker state state = circuit_breaker.get_state() logger.info("Circuit breaker state", state=state.value) - + logger.info("āœ… Graceful degradation test passed") async def test_performance_under_failures(): """Test performance characteristics under various failure scenarios.""" logger.info("=== Testing Performance Under Failures ===") - + config = CircuitBreakerConfig( - failure_threshold=5, - success_threshold=3, - timeout_seconds=1.0 + failure_threshold=5, success_threshold=3, timeout_seconds=1.0 ) - + circuit_breaker = CacheCircuitBreaker(config) failing_cache = FailingCacheManager(failure_rate=0.3) # 30% failure rate resilient_cache = ResilientCacheWrapper(failing_cache, circuit_breaker) - + # Test with mixed success/failure pattern operation_count = 50 success_count = 0 failure_count = 0 total_time = 0 - + logger.info("Running performance test with 30% failure rate") - + for i in range(operation_count): start_time = time.time() try: @@ -306,47 +325,47 @@ async def test_performance_under_failures(): success_count += 1 except CacheError: failure_count += 1 - + end_time = time.time() - total_time += (end_time - start_time) - + total_time += end_time - start_time + avg_latency = (total_time / operation_count) * 1000 # Convert to ms - + metrics = circuit_breaker.get_metrics() - - logger.info("Performance test results", - total_operations=operation_count, - successes=success_count, - failures=failure_count, - avg_latency_ms=avg_latency, - circuit_state=metrics.state.value, - circuit_failure_rate=metrics.failure_rate, - state_changes=metrics.state_changes) - + + logger.info( + "Performance test results", + total_operations=operation_count, + successes=success_count, + failures=failure_count, + avg_latency_ms=avg_latency, + circuit_state=metrics.state.value, + circuit_failure_rate=metrics.failure_rate, + state_changes=metrics.state_changes, + ) + logger.info("āœ… Performance test completed") async def test_circuit_breaker_metrics(): """Test circuit breaker metrics collection.""" logger.info("=== Testing Circuit Breaker Metrics ===") - + config = CircuitBreakerConfig( - failure_threshold=3, - success_threshold=2, - timeout_seconds=1.0 + failure_threshold=3, success_threshold=2, timeout_seconds=1.0 ) - + circuit_breaker = CacheCircuitBreaker(config) failing_cache = FailingCacheManager() resilient_cache = ResilientCacheWrapper(failing_cache, circuit_breaker) - + # Generate some operations logger.info("Generating operations for metrics") - + # Successful operations for i in range(5): await resilient_cache.store(f"metrics_key_{i}", f"metrics_content_{i}") - + # Failed operations failing_cache.set_failure_rate(1.0) for i in range(3): @@ -354,45 +373,48 @@ async def test_circuit_breaker_metrics(): await resilient_cache.store(f"fail_key_{i}", f"fail_content_{i}") except CacheError: pass - + # Get comprehensive metrics metrics = circuit_breaker.get_metrics() cache_metrics = resilient_cache.get_metrics() - - logger.info("Circuit breaker metrics", - state=metrics.state.value, - total_operations=metrics.total_operations, - success_count=metrics.success_count, - failure_count=metrics.failure_count, - failure_rate=metrics.failure_rate, - state_changes=metrics.state_changes, - recent_failures_count=len(metrics.recent_failures)) - - logger.info("Combined metrics structure", - cache_metrics_keys=list(cache_metrics.keys())) - + + logger.info( + "Circuit breaker metrics", + state=metrics.state.value, + total_operations=metrics.total_operations, + success_count=metrics.success_count, + failure_count=metrics.failure_count, + failure_rate=metrics.failure_rate, + state_changes=metrics.state_changes, + recent_failures_count=len(metrics.recent_failures), + ) + + logger.info( + "Combined metrics structure", cache_metrics_keys=list(cache_metrics.keys()) + ) + logger.info("āœ… Metrics test completed") async def test_recovery_scenarios(): """Test various recovery scenarios.""" logger.info("=== Testing Recovery Scenarios ===") - + config = CircuitBreakerConfig( failure_threshold=2, success_threshold=3, timeout_seconds=1.0, gradual_recovery=True, - recovery_factor=0.5 + recovery_factor=0.5, ) - + circuit_breaker = CacheCircuitBreaker(config) failing_cache = FailingCacheManager() resilient_cache = ResilientCacheWrapper(failing_cache, circuit_breaker) - + # Test scenario 1: Complete failure and recovery logger.info("Scenario 1: Complete failure and recovery") - + # Cause circuit to open failing_cache.set_failure_rate(1.0) for i in range(3): @@ -400,16 +422,16 @@ async def test_recovery_scenarios(): await resilient_cache.store(f"fail_{i}", "content") except CacheError: pass - + assert circuit_breaker.get_state() == CircuitState.OPEN logger.info("Circuit opened successfully") - + # Wait for timeout await asyncio.sleep(1.5) - + # Fix cache and recover failing_cache.set_failure_rate(0.0) - + # Gradual recovery recovery_attempts = 10 success_count = 0 @@ -420,32 +442,34 @@ async def test_recovery_scenarios(): except CacheError as e: if "THROTTLING" in str(e.error_code): logger.debug("Request throttled during recovery") - + final_state = circuit_breaker.get_state() - logger.info("Recovery scenario completed", - final_state=final_state.value, - recovery_success_rate=success_count / recovery_attempts) - + logger.info( + "Recovery scenario completed", + final_state=final_state.value, + recovery_success_rate=success_count / recovery_attempts, + ) + logger.info("āœ… Recovery scenarios test completed") async def run_all_tests(): """Run all circuit breaker and resilience tests.""" logger.info("🧪 Starting Cache Resilience Tests") - + try: await test_circuit_breaker_states() await test_graceful_degradation() await test_performance_under_failures() await test_circuit_breaker_metrics() await test_recovery_scenarios() - + logger.info("šŸŽ‰ All cache resilience tests passed!") - + except Exception as e: logger.error("āŒ Test failed", error=str(e), exc_info=True) raise if __name__ == "__main__": - asyncio.run(run_all_tests()) \ No newline at end of file + asyncio.run(run_all_tests()) diff --git a/scripts/test_fact_cache_integration.py b/scripts/test_fact_cache_integration.py index 397eff7..75e1d8a 100644 --- a/scripts/test_fact_cache_integration.py +++ b/scripts/test_fact_cache_integration.py @@ -35,7 +35,9 @@ import structlog except ImportError as e: print(f"āŒ Import error: {e}") - print("Make sure you're running this script from the project root and all dependencies are installed.") + print( + "Make sure you're running this script from the project root and all dependencies are installed." + ) sys.exit(1) # Configure logging @@ -43,7 +45,7 @@ processors=[ structlog.processors.TimeStamper(fmt="ISO"), structlog.processors.add_log_level, - structlog.processors.JSONRenderer() + structlog.processors.JSONRenderer(), ], logger_factory=structlog.PrintLoggerFactory(), wrapper_class=structlog.BoundLogger, @@ -56,14 +58,14 @@ def setup_test_environment() -> str: """ Set up test environment variables and create a temporary database. - + Returns: Path to temporary database file """ # Create temporary database - temp_db = tempfile.NamedTemporaryFile(suffix='.db', delete=False) + temp_db = tempfile.NamedTemporaryFile(suffix=".db", delete=False) temp_db.close() - + # Set required environment variables for testing test_env = { "DATABASE_PATH": temp_db.name, @@ -76,27 +78,32 @@ def setup_test_environment() -> str: "LOG_LEVEL": "INFO", "MAX_RETRIES": "2", "REQUEST_TIMEOUT": "30", - "SKIP_API_VALIDATION": "true" # Skip API validation for testing + "SKIP_API_VALIDATION": "true", # Skip API validation for testing } - + # Check if API keys are set, if not set test values if not os.getenv("ANTHROPIC_API_KEY"): test_env["ANTHROPIC_API_KEY"] = "test-key-for-integration-testing" - logger.warning("Using test API key for Anthropic - some features may be limited") - + logger.warning( + "Using test API key for Anthropic - some features may be limited" + ) + if not os.getenv("ARCADE_API_KEY"): test_env["ARCADE_API_KEY"] = "test-key-for-integration-testing" logger.warning("Using test API key for Arcade - some features may be limited") - + # Apply test environment for key, value in test_env.items(): os.environ[key] = value - - logger.info("Test environment configured", - database_path=temp_db.name, - api_keys_configured=bool(os.getenv("ANTHROPIC_API_KEY")) and bool(os.getenv("ARCADE_API_KEY")), - skip_api_validation=True) - + + logger.info( + "Test environment configured", + database_path=temp_db.name, + api_keys_configured=bool(os.getenv("ANTHROPIC_API_KEY")) + and bool(os.getenv("ARCADE_API_KEY")), + skip_api_validation=True, + ) + return temp_db.name @@ -113,103 +120,115 @@ def cleanup_test_environment(temp_db_path: str) -> None: async def test_fact_driver_initialization(): """Test FACT driver initialization with real components.""" logger.info("=== Testing FACT Driver Initialization ===") - + try: # Create configuration config = Config() logger.info("Configuration created successfully") - + # Initialize FACT driver driver = FACTDriver(config) logger.info("FACT driver instance created") - + # Test initialization await driver.initialize() logger.info("āœ… FACT driver initialized successfully") - + # Verify components are available assert driver.config is not None, "Configuration should be available" - assert hasattr(driver, 'cache_circuit_breaker'), "Circuit breaker should be available" - + assert hasattr( + driver, "cache_circuit_breaker" + ), "Circuit breaker should be available" + # Test metrics collection metrics = driver.get_metrics() logger.info("Initial metrics collected", metrics=metrics) - + assert "initialized" in metrics, "Metrics should include initialization status" assert metrics["initialized"] == True, "Driver should report as initialized" - + # Clean shutdown await driver.shutdown() logger.info("āœ… FACT driver initialization test passed") - + except ConfigurationError as e: logger.error("āŒ Configuration error during initialization", error=str(e)) logger.info("šŸ’” This is expected if API keys are not properly configured") raise except Exception as e: - logger.error("āŒ FACT driver initialization test failed", error=str(e), exc_info=True) + logger.error( + "āŒ FACT driver initialization test failed", error=str(e), exc_info=True + ) raise async def test_cache_resilience_features(): """Test cache resilience features without external API calls.""" logger.info("=== Testing Cache Resilience Features ===") - + try: config = Config() driver = FACTDriver(config) - + await driver.initialize() logger.info("Driver initialized for cache resilience testing") - + # Test 1: Circuit breaker manipulation - if hasattr(driver, 'cache_circuit_breaker') and driver.cache_circuit_breaker: + if hasattr(driver, "cache_circuit_breaker") and driver.cache_circuit_breaker: logger.info("Testing circuit breaker state changes") - + # Test initial state initial_state = driver.cache_circuit_breaker.get_state() logger.info("Initial circuit breaker state", state=initial_state.value) - + # Force open state driver.cache_circuit_breaker.force_open() open_state = driver.cache_circuit_breaker.get_state() - assert open_state == CircuitState.OPEN, "Circuit breaker should be in OPEN state" + assert ( + open_state == CircuitState.OPEN + ), "Circuit breaker should be in OPEN state" logger.info("āœ… Circuit breaker forced to OPEN state") - + # Force closed state driver.cache_circuit_breaker.force_closed() closed_state = driver.cache_circuit_breaker.get_state() - assert closed_state == CircuitState.CLOSED, "Circuit breaker should be in CLOSED state" + assert ( + closed_state == CircuitState.CLOSED + ), "Circuit breaker should be in CLOSED state" logger.info("āœ… Circuit breaker forced to CLOSED state") - + else: - logger.warning("Circuit breaker not available - cache resilience may be limited") - + logger.warning( + "Circuit breaker not available - cache resilience may be limited" + ) + # Test 2: Cache degradation mode logger.info("Testing cache degradation mode") - + # Simulate cache degraded mode - if hasattr(driver, '_cache_degraded'): - original_degraded = getattr(driver, '_cache_degraded', False) - + if hasattr(driver, "_cache_degraded"): + original_degraded = getattr(driver, "_cache_degraded", False) + # Set degraded mode driver._cache_degraded = True logger.info("Cache set to degraded mode") - + # Check metrics reflect degraded state metrics = driver.get_metrics() if "cache_degraded" in metrics: - assert metrics["cache_degraded"] == True, "Metrics should show cache degraded" + assert ( + metrics["cache_degraded"] == True + ), "Metrics should show cache degraded" logger.info("āœ… Metrics correctly show cache degraded state") - + # Restore original state driver._cache_degraded = original_degraded - + # Test 3: System metrics during different states logger.info("Testing metrics collection") - + final_metrics = driver.get_metrics() - + # Verify essential metrics are present essential_metrics = ["initialized", "total_queries"] for metric in essential_metrics: @@ -217,16 +236,17 @@ async def test_cache_resilience_features(): logger.info(f"āœ… Metric '{metric}' available: {final_metrics[metric]}") else: logger.warning(f"āš ļø Metric '{metric}' not available") - + # Log all cache and circuit breaker related metrics - cache_metrics = {k: v for k, v in final_metrics.items() - if k.startswith(('cache', 'circuit'))} + cache_metrics = { + k: v for k, v in final_metrics.items() if k.startswith(("cache", "circuit")) + } if cache_metrics: logger.info("Cache and circuit breaker metrics", metrics=cache_metrics) - + await driver.shutdown() logger.info("āœ… Cache resilience features test passed") - + except Exception as e: logger.error("āŒ Cache resilience test failed", error=str(e), exc_info=True) raise @@ -235,21 +255,21 @@ async def test_cache_resilience_features(): async def test_database_integration(): """Test database integration and connection handling.""" logger.info("=== Testing Database Integration ===") - + try: config = Config() driver = FACTDriver(config) - + # Test database initialization await driver.initialize() logger.info("Driver initialized for database testing") - + # Verify database file exists db_path = config.database_path if db_path != ":memory:": assert os.path.exists(db_path), f"Database file should exist at {db_path}" logger.info("āœ… Database file created", path=db_path) - + # Test basic database connectivity try: conn = sqlite3.connect(db_path) @@ -257,22 +277,26 @@ async def test_database_integration(): cursor.execute("SELECT name FROM sqlite_master WHERE type='table';") tables = cursor.fetchall() conn.close() - - logger.info("āœ… Database connectivity verified", table_count=len(tables)) - + + logger.info( + "āœ… Database connectivity verified", table_count=len(tables) + ) + except Exception as e: logger.warning("Database connectivity test failed", error=str(e)) else: logger.info("āœ… Using in-memory database") - + # Test metrics include database information metrics = driver.get_metrics() - logger.info("Database-related metrics collected", - has_db_metrics=any(k.startswith('db') for k in metrics.keys())) - + logger.info( + "Database-related metrics collected", + has_db_metrics=any(k.startswith("db") for k in metrics.keys()), + ) + await driver.shutdown() logger.info("āœ… Database integration test passed") - + except Exception as e: logger.error("āŒ Database integration test failed", error=str(e), exc_info=True) raise @@ -281,90 +305,104 @@ async def test_database_integration(): async def test_performance_monitoring(): """Test performance monitoring capabilities.""" logger.info("=== Testing Performance Monitoring ===") - + try: config = Config() driver = FACTDriver(config) - + await driver.initialize() logger.info("Driver initialized for performance testing") - + # Test timing measurements start_time = time.time() - + # Simulate some operations by getting metrics multiple times for i in range(5): metrics = driver.get_metrics() await asyncio.sleep(0.01) # Small delay to simulate work - + elapsed_time = time.time() - start_time - logger.info("Performance test completed", - operations=5, - total_time_ms=elapsed_time * 1000, - avg_time_ms=(elapsed_time / 5) * 1000) - + logger.info( + "Performance test completed", + operations=5, + total_time_ms=elapsed_time * 1000, + avg_time_ms=(elapsed_time / 5) * 1000, + ) + # Verify metrics are collected efficiently final_metrics = driver.get_metrics() - performance_keys = [k for k in final_metrics.keys() - if 'time' in k.lower() or 'latency' in k.lower() or 'performance' in k.lower()] - + performance_keys = [ + k + for k in final_metrics.keys() + if "time" in k.lower() + or "latency" in k.lower() + or "performance" in k.lower() + ] + if performance_keys: logger.info("āœ… Performance metrics available", keys=performance_keys) else: logger.info("ā„¹ļø No explicit performance metrics found") - + await driver.shutdown() logger.info("āœ… Performance monitoring test passed") - + except Exception as e: - logger.error("āŒ Performance monitoring test failed", error=str(e), exc_info=True) + logger.error( + "āŒ Performance monitoring test failed", error=str(e), exc_info=True + ) raise async def test_error_handling(): """Test error handling and recovery mechanisms.""" logger.info("=== Testing Error Handling ===") - + try: # Test with invalid configuration logger.info("Testing configuration error handling") - + # Save original environment original_db_path = os.getenv("DATABASE_PATH") - + # Set invalid database path to test error handling os.environ["DATABASE_PATH"] = "/invalid/path/that/does/not/exist.db" - + try: config = Config() driver = FACTDriver(config) await driver.initialize() - + # If we get here, the system handled the invalid path gracefully logger.info("āœ… System handled invalid database path gracefully") await driver.shutdown() - + except Exception as e: # This is expected for truly invalid paths - logger.info("āœ… System correctly raised error for invalid database path", error=str(e)) - + logger.info( + "āœ… System correctly raised error for invalid database path", + error=str(e), + ) + finally: # Restore original database path if original_db_path: os.environ["DATABASE_PATH"] = original_db_path - + # Test normal operation after error recovery logger.info("Testing recovery after error") config = Config() driver = FACTDriver(config) await driver.initialize() - + metrics = driver.get_metrics() - assert metrics["initialized"] == True, "Driver should initialize successfully after error recovery" - + assert ( + metrics["initialized"] == True + ), "Driver should initialize successfully after error recovery" + await driver.shutdown() logger.info("āœ… Error handling and recovery test passed") - + except Exception as e: logger.error("āŒ Error handling test failed", error=str(e), exc_info=True) raise @@ -373,33 +411,35 @@ async def test_error_handling(): async def run_all_integration_tests(): """Run all real integration tests without mocking.""" logger.info("🧪 Starting FACT Real Integration Tests (No Mocking)") - + temp_db_path = None - + try: # Set up test environment temp_db_path = setup_test_environment() logger.info("Test environment set up successfully") - + # Run all test suites await test_fact_driver_initialization() await test_cache_resilience_features() await test_database_integration() await test_performance_monitoring() await test_error_handling() - + logger.info("šŸŽ‰ All FACT real integration tests passed!") - + except ConfigurationError as e: logger.error("āŒ Configuration error in integration tests", error=str(e)) - logger.info("šŸ’” Ensure API keys are properly configured in environment variables") + logger.info( + "šŸ’” Ensure API keys are properly configured in environment variables" + ) logger.info("šŸ’” For testing, you can use placeholder values like 'test-key'") raise - + except Exception as e: logger.error("āŒ Integration tests failed", error=str(e), exc_info=True) raise - + finally: # Clean up test environment if temp_db_path: @@ -414,4 +454,4 @@ async def run_all_integration_tests(): sys.exit(1) except Exception as e: logger.error("Integration tests failed to run", error=str(e)) - sys.exit(1) \ No newline at end of file + sys.exit(1) diff --git a/scripts/test_response_padding.py b/scripts/test_response_padding.py index 712e17d..0332665 100644 --- a/scripts/test_response_padding.py +++ b/scripts/test_response_padding.py @@ -18,7 +18,7 @@ pad_response_for_caching, enhance_sql_tool_response, validate_enhanced_response, - _estimate_tokens + _estimate_tokens, ) from cache.manager import CacheManager from cache.config import get_default_cache_config @@ -26,9 +26,9 @@ def test_sql_response_padding(): """Test SQL response padding functionality.""" - + print("=== FACT Response Padding Test ===\n") - + # Simulate typical SQL tool responses that are too small for caching test_responses = { "simple_select": """ @@ -42,7 +42,6 @@ def test_sql_response_padding(): Results: 8 rows returned Execution time: 0.023 seconds """.strip(), - "insert_operation": """ INSERT INTO products (name, price, category_id, description) VALUES @@ -53,7 +52,6 @@ def test_sql_response_padding(): Query executed successfully. 3 rows inserted. """.strip(), - "table_creation": """ CREATE TABLE analytics_events ( id SERIAL PRIMARY KEY, @@ -66,106 +64,106 @@ def test_sql_response_padding(): ); Table created successfully. -""".strip() +""".strip(), } - + print("1. Testing original response token counts:") for name, response in test_responses.items(): tokens = _estimate_tokens(response) print(f" {name}: {tokens} tokens") - + print(f"\n2. Testing response padding (target: 500 tokens):") enhanced_responses = {} - + for name, response in test_responses.items(): print(f"\n Processing {name}...") - + try: # Test basic padding enhanced = pad_response_for_caching( - content=response, - content_type="sql", - target_tokens=500 + content=response, content_type="sql", target_tokens=500 ) - + enhanced_responses[name] = enhanced - + # Validate enhancement validation = validate_enhanced_response(response, enhanced, 500) - + print(f" āœ… Original: {validation['original_tokens']} tokens") print(f" āœ… Enhanced: {validation['enhanced_tokens']} tokens") print(f" āœ… Meets requirement: {validation['meets_requirement']}") print(f" āœ… Original preserved: {validation['original_preserved']}") - + except Exception as e: print(f" āŒ Failed to enhance {name}: {e}") - + print(f"\n3. Testing SQL-specific enhancement function:") - + # Test with query context context = { "query_type": "SELECT", "tables_accessed": ["users", "profiles"], "execution_time_ms": 23, - "rows_returned": 8 + "rows_returned": 8, } - + enhanced_with_context = enhance_sql_tool_response( sql_response=test_responses["simple_select"], query_context=context, - min_tokens=500 + min_tokens=500, ) - + context_validation = validate_enhanced_response( - test_responses["simple_select"], - enhanced_with_context, - 500 + test_responses["simple_select"], enhanced_with_context, 500 + ) + + print( + f" āœ… Enhanced with context: {context_validation['enhanced_tokens']} tokens" ) - - print(f" āœ… Enhanced with context: {context_validation['enhanced_tokens']} tokens") print(f" āœ… Context preserved: {'query_type' in enhanced_with_context}") - + print(f"\n4. Testing cache integration:") - + # Test that enhanced responses can be cached config = get_default_cache_config() cache_manager = CacheManager(config) - + successful_caches = 0 - + for name, enhanced_response in enhanced_responses.items(): try: query_hash = cache_manager.generate_hash(f"test_query_{name}") entry = cache_manager.store(query_hash, enhanced_response) - + print(f" āœ… Cached {name}: {entry.token_count} tokens") successful_caches += 1 - + # Verify retrieval retrieved = cache_manager.get(query_hash) if retrieved and retrieved.content == enhanced_response: print(f" āœ… Retrieved {name} successfully") else: print(f" āŒ Failed to retrieve {name}") - + except Exception as e: print(f" āŒ Failed to cache {name}: {e}") - + print(f"\n5. Performance and quality metrics:") print(f" āœ… Responses enhanced: {len(enhanced_responses)}/{len(test_responses)}") print(f" āœ… Successfully cached: {successful_caches}/{len(enhanced_responses)}") - + # Calculate average enhancement ratio total_ratio = 0 for name, response in test_responses.items(): if name in enhanced_responses: - validation = validate_enhanced_response(response, enhanced_responses[name], 500) - total_ratio += validation['enhancement_ratio'] - + validation = validate_enhanced_response( + response, enhanced_responses[name], 500 + ) + total_ratio += validation["enhancement_ratio"] + avg_ratio = total_ratio / len(enhanced_responses) if enhanced_responses else 0 print(f" āœ… Average enhancement ratio: {avg_ratio:.2f}x") - + print(f"\n=== Test Summary ===") print("āœ… SQL response padding successfully implemented") print("āœ… Small responses (320-368 tokens) enhanced to meet 500+ token requirement") @@ -173,55 +171,57 @@ def test_sql_response_padding(): print("āœ… Enhanced responses successfully cached and retrieved") print("āœ… Context-aware enhancement provides additional value") print("āœ… Performance metrics within acceptable ranges") - + return True def test_edge_cases(): """Test edge cases and error handling.""" - + print("\n=== Edge Case Testing ===\n") - + # Test empty content try: pad_response_for_caching("", "sql", 500) print("āŒ Should have failed on empty content") except ValueError: print("āœ… Correctly rejected empty content") - + # Test invalid target tokens try: pad_response_for_caching("test", "sql", 50) print("āŒ Should have failed on low target tokens") except ValueError: print("āœ… Correctly rejected invalid target tokens") - + # Test content that already meets requirements - long_content = "This is a long response with detailed information. " * 120 # ~600+ tokens + long_content = ( + "This is a long response with detailed information. " * 120 + ) # ~600+ tokens result = pad_response_for_caching(long_content, "sql", 500) - + if result == long_content: print("āœ… Correctly returned unmodified content when requirements already met") else: print("āŒ Unnecessarily modified content that already met requirements") - + # Test different content types json_content = '{"result": "success", "data": [1, 2, 3]}' json_enhanced = pad_response_for_caching(json_content, "json", 500) - + if _estimate_tokens(json_enhanced) >= 500: print("āœ… JSON content type padding works correctly") else: print("āŒ JSON content type padding failed") - + print("āœ… Edge case testing completed") def demonstrate_caching_improvement(): """Demonstrate the caching improvement for SQL responses.""" - + print("\n=== Caching Improvement Demonstration ===\n") - + # Simulate the original problem: SQL response too small for caching original_sql_response = """ SELECT COUNT(*) as total_users, @@ -233,52 +233,52 @@ def demonstrate_caching_improvement(): Result: total_users=1247, avg_age=32.4, last_signup=2024-01-15 Execution time: 0.045 seconds """.strip() - + print("Before enhancement:") original_tokens = _estimate_tokens(original_sql_response) print(f" Token count: {original_tokens}") print(f" Meets caching requirement (500+): {original_tokens >= 500}") - + # Try to cache original response config = get_default_cache_config() cache_manager = CacheManager(config) - + try: query_hash = cache_manager.generate_hash("demo_query") cache_manager.store(query_hash, original_sql_response) print(" āŒ Original response should not have been cacheable") except Exception: print(" āœ… Original response correctly rejected for caching") - + print("\nAfter enhancement:") enhanced_response = enhance_sql_tool_response( sql_response=original_sql_response, query_context={ "operation": "aggregation", "performance": "optimized", - "tables": ["users"] + "tables": ["users"], }, - min_tokens=500 + min_tokens=500, ) - + enhanced_tokens = _estimate_tokens(enhanced_response) print(f" Token count: {enhanced_tokens}") print(f" Meets caching requirement (500+): {enhanced_tokens >= 500}") - + try: query_hash = cache_manager.generate_hash("demo_query_enhanced") entry = cache_manager.store(query_hash, enhanced_response) print(f" āœ… Enhanced response successfully cached") - + # Verify retrieval retrieved = cache_manager.get(query_hash) if retrieved: print(f" āœ… Enhanced response successfully retrieved from cache") print(f" āœ… Cache hit preserves enhanced content and context") - + except Exception as e: print(f" āŒ Failed to cache enhanced response: {e}") - + print(f"\nImprovement summary:") print(f" Token increase: {enhanced_tokens - original_tokens} tokens") print(f" Enhancement ratio: {enhanced_tokens / original_tokens:.2f}x") @@ -288,15 +288,15 @@ def demonstrate_caching_improvement(): if __name__ == "__main__": print("Testing FACT system response padding utilities...\n") - + try: test_sql_response_padding() test_edge_cases() demonstrate_caching_improvement() - + print(f"\nšŸŽ‰ All tests completed successfully!") print(f"Response padding is ready for production use.") - + except Exception as e: print(f"\nāŒ Test failed with error: {e}") - sys.exit(1) \ No newline at end of file + sys.exit(1) diff --git a/scripts/validate_complete_system.py b/scripts/validate_complete_system.py index 6f55e52..9cdff32 100755 --- a/scripts/validate_complete_system.py +++ b/scripts/validate_complete_system.py @@ -25,7 +25,7 @@ from datetime import datetime # Add src to path for imports -sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'src')) +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "src")) from core.config import Config, ConfigurationError from core.driver import get_driver @@ -39,8 +39,15 @@ class ValidationResult: """Represents the result of a validation check.""" - - def __init__(self, component: str, test: str, success: bool, message: str, details: Dict[str, Any] = None): + + def __init__( + self, + component: str, + test: str, + success: bool, + message: str, + details: Dict[str, Any] = None, + ): self.component = component self.test = test self.success = success @@ -51,49 +58,55 @@ def __init__(self, component: str, test: str, success: bool, message: str, detai class SystemValidator: """Comprehensive FACT system validator.""" - + def __init__(self, verbose: bool = False, fix_issues: bool = False): self.verbose = verbose self.fix_issues = fix_issues self.results: List[ValidationResult] = [] self.setup_logging() - + def setup_logging(self): """Set up logging configuration.""" level = logging.DEBUG if self.verbose else logging.INFO logging.basicConfig( - level=level, - format='%(asctime)s - %(levelname)s - %(message)s' + level=level, format="%(asctime)s - %(levelname)s - %(message)s" ) self.logger = logging.getLogger(__name__) - - def add_result(self, component: str, test: str, success: bool, message: str, details: Dict[str, Any] = None): + + def add_result( + self, + component: str, + test: str, + success: bool, + message: str, + details: Dict[str, Any] = None, + ): """Add a validation result.""" result = ValidationResult(component, test, success, message, details) self.results.append(result) - + status = "āœ…" if success else "āŒ" print(f"{status} {component}: {test} - {message}") - + if self.verbose and details: for key, value in details.items(): print(f" {key}: {value}") - + def validate_environment_configuration(self) -> bool: """Validate environment configuration.""" print("\nšŸ”§ Validating Environment Configuration...") all_passed = True - + try: # Test .env file existence - env_path = Path('.env') + env_path = Path(".env") if env_path.exists(): self.add_result( - "Environment", - ".env file", - True, + "Environment", + ".env file", + True, f"Found at {env_path.absolute()}", - {"file_size": env_path.stat().st_size} + {"file_size": env_path.stat().st_size}, ) else: self.add_result("Environment", ".env file", False, "Missing .env file") @@ -101,256 +114,337 @@ def validate_environment_configuration(self) -> bool: print(" šŸ”§ Creating .env file with defaults...") self.create_default_env_file() all_passed = False - + # Test configuration loading config = Config() self.add_result( - "Environment", - "Configuration loading", - True, + "Environment", + "Configuration loading", + True, "Configuration loaded successfully", - config.to_dict() + config.to_dict(), ) - + # Validate API keys - if config.anthropic_api_key and not config.anthropic_api_key.startswith('your_'): - self.add_result("Environment", "Anthropic API key", True, "Valid API key configured") + if config.anthropic_api_key and not config.anthropic_api_key.startswith( + "your_" + ): + self.add_result( + "Environment", "Anthropic API key", True, "Valid API key configured" + ) else: - self.add_result("Environment", "Anthropic API key", False, "API key not configured or placeholder") + self.add_result( + "Environment", + "Anthropic API key", + False, + "API key not configured or placeholder", + ) all_passed = False - - if config.arcade_api_key and not config.arcade_api_key.startswith('your_'): - self.add_result("Environment", "Arcade API key", True, "Valid API key configured") + + if config.arcade_api_key and not config.arcade_api_key.startswith("your_"): + self.add_result( + "Environment", "Arcade API key", True, "Valid API key configured" + ) else: - self.add_result("Environment", "Arcade API key", False, "API key not configured or placeholder") + self.add_result( + "Environment", + "Arcade API key", + False, + "API key not configured or placeholder", + ) all_passed = False - + return all_passed - + except ConfigurationError as e: self.add_result("Environment", "Configuration loading", False, str(e)) return False except Exception as e: - self.add_result("Environment", "Configuration loading", False, f"Unexpected error: {e}") + self.add_result( + "Environment", "Configuration loading", False, f"Unexpected error: {e}" + ) return False - + async def validate_database_integration(self) -> bool: """Validate database integration.""" print("\nšŸ—„ļø Validating Database Integration...") all_passed = True - + try: config = Config() db_manager = DatabaseManager(config.database_path) - + # Test database initialization await db_manager.initialize_database() - self.add_result("Database", "Initialization", True, "Database initialized successfully") - + self.add_result( + "Database", "Initialization", True, "Database initialized successfully" + ) + # Test database info retrieval db_info = await db_manager.get_database_info() self.add_result( - "Database", - "Schema validation", - True, + "Database", + "Schema validation", + True, f"Found {db_info['total_tables']} tables", - db_info + db_info, ) - + # Validate required tables required_tables = ["companies", "financial_data", "benchmarks"] for table in required_tables: - if table in db_info['tables']: - row_count = db_info['tables'][table]['row_count'] + if table in db_info["tables"]: + row_count = db_info["tables"][table]["row_count"] self.add_result( - "Database", - f"Table {table}", - True, + "Database", + f"Table {table}", + True, f"{row_count} rows", - {"row_count": row_count} + {"row_count": row_count}, ) else: - self.add_result("Database", f"Table {table}", False, "Table missing") + self.add_result( + "Database", f"Table {table}", False, "Table missing" + ) all_passed = False - + return all_passed - + except Exception as e: self.add_result("Database", "Integration", False, f"Database error: {e}") return False - + async def validate_api_connectivity(self) -> bool: """Validate API connectivity.""" print("\n🌐 Validating API Connectivity...") all_passed = True - + try: config = Config() - + # Skip API tests if keys are not configured - if (not config.anthropic_api_key or config.anthropic_api_key.startswith('your_') or - not config.arcade_api_key or config.arcade_api_key.startswith('your_')): - self.add_result("API", "Connectivity test", False, "API keys not configured - skipping connectivity tests") + if ( + not config.anthropic_api_key + or config.anthropic_api_key.startswith("your_") + or not config.arcade_api_key + or config.arcade_api_key.startswith("your_") + ): + self.add_result( + "API", + "Connectivity test", + False, + "API keys not configured - skipping connectivity tests", + ) return False - + # Test driver initialization (includes API connectivity) try: driver = await get_driver() - self.add_result("API", "Driver initialization", True, "Driver initialized successfully") - + self.add_result( + "API", + "Driver initialization", + True, + "Driver initialized successfully", + ) + # Test basic functionality metrics = driver.get_metrics() self.add_result( - "API", - "System metrics", - True, + "API", + "System metrics", + True, f"System operational - {len(driver.tool_registry.list_tools())} tools available", - metrics + metrics, ) - + await driver.shutdown() return True - + except Exception as e: - self.add_result("API", "Driver initialization", False, f"Driver error: {e}") + self.add_result( + "API", "Driver initialization", False, f"Driver error: {e}" + ) return False - + except Exception as e: self.add_result("API", "Connectivity", False, f"API error: {e}") return False - + async def validate_cache_system(self) -> bool: """Validate cache system.""" print("\nšŸ’¾ Validating Cache System...") all_passed = True - + try: config = Config() cache_config = config.cache_config - + self.add_result( - "Cache", - "Configuration", - True, + "Cache", + "Configuration", + True, "Cache configuration loaded", - cache_config + cache_config, ) - + # Test cache validator validator = CacheValidator() - self.add_result("Cache", "Validator initialization", True, "Cache validator initialized") - + self.add_result( + "Cache", "Validator initialization", True, "Cache validator initialized" + ) + # Note: Actual cache performance testing requires live system # This validates the cache system is properly configured - + return all_passed - + except Exception as e: self.add_result("Cache", "System", False, f"Cache error: {e}") return False - + def validate_security_layer(self) -> bool: """Validate security layer.""" print("\nšŸ›”ļø Validating Security Layer...") all_passed = True - + try: # Test authorization manager initialization auth_manager = AuthorizationManager() - self.add_result("Security", "Authorization manager", True, "Authorization manager initialized") - + self.add_result( + "Security", + "Authorization manager", + True, + "Authorization manager initialized", + ) + # Test input sanitizer initialization input_sanitizer = InputSanitizer() - self.add_result("Security", "Input sanitizer", True, "Input sanitizer initialized") - + self.add_result( + "Security", "Input sanitizer", True, "Input sanitizer initialized" + ) + # Test input sanitization test_inputs = [ ("Normal query", "What is TechCorp's revenue?", True), ("XSS attempt", "", False), ("SQL injection", "'; DROP TABLE users; --", False), - ("Valid financial query", "Show me Q1 2025 financial data", True) + ("Valid financial query", "Show me Q1 2025 financial data", True), ] - + for test_name, test_input, should_pass in test_inputs: try: result = input_sanitizer.sanitize_string(test_input) if should_pass: - self.add_result("Security", f"Input validation: {test_name}", True, "Input accepted") + self.add_result( + "Security", + f"Input validation: {test_name}", + True, + "Input accepted", + ) else: - self.add_result("Security", f"Input validation: {test_name}", False, "Malicious input should be blocked") + self.add_result( + "Security", + f"Input validation: {test_name}", + False, + "Malicious input should be blocked", + ) all_passed = False except Exception: if not should_pass: - self.add_result("Security", f"Input validation: {test_name}", True, "Malicious input blocked") + self.add_result( + "Security", + f"Input validation: {test_name}", + True, + "Malicious input blocked", + ) else: - self.add_result("Security", f"Input validation: {test_name}", False, "Valid input rejected") + self.add_result( + "Security", + f"Input validation: {test_name}", + False, + "Valid input rejected", + ) all_passed = False - + return all_passed - + except Exception as e: self.add_result("Security", "Layer", False, f"Security error: {e}") return False - + async def validate_benchmark_framework(self) -> bool: """Validate benchmark framework.""" print("\nšŸ“Š Validating Benchmark Framework...") all_passed = True - + try: # Test framework initialization framework = BenchmarkFramework() - self.add_result("Benchmark", "Framework initialization", True, "Framework initialized") - + self.add_result( + "Benchmark", "Framework initialization", True, "Framework initialized" + ) + # Test configuration config_details = { "iterations": framework.config.iterations, "warmup_iterations": framework.config.warmup_iterations, "concurrent_users": framework.config.concurrent_users, - "timeout_seconds": framework.config.timeout_seconds + "timeout_seconds": framework.config.timeout_seconds, } - + self.add_result( - "Benchmark", - "Configuration", - True, + "Benchmark", + "Configuration", + True, "Benchmark configuration loaded", - config_details + config_details, ) - + # Test empty benchmark suite (should handle gracefully) summary = await framework.run_benchmark_suite([]) if summary.total_queries == 0: - self.add_result("Benchmark", "Empty suite handling", True, "Empty benchmark suite handled correctly") + self.add_result( + "Benchmark", + "Empty suite handling", + True, + "Empty benchmark suite handled correctly", + ) else: - self.add_result("Benchmark", "Empty suite handling", False, "Empty benchmark suite not handled correctly") + self.add_result( + "Benchmark", + "Empty suite handling", + False, + "Empty benchmark suite not handled correctly", + ) all_passed = False - + return all_passed - + except Exception as e: self.add_result("Benchmark", "Framework", False, f"Benchmark error: {e}") return False - + def validate_monitoring_system(self) -> bool: """Validate monitoring system.""" print("\nšŸ“ˆ Validating Monitoring System...") all_passed = True - + try: # Test metrics collector metrics_collector = MetricsCollector() - self.add_result("Monitoring", "Metrics collector", True, "Metrics collector initialized") - + self.add_result( + "Monitoring", "Metrics collector", True, "Metrics collector initialized" + ) + return all_passed - + except Exception as e: self.add_result("Monitoring", "System", False, f"Monitoring error: {e}") return False - + def create_default_env_file(self): """Create a default .env file.""" - env_content = '''# FACT System Configuration + env_content = """# FACT System Configuration # Update the API keys below with your actual credentials # Required API Keys - MUST be configured for system to work @@ -374,80 +468,84 @@ def create_default_env_file(self): CACHE_HIT_TARGET_MS=30 CACHE_MISS_TARGET_MS=120 MAX_CONCURRENT_QUERIES=50 -''' - - with open('.env', 'w') as f: +""" + + with open(".env", "w") as f: f.write(env_content) print(" šŸ“„ Created .env file with default configuration") - + def print_summary(self): """Print validation summary.""" - print("\n" + "="*60) + print("\n" + "=" * 60) print("šŸŽÆ FACT SYSTEM VALIDATION SUMMARY") - print("="*60) - + print("=" * 60) + # Count results by component component_results = {} for result in self.results: if result.component not in component_results: component_results[result.component] = {"passed": 0, "failed": 0} - + if result.success: component_results[result.component]["passed"] += 1 else: component_results[result.component]["failed"] += 1 - + # Print component summary total_passed = 0 total_failed = 0 - + for component, counts in component_results.items(): passed = counts["passed"] failed = counts["failed"] total = passed + failed - + status = "āœ…" if failed == 0 else "āŒ" print(f"{status} {component}: {passed}/{total} tests passed") - + total_passed += passed total_failed += failed - + print("-" * 60) - + overall_total = total_passed + total_failed success_rate = (total_passed / overall_total * 100) if overall_total > 0 else 0 - + if total_failed == 0: - print(f"šŸŽ‰ ALL SYSTEMS OPERATIONAL: {total_passed}/{overall_total} tests passed ({success_rate:.1f}%)") + print( + f"šŸŽ‰ ALL SYSTEMS OPERATIONAL: {total_passed}/{overall_total} tests passed ({success_rate:.1f}%)" + ) print("\nāœ… Your FACT system is fully integrated and ready for use!") print("\nNext steps:") print("1. Start the CLI: python main.py cli") print("2. Run benchmarks: python scripts/run_benchmarks.py") print("3. Try sample queries in the interactive CLI") else: - print(f"āš ļø ISSUES FOUND: {total_passed}/{overall_total} tests passed ({success_rate:.1f}%)") + print( + f"āš ļø ISSUES FOUND: {total_passed}/{overall_total} tests passed ({success_rate:.1f}%)" + ) print(f"\nāŒ {total_failed} issue(s) need attention:") - + for result in self.results: if not result.success: print(f" • {result.component}: {result.test} - {result.message}") - + print("\nšŸ’” Recommendations:") if any("API key" in r.test for r in self.results if not r.success): print(" 1. Update API keys in .env file with actual credentials") print(" 2. Get Anthropic API key: https://console.anthropic.com/") print(" 3. Get Arcade API key: https://www.arcade-ai.com/") - + if any("Database" in r.component for r in self.results if not r.success): print(" 4. Run: python main.py init") - + print(" 5. Re-run validation: python scripts/validate_complete_system.py") - + async def run_all_validations(self) -> bool: """Run all validation checks.""" print("šŸš€ FACT System Complete Validation") - print("="*50) - + print("=" * 50) + validations = [ ("Environment Configuration", self.validate_environment_configuration()), ("Database Integration", self.validate_database_integration()), @@ -455,25 +553,25 @@ async def run_all_validations(self) -> bool: ("Cache System", self.validate_cache_system()), ("Security Layer", self.validate_security_layer()), ("Benchmark Framework", self.validate_benchmark_framework()), - ("Monitoring System", self.validate_monitoring_system()) + ("Monitoring System", self.validate_monitoring_system()), ] - + all_passed = True - + for name, validation in validations: try: if asyncio.iscoroutine(validation): result = await validation else: result = validation - + if not result: all_passed = False - + except Exception as e: self.add_result("System", name, False, f"Validation failed: {e}") all_passed = False - + self.print_summary() return all_passed @@ -482,29 +580,30 @@ async def main(): """Main validation routine.""" parser = argparse.ArgumentParser( description="FACT System Complete Validation", - formatter_class=argparse.RawDescriptionHelpFormatter + formatter_class=argparse.RawDescriptionHelpFormatter, ) - + parser.add_argument( - "--verbose", "-v", + "--verbose", + "-v", action="store_true", - help="Enable verbose output with detailed information" + help="Enable verbose output with detailed information", ) - + parser.add_argument( "--fix-issues", - action="store_true", - help="Attempt to automatically fix common issues" + action="store_true", + help="Attempt to automatically fix common issues", ) - + args = parser.parse_args() - + validator = SystemValidator(verbose=args.verbose, fix_issues=args.fix_issues) success = await validator.run_all_validations() - + # Exit with appropriate code sys.exit(0 if success else 1) if __name__ == "__main__": - asyncio.run(main()) \ No newline at end of file + asyncio.run(main()) diff --git a/scripts/validate_env.py b/scripts/validate_env.py index 90ac091..e620571 100755 --- a/scripts/validate_env.py +++ b/scripts/validate_env.py @@ -19,336 +19,401 @@ from typing import Dict, List, Tuple, Optional, Any # Add src to path for imports -sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'src')) +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "src")) try: from dotenv import load_dotenv except ImportError: - print("Warning: python-dotenv not installed. Install with: pip install python-dotenv") + print( + "Warning: python-dotenv not installed. Install with: pip install python-dotenv" + ) load_dotenv = None class ConfigurationValidator: """Validates FACT system configuration against requirements.""" - + # API Key validation patterns from specification API_KEY_PATTERNS = { - 'ANTHROPIC_API_KEY': r'^sk-ant-api03-[A-Za-z0-9_-]+$', - 'ARCADE_API_KEY': r'^arc_[A-Za-z0-9_-]+$', - 'OPENAI_API_KEY': r'^sk-proj-[A-Za-z0-9_-]+$' + "ANTHROPIC_API_KEY": r"^sk-ant-api03-[A-Za-z0-9_-]+$", + "ARCADE_API_KEY": r"^arc_[A-Za-z0-9_-]+$", + "OPENAI_API_KEY": r"^sk-proj-[A-Za-z0-9_-]+$", } - + # Required configuration keys - REQUIRED_KEYS = ['ANTHROPIC_API_KEY', 'ARCADE_API_KEY'] - + REQUIRED_KEYS = ["ANTHROPIC_API_KEY", "ARCADE_API_KEY"] + # Parameter validation rules PARAMETER_RULES = { - 'MAX_RETRIES': {'type': int, 'min': 1, 'max': 10, 'default': 3}, - 'REQUEST_TIMEOUT': {'type': int, 'min': 5, 'max': 300, 'default': 30}, - 'CACHE_TTL_SECONDS': {'type': int, 'min': 60, 'max': 86400, 'default': 3600}, - 'CACHE_MIN_TOKENS': {'type': int, 'min': 1, 'max': 1000, 'default': 50}, - 'VALIDATION_MAX_STRING_LENGTH': {'type': int, 'min': 1, 'max': 100000, 'default': 10000}, - 'LOG_LEVEL': {'type': str, 'choices': ['DEBUG', 'INFO', 'WARNING', 'ERROR'], 'default': 'INFO'}, - 'CLAUDE_MODEL': {'type': str, 'pattern': r'^claude-3-', 'default': 'claude-3-5-sonnet-20241022'}, - 'CACHE_PREFIX': {'type': str, 'pattern': r'^[A-Za-z0-9_]+$', 'min_length': 1, 'max_length': 50, 'default': 'fact_v1'} + "MAX_RETRIES": {"type": int, "min": 1, "max": 10, "default": 3}, + "REQUEST_TIMEOUT": {"type": int, "min": 5, "max": 300, "default": 30}, + "CACHE_TTL_SECONDS": {"type": int, "min": 60, "max": 86400, "default": 3600}, + "CACHE_MIN_TOKENS": {"type": int, "min": 1, "max": 1000, "default": 50}, + "VALIDATION_MAX_STRING_LENGTH": { + "type": int, + "min": 1, + "max": 100000, + "default": 10000, + }, + "LOG_LEVEL": { + "type": str, + "choices": ["DEBUG", "INFO", "WARNING", "ERROR"], + "default": "INFO", + }, + "CLAUDE_MODEL": { + "type": str, + "pattern": r"^claude-3-", + "default": "claude-3-5-sonnet-20241022", + }, + "CACHE_PREFIX": { + "type": str, + "pattern": r"^[A-Za-z0-9_]+$", + "min_length": 1, + "max_length": 50, + "default": "fact_v1", + }, } - + def __init__(self, verbose: bool = False): self.verbose = verbose self.errors: List[str] = [] self.warnings: List[str] = [] self.config: Dict[str, Any] = {} - + def load_environment(self) -> bool: """Load environment configuration from .env file and system environment.""" try: # Load .env file if it exists - env_path = Path('.env') + env_path = Path(".env") if env_path.exists(): if load_dotenv: load_dotenv(env_path) self.log_info(f"āœ… Loaded configuration from {env_path}") else: - self.log_warning("āš ļø python-dotenv not available, using system environment only") + self.log_warning( + "āš ļø python-dotenv not available, using system environment only" + ) else: - self.log_warning("āš ļø No .env file found, using system environment variables") - + self.log_warning( + "āš ļø No .env file found, using system environment variables" + ) + # Load all environment variables self.config = dict(os.environ) return True - + except Exception as e: self.errors.append(f"Failed to load environment: {e}") return False - + def validate_required_keys(self) -> bool: """Validate that all required API keys are present.""" missing_keys = [] invalid_keys = [] - + for key in self.REQUIRED_KEYS: - value = self.config.get(key, '').strip() - + value = self.config.get(key, "").strip() + if not value or value == f"your_{key.lower()}_here": missing_keys.append(key) continue - + # Validate format pattern = self.API_KEY_PATTERNS.get(key) if pattern and not re.match(pattern, value): invalid_keys.append(f"{key} (format: {pattern})") - + if missing_keys: - self.errors.append(f"Missing required configuration keys: {', '.join(missing_keys)}") - + self.errors.append( + f"Missing required configuration keys: {', '.join(missing_keys)}" + ) + if invalid_keys: self.errors.append(f"Invalid API key format: {', '.join(invalid_keys)}") - + return len(missing_keys) == 0 and len(invalid_keys) == 0 - + def validate_optional_keys(self) -> bool: """Validate optional API keys if present.""" invalid_keys = [] - + for key, pattern in self.API_KEY_PATTERNS.items(): if key in self.REQUIRED_KEYS: continue - - value = self.config.get(key, '').strip() + + value = self.config.get(key, "").strip() if value and value != f"your_{key.lower()}_here": if not re.match(pattern, value): invalid_keys.append(f"{key} (format: {pattern})") - + if invalid_keys: - self.errors.append(f"Invalid optional API key format: {', '.join(invalid_keys)}") - + self.errors.append( + f"Invalid optional API key format: {', '.join(invalid_keys)}" + ) + return len(invalid_keys) == 0 - + def validate_parameters(self) -> bool: """Validate configuration parameters against rules.""" parameter_errors = [] - + for param, rules in self.PARAMETER_RULES.items(): value = self.config.get(param) - + if value is None: if self.verbose: - self.log_info(f"šŸ“‹ {param}: Using default value ({rules.get('default')})") + self.log_info( + f"šŸ“‹ {param}: Using default value ({rules.get('default')})" + ) continue - + # Type validation - if rules['type'] == int: + if rules["type"] == int: try: int_value = int(value) - if 'min' in rules and int_value < rules['min']: - parameter_errors.append(f"{param}={value} below minimum {rules['min']}") - elif 'max' in rules and int_value > rules['max']: - parameter_errors.append(f"{param}={value} above maximum {rules['max']}") + if "min" in rules and int_value < rules["min"]: + parameter_errors.append( + f"{param}={value} below minimum {rules['min']}" + ) + elif "max" in rules and int_value > rules["max"]: + parameter_errors.append( + f"{param}={value} above maximum {rules['max']}" + ) except ValueError: parameter_errors.append(f"{param}={value} is not a valid integer") - - elif rules['type'] == str: - if 'choices' in rules and value not in rules['choices']: - parameter_errors.append(f"{param}={value} not in {rules['choices']}") - elif 'pattern' in rules and not re.match(rules['pattern'], value): - parameter_errors.append(f"{param}={value} doesn't match pattern {rules['pattern']}") - elif 'min_length' in rules and len(value) < rules['min_length']: - parameter_errors.append(f"{param}={value} too short (min: {rules['min_length']})") - elif 'max_length' in rules and len(value) > rules['max_length']: - parameter_errors.append(f"{param}={value} too long (max: {rules['max_length']})") - + + elif rules["type"] == str: + if "choices" in rules and value not in rules["choices"]: + parameter_errors.append( + f"{param}={value} not in {rules['choices']}" + ) + elif "pattern" in rules and not re.match(rules["pattern"], value): + parameter_errors.append( + f"{param}={value} doesn't match pattern {rules['pattern']}" + ) + elif "min_length" in rules and len(value) < rules["min_length"]: + parameter_errors.append( + f"{param}={value} too short (min: {rules['min_length']})" + ) + elif "max_length" in rules and len(value) > rules["max_length"]: + parameter_errors.append( + f"{param}={value} too long (max: {rules['max_length']})" + ) + if parameter_errors: self.errors.extend(parameter_errors) - + return len(parameter_errors) == 0 - + def validate_cache_size(self) -> bool: """Validate cache size parameter format.""" - cache_size = self.config.get('CACHE_MAX_SIZE', '').strip() - + cache_size = self.config.get("CACHE_MAX_SIZE", "").strip() + if not cache_size: if self.verbose: self.log_info("šŸ“‹ CACHE_MAX_SIZE: Using default value (100MB)") return True - + # Validate cache size format - cache_pattern = r'^(\d+)(K|M|G|T)?B$' + cache_pattern = r"^(\d+)(K|M|G|T)?B$" match = re.match(cache_pattern, cache_size.upper()) - + if not match: - self.errors.append(f"CACHE_MAX_SIZE={cache_size} invalid format (expected: [number][KMGT]B)") + self.errors.append( + f"CACHE_MAX_SIZE={cache_size} invalid format (expected: [number][KMGT]B)" + ) return False - + # Convert to bytes for range validation size_value = int(match.group(1)) - unit = match.group(2) or '' - - multipliers = {'': 1, 'K': 1024, 'M': 1024**2, 'G': 1024**3, 'T': 1024**4} + unit = match.group(2) or "" + + multipliers = {"": 1, "K": 1024, "M": 1024**2, "G": 1024**3, "T": 1024**4} bytes_value = size_value * multipliers[unit] - + min_bytes = 1024**2 # 1MB max_bytes = 10 * 1024**3 # 10GB - + if bytes_value < min_bytes: self.errors.append(f"CACHE_MAX_SIZE={cache_size} below minimum 1MB") return False elif bytes_value > max_bytes: self.errors.append(f"CACHE_MAX_SIZE={cache_size} above maximum 10GB") return False - + return True - + def validate_file_paths(self) -> bool: """Validate file path configurations.""" path_errors = [] - + # Check database path - db_path = self.config.get('DATABASE_PATH', 'data/fact_demo.db') + db_path = self.config.get("DATABASE_PATH", "data/fact_demo.db") db_dir = Path(db_path).parent - + if not db_dir.exists(): try: db_dir.mkdir(parents=True, exist_ok=True) self.log_info(f"āœ… Created database directory: {db_dir}") except Exception as e: path_errors.append(f"Cannot create database directory {db_dir}: {e}") - + if path_errors: self.errors.extend(path_errors) - + return len(path_errors) == 0 - + def check_connectivity(self) -> bool: """Test connectivity to configured services (optional).""" - if not self.config.get('ANTHROPIC_API_KEY') or not self.config.get('ARCADE_API_KEY'): - self.warnings.append("āš ļø Skipping connectivity tests - API keys not configured") + if not self.config.get("ANTHROPIC_API_KEY") or not self.config.get( + "ARCADE_API_KEY" + ): + self.warnings.append( + "āš ļø Skipping connectivity tests - API keys not configured" + ) return True - + connectivity_results = [] - + # Test Anthropic API try: import requests - response = requests.get('https://api.anthropic.com', timeout=10) - connectivity_results.append(f"āœ… Anthropic API reachable (status: {response.status_code})") + + response = requests.get("https://api.anthropic.com", timeout=10) + connectivity_results.append( + f"āœ… Anthropic API reachable (status: {response.status_code})" + ) except Exception as e: connectivity_results.append(f"āŒ Anthropic API unreachable: {e}") - + # Test Arcade API - arcade_url = self.config.get('ARCADE_BASE_URL', 'https://api.arcade-ai.com') + arcade_url = self.config.get("ARCADE_BASE_URL", "https://api.arcade-ai.com") try: import requests + response = requests.get(arcade_url, timeout=10) - connectivity_results.append(f"āœ… Arcade API reachable (status: {response.status_code})") + connectivity_results.append( + f"āœ… Arcade API reachable (status: {response.status_code})" + ) except Exception as e: connectivity_results.append(f"āŒ Arcade API unreachable: {e}") - + if self.verbose: for result in connectivity_results: print(result) - + return True - + def generate_recommendations(self) -> List[str]: """Generate configuration recommendations.""" recommendations = [] - + # Check for optimal settings - if not self.config.get('STRICT_MODE') or self.config.get('STRICT_MODE').lower() != 'true': + if ( + not self.config.get("STRICT_MODE") + or self.config.get("STRICT_MODE").lower() != "true" + ): recommendations.append("Enable STRICT_MODE=true for production security") - - if not self.config.get('RATE_LIMITING_ENABLED') or self.config.get('RATE_LIMITING_ENABLED').lower() != 'true': + + if ( + not self.config.get("RATE_LIMITING_ENABLED") + or self.config.get("RATE_LIMITING_ENABLED").lower() != "true" + ): recommendations.append("Enable RATE_LIMITING_ENABLED=true for production") - - if self.config.get('DEBUG_MODE', '').lower() == 'true': + + if self.config.get("DEBUG_MODE", "").lower() == "true": recommendations.append("Disable DEBUG_MODE=false for production") - - if self.config.get('LOG_LEVEL', '').upper() == 'DEBUG': + + if self.config.get("LOG_LEVEL", "").upper() == "DEBUG": recommendations.append("Use LOG_LEVEL=INFO or higher for production") - + return recommendations - + def run_validation(self, check_connectivity: bool = False) -> bool: """Run complete validation suite.""" self.log_info("šŸ” FACT Environment Configuration Validation") self.log_info("=" * 50) - + # Step 1: Load environment if not self.load_environment(): return False - + # Step 2: Validate required keys self.log_info("\nšŸ“‹ Validating required API keys...") self.validate_required_keys() - + # Step 3: Validate optional keys self.log_info("šŸ“‹ Validating optional API keys...") self.validate_optional_keys() - + # Step 4: Validate parameters self.log_info("šŸ“‹ Validating configuration parameters...") self.validate_parameters() - + # Step 5: Validate cache size self.log_info("šŸ“‹ Validating cache configuration...") self.validate_cache_size() - + # Step 6: Validate file paths self.log_info("šŸ“‹ Validating file paths...") self.validate_file_paths() - + # Step 7: Connectivity check (optional) if check_connectivity: self.log_info("šŸ“‹ Testing service connectivity...") self.check_connectivity() - + # Step 8: Generate recommendations recommendations = self.generate_recommendations() - + # Report results self.print_results(recommendations) - + return len(self.errors) == 0 - + def print_results(self, recommendations: List[str]): """Print validation results.""" print(f"\n{'=' * 50}") print("šŸ“Š VALIDATION RESULTS") print(f"{'=' * 50}") - + if self.errors: print(f"\nāŒ ERRORS ({len(self.errors)}):") for error in self.errors: print(f" • {error}") - + if self.warnings: print(f"\nāš ļø WARNINGS ({len(self.warnings)}):") for warning in self.warnings: print(f" • {warning}") - + if recommendations: print(f"\nšŸ’” RECOMMENDATIONS ({len(recommendations)}):") for rec in recommendations: print(f" • {rec}") - + if not self.errors and not self.warnings: print("\nāœ… Configuration validation passed!") print("šŸŽ‰ Your FACT system is ready to use!") elif not self.errors: - print(f"\nāœ… Configuration validation passed with {len(self.warnings)} warnings") + print( + f"\nāœ… Configuration validation passed with {len(self.warnings)} warnings" + ) else: - print(f"\nāŒ Configuration validation failed with {len(self.errors)} errors") + print( + f"\nāŒ Configuration validation failed with {len(self.errors)} errors" + ) print("\nšŸ”§ RECOVERY SUGGESTIONS:") print(" 1. Update your .env file with correct API keys") print(" 2. Verify parameter values are within allowed ranges") print(" 3. Check file paths and permissions") print(" 4. Run: python scripts/validate_env.py --verbose") - + def log_info(self, message: str): """Log info message if verbose mode is enabled.""" if self.verbose: print(message) - + def log_warning(self, message: str): """Log warning message.""" self.warnings.append(message) @@ -370,30 +435,29 @@ def main(): Exit codes: 0 = Validation passed 1 = Validation failed - """ + """, ) - + parser.add_argument( - '--verbose', '-v', - action='store_true', - help='Enable verbose output' + "--verbose", "-v", action="store_true", help="Enable verbose output" ) - + parser.add_argument( - '--check-connectivity', '-c', - action='store_true', - help='Test connectivity to external services' + "--check-connectivity", + "-c", + action="store_true", + help="Test connectivity to external services", ) - + args = parser.parse_args() - + # Run validation validator = ConfigurationValidator(verbose=args.verbose) success = validator.run_validation(check_connectivity=args.check_connectivity) - + # Exit with appropriate code sys.exit(0 if success else 1) if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/src/__init__.py b/src/__init__.py index a19409b..2894815 100644 --- a/src/__init__.py +++ b/src/__init__.py @@ -5,4 +5,4 @@ """ __version__ = "1.0.0" -__author__ = "FACT Development Team" \ No newline at end of file +__author__ = "FACT Development Team" diff --git a/src/arcade/__init__.py b/src/arcade/__init__.py index c7d200f..95c1884 100644 --- a/src/arcade/__init__.py +++ b/src/arcade/__init__.py @@ -17,9 +17,9 @@ from arcade.errors import ArcadeError, ArcadeConnectionError, ArcadeExecutionError __all__ = [ - 'ArcadeClient', - 'ArcadeGateway', - 'ArcadeError', - 'ArcadeConnectionError', - 'ArcadeExecutionError' -] \ No newline at end of file + "ArcadeClient", + "ArcadeGateway", + "ArcadeError", + "ArcadeConnectionError", + "ArcadeExecutionError", +] diff --git a/src/arcade/client.py b/src/arcade/client.py index 677ec5d..a7bf426 100644 --- a/src/arcade/client.py +++ b/src/arcade/client.py @@ -14,6 +14,7 @@ # Import arcadepy when available try: import arcade + ARCADE_AVAILABLE = True except ImportError: ARCADE_AVAILABLE = False @@ -27,25 +28,26 @@ ArcadeExecutionError, ArcadeTimeoutError, ArcadeRegistrationError, - ArcadeSerializationError + ArcadeSerializationError, ) from ..core.errors import ToolExecutionError except ImportError: # Fall back to absolute imports (when run as script) import sys from pathlib import Path + # Add src to path if not already there src_path = str(Path(__file__).parent.parent) if src_path not in sys.path: sys.path.insert(0, src_path) - + from arcade.errors import ( ArcadeConnectionError, ArcadeAuthenticationError, ArcadeExecutionError, ArcadeTimeoutError, ArcadeRegistrationError, - ArcadeSerializationError + ArcadeSerializationError, ) from core.errors import ToolExecutionError @@ -56,19 +58,21 @@ class ArcadeClient: """ Client for interacting with Arcade.dev tool hosting platform. - + Provides secure tool execution, registration, and management through the Arcade.dev gateway infrastructure. """ - - def __init__(self, - api_key: Optional[str] = None, - base_url: Optional[str] = None, - timeout: int = 30, - max_retries: int = 3): + + def __init__( + self, + api_key: Optional[str] = None, + base_url: Optional[str] = None, + timeout: int = 30, + max_retries: int = 3, + ): """ Initialize Arcade client. - + Args: api_key: Arcade.dev API key base_url: Base URL for Arcade.dev API @@ -76,25 +80,25 @@ def __init__(self, max_retries: Maximum retry attempts """ if not ARCADE_AVAILABLE: - raise ImportError("arcadepy library not available. Install with: pip install arcadepy") - + raise ImportError( + "arcadepy library not available. Install with: pip install arcadepy" + ) + self.api_key = api_key self.base_url = base_url or "https://api.arcade.dev" self.timeout = timeout self.max_retries = max_retries - + # Initialize arcade client self._client = None self._connected = False - - logger.info("ArcadeClient initialized", - base_url=self.base_url, - timeout=timeout) - + + logger.info("ArcadeClient initialized", base_url=self.base_url, timeout=timeout) + async def connect(self) -> None: """ Establish connection to Arcade.dev platform. - + Raises: ArcadeConnectionError: If connection fails ArcadeAuthenticationError: If authentication fails @@ -103,278 +107,290 @@ async def connect(self) -> None: # Initialize arcade client with API key if self.api_key: self._client = arcade.Client( - api_key=self.api_key, - base_url=self.base_url, - timeout=self.timeout + api_key=self.api_key, base_url=self.base_url, timeout=self.timeout ) else: self._client = arcade.Client( - base_url=self.base_url, - timeout=self.timeout + base_url=self.base_url, timeout=self.timeout ) - + # Test connection with a simple API call await self._test_connection() - + self._connected = True logger.info("Connected to Arcade.dev successfully") - + except Exception as e: logger.error("Failed to connect to Arcade.dev", error=str(e)) if "authentication" in str(e).lower() or "api key" in str(e).lower(): raise ArcadeAuthenticationError(f"Authentication failed: {str(e)}") else: raise ArcadeConnectionError(f"Connection failed: {str(e)}") - - async def execute_tool(self, - tool_name: str, - arguments: Dict[str, Any], - timeout: Optional[int] = None, - user_id: Optional[str] = None) -> Dict[str, Any]: + + async def execute_tool( + self, + tool_name: str, + arguments: Dict[str, Any], + timeout: Optional[int] = None, + user_id: Optional[str] = None, + ) -> Dict[str, Any]: """ Execute a tool on Arcade.dev platform. - + Args: tool_name: Name of the tool to execute arguments: Tool arguments timeout: Execution timeout in seconds user_id: Optional user identifier for logging - + Returns: Tool execution result - + Raises: ArcadeExecutionError: If execution fails ArcadeTimeoutError: If execution times out """ if not self._connected or not self._client: await self.connect() - + execution_timeout = timeout or self.timeout start_time = time.time() - + try: - logger.info("Executing tool on Arcade.dev", - tool_name=tool_name, - user_id=user_id, - timeout=execution_timeout) - + logger.info( + "Executing tool on Arcade.dev", + tool_name=tool_name, + user_id=user_id, + timeout=execution_timeout, + ) + # Prepare execution request execution_request = { "tool": tool_name, "arguments": arguments, - "timeout": execution_timeout + "timeout": execution_timeout, } - + if user_id: execution_request["user_id"] = user_id - + # Execute tool with timeout result = await asyncio.wait_for( - self._execute_with_retry(execution_request), - timeout=execution_timeout + self._execute_with_retry(execution_request), timeout=execution_timeout ) - + execution_time = (time.time() - start_time) * 1000 - - logger.info("Tool executed successfully on Arcade.dev", - tool_name=tool_name, - execution_time_ms=execution_time, - user_id=user_id) - + + logger.info( + "Tool executed successfully on Arcade.dev", + tool_name=tool_name, + execution_time_ms=execution_time, + user_id=user_id, + ) + return self._process_execution_result(result, execution_time) - + except asyncio.TimeoutError: execution_time = (time.time() - start_time) * 1000 - logger.error("Tool execution timed out", - tool_name=tool_name, - timeout=execution_timeout, - execution_time_ms=execution_time) - raise ArcadeTimeoutError(f"Tool execution timed out after {execution_timeout} seconds") - + logger.error( + "Tool execution timed out", + tool_name=tool_name, + timeout=execution_timeout, + execution_time_ms=execution_time, + ) + raise ArcadeTimeoutError( + f"Tool execution timed out after {execution_timeout} seconds" + ) + except Exception as e: execution_time = (time.time() - start_time) * 1000 - logger.error("Tool execution failed", - tool_name=tool_name, - error=str(e), - execution_time_ms=execution_time) + logger.error( + "Tool execution failed", + tool_name=tool_name, + error=str(e), + execution_time_ms=execution_time, + ) raise ArcadeExecutionError(f"Tool execution failed: {str(e)}") - - async def register_tool(self, - tool_definition: Dict[str, Any], - source_code: Optional[str] = None) -> Dict[str, Any]: + + async def register_tool( + self, tool_definition: Dict[str, Any], source_code: Optional[str] = None + ) -> Dict[str, Any]: """ Register a tool with Arcade.dev platform. - + Args: tool_definition: Tool definition including name, description, parameters source_code: Optional source code for the tool - + Returns: Registration result - + Raises: ArcadeRegistrationError: If registration fails """ if not self._connected or not self._client: await self.connect() - + try: - logger.info("Registering tool with Arcade.dev", - tool_name=tool_definition.get("name")) - + logger.info( + "Registering tool with Arcade.dev", + tool_name=tool_definition.get("name"), + ) + # Prepare registration request registration_request = { "name": tool_definition["name"], "description": tool_definition["description"], - "parameters": tool_definition["parameters"] + "parameters": tool_definition["parameters"], } - + if source_code: registration_request["source_code"] = source_code - + # Add metadata registration_request["metadata"] = { "version": tool_definition.get("version", "1.0.0"), "requires_auth": tool_definition.get("requires_auth", False), "timeout_seconds": tool_definition.get("timeout_seconds", 30), - "created_at": time.time() + "created_at": time.time(), } - + # Register tool result = await self._client.tools.register(registration_request) - - logger.info("Tool registered successfully", - tool_name=tool_definition["name"], - tool_id=result.get("id")) - + + logger.info( + "Tool registered successfully", + tool_name=tool_definition["name"], + tool_id=result.get("id"), + ) + return result - + except Exception as e: - logger.error("Tool registration failed", - tool_name=tool_definition.get("name"), - error=str(e)) + logger.error( + "Tool registration failed", + tool_name=tool_definition.get("name"), + error=str(e), + ) raise ArcadeRegistrationError(f"Tool registration failed: {str(e)}") - + async def list_tools(self) -> List[Dict[str, Any]]: """ List all registered tools on Arcade.dev platform. - + Returns: List of registered tools - + Raises: ArcadeConnectionError: If request fails """ if not self._connected or not self._client: await self.connect() - + try: result = await self._client.tools.list() - - logger.debug("Retrieved tool list from Arcade.dev", - tool_count=len(result.get("tools", []))) - + + logger.debug( + "Retrieved tool list from Arcade.dev", + tool_count=len(result.get("tools", [])), + ) + return result.get("tools", []) - + except Exception as e: logger.error("Failed to list tools", error=str(e)) raise ArcadeConnectionError(f"Failed to list tools: {str(e)}") - + async def get_tool_info(self, tool_name: str) -> Dict[str, Any]: """ Get detailed information about a specific tool. - + Args: tool_name: Name of the tool - + Returns: Tool information - + Raises: ArcadeConnectionError: If request fails """ if not self._connected or not self._client: await self.connect() - + try: result = await self._client.tools.get(tool_name) - - logger.debug("Retrieved tool info from Arcade.dev", - tool_name=tool_name) - + + logger.debug("Retrieved tool info from Arcade.dev", tool_name=tool_name) + return result - + except Exception as e: - logger.error("Failed to get tool info", - tool_name=tool_name, - error=str(e)) + logger.error("Failed to get tool info", tool_name=tool_name, error=str(e)) raise ArcadeConnectionError(f"Failed to get tool info: {str(e)}") - + async def delete_tool(self, tool_name: str) -> bool: """ Delete a tool from Arcade.dev platform. - + Args: tool_name: Name of the tool to delete - + Returns: True if deletion was successful - + Raises: ArcadeConnectionError: If request fails """ if not self._connected or not self._client: await self.connect() - + try: await self._client.tools.delete(tool_name) - - logger.info("Tool deleted successfully", - tool_name=tool_name) - + + logger.info("Tool deleted successfully", tool_name=tool_name) + return True - + except Exception as e: - logger.error("Failed to delete tool", - tool_name=tool_name, - error=str(e)) + logger.error("Failed to delete tool", tool_name=tool_name, error=str(e)) raise ArcadeConnectionError(f"Failed to delete tool: {str(e)}") - - async def get_execution_logs(self, - execution_id: str, - limit: int = 100) -> List[Dict[str, Any]]: + + async def get_execution_logs( + self, execution_id: str, limit: int = 100 + ) -> List[Dict[str, Any]]: """ Get execution logs for a specific execution. - + Args: execution_id: Execution identifier limit: Maximum number of log entries - + Returns: List of log entries - + Raises: ArcadeConnectionError: If request fails """ if not self._connected or not self._client: await self.connect() - + try: result = await self._client.executions.get_logs(execution_id, limit=limit) - - logger.debug("Retrieved execution logs", - execution_id=execution_id, - log_count=len(result.get("logs", []))) - + + logger.debug( + "Retrieved execution logs", + execution_id=execution_id, + log_count=len(result.get("logs", [])), + ) + return result.get("logs", []) - + except Exception as e: - logger.error("Failed to get execution logs", - execution_id=execution_id, - error=str(e)) + logger.error( + "Failed to get execution logs", execution_id=execution_id, error=str(e) + ) raise ArcadeConnectionError(f"Failed to get execution logs: {str(e)}") - + async def close(self) -> None: """Close the Arcade client connection.""" if self._client: @@ -386,12 +402,12 @@ async def close(self) -> None: self._client = None self._connected = False logger.info("Arcade client connection closed") - + async def _test_connection(self) -> None: """Test connection to Arcade.dev platform.""" try: # Simple health check or authentication test - if hasattr(self._client, 'health'): + if hasattr(self._client, "health"): await self._client.health.check() else: # Fallback to listing tools as a connection test @@ -399,68 +415,80 @@ async def _test_connection(self) -> None: except Exception as e: logger.error("Connection test failed", error=str(e)) raise - - async def _execute_with_retry(self, execution_request: Dict[str, Any]) -> Dict[str, Any]: + + async def _execute_with_retry( + self, execution_request: Dict[str, Any] + ) -> Dict[str, Any]: """Execute tool with retry logic.""" last_error = None - + for attempt in range(self.max_retries + 1): try: return await self._client.tools.execute(execution_request) except Exception as e: last_error = e if attempt < self.max_retries: - wait_time = 2 ** attempt # Exponential backoff - logger.warning("Tool execution attempt failed, retrying", - attempt=attempt + 1, - max_retries=self.max_retries, - wait_time=wait_time, - error=str(e)) + wait_time = 2**attempt # Exponential backoff + logger.warning( + "Tool execution attempt failed, retrying", + attempt=attempt + 1, + max_retries=self.max_retries, + wait_time=wait_time, + error=str(e), + ) await asyncio.sleep(wait_time) else: - logger.error("All retry attempts exhausted", - attempts=self.max_retries + 1, - error=str(e)) - + logger.error( + "All retry attempts exhausted", + attempts=self.max_retries + 1, + error=str(e), + ) + raise last_error - - def _process_execution_result(self, result: Dict[str, Any], execution_time: float) -> Dict[str, Any]: + + def _process_execution_result( + self, result: Dict[str, Any], execution_time: float + ) -> Dict[str, Any]: """Process and validate execution result.""" try: # Ensure result has required fields processed_result = { "success": result.get("success", True), - "execution_time_ms": execution_time + "execution_time_ms": execution_time, } - + if result.get("success", True): processed_result["data"] = result.get("data", result) else: processed_result["error"] = result.get("error", "Unknown error") processed_result["success"] = False - + # Add metadata if available if "metadata" in result: processed_result["metadata"] = result["metadata"] - + return processed_result - + except Exception as e: logger.error("Failed to process execution result", error=str(e)) - raise ArcadeSerializationError(f"Failed to process execution result: {str(e)}") + raise ArcadeSerializationError( + f"Failed to process execution result: {str(e)}" + ) # Utility functions for Arcade integration -def create_arcade_client(api_key: Optional[str] = None, - base_url: Optional[str] = None) -> ArcadeClient: + +def create_arcade_client( + api_key: Optional[str] = None, base_url: Optional[str] = None +) -> ArcadeClient: """ Create and configure an Arcade client. - + Args: api_key: Arcade.dev API key base_url: Base URL for Arcade.dev API - + Returns: Configured ArcadeClient instance """ @@ -470,10 +498,10 @@ def create_arcade_client(api_key: Optional[str] = None, async def test_arcade_connection(client: ArcadeClient) -> bool: """ Test connection to Arcade.dev platform. - + Args: client: ArcadeClient instance - + Returns: True if connection is successful """ @@ -482,4 +510,4 @@ async def test_arcade_connection(client: ArcadeClient) -> bool: return True except Exception as e: logger.error("Arcade connection test failed", error=str(e)) - return False \ No newline at end of file + return False diff --git a/src/arcade/errors.py b/src/arcade/errors.py index 4639c0f..ae9ffee 100644 --- a/src/arcade/errors.py +++ b/src/arcade/errors.py @@ -6,6 +6,7 @@ """ from typing import Optional, Dict, Any + try: # Try relative imports first (when used as package) from ..core.errors import FACTError @@ -13,44 +14,52 @@ # Fall back to absolute imports (when run as script) import sys from pathlib import Path + # Add src to path if not already there src_path = str(Path(__file__).parent.parent) if src_path not in sys.path: sys.path.insert(0, src_path) - + from core.errors import FACTError class ArcadeError(FACTError): """Base exception class for all Arcade.dev integration errors.""" + pass class ArcadeConnectionError(ArcadeError): """Raised when connection to Arcade.dev fails.""" + pass class ArcadeAuthenticationError(ArcadeError): """Raised when Arcade.dev authentication fails.""" + pass class ArcadeExecutionError(ArcadeError): """Raised when tool execution on Arcade.dev fails.""" + pass class ArcadeTimeoutError(ArcadeError): """Raised when Arcade.dev operations timeout.""" + pass class ArcadeRegistrationError(ArcadeError): """Raised when tool registration with Arcade.dev fails.""" + pass class ArcadeSerializationError(ArcadeError): """Raised when request/response serialization fails.""" - pass \ No newline at end of file + + pass diff --git a/src/arcade/gateway.py b/src/arcade/gateway.py index f3890e9..bda72d5 100644 --- a/src/arcade/gateway.py +++ b/src/arcade/gateway.py @@ -18,11 +18,12 @@ # Fall back to absolute imports (when run as script) import sys from pathlib import Path + # Add src to path if not already there src_path = str(Path(__file__).parent.parent) if src_path not in sys.path: sys.path.insert(0, src_path) - + from arcade.client import ArcadeClient from arcade.errors import ArcadeError, ArcadeExecutionError from core.errors import ToolExecutionError @@ -34,18 +35,20 @@ class ArcadeGateway: """ Gateway for routing tool execution between local and Arcade.dev execution. - + Provides intelligent routing, fallback mechanisms, and execution orchestration for tools that can run locally or on Arcade.dev. """ - - def __init__(self, - arcade_client: Optional[ArcadeClient] = None, - enable_fallback: bool = True, - prefer_arcade: bool = False): + + def __init__( + self, + arcade_client: Optional[ArcadeClient] = None, + enable_fallback: bool = True, + prefer_arcade: bool = False, + ): """ Initialize Arcade gateway. - + Args: arcade_client: Optional Arcade client for remote execution enable_fallback: Whether to fallback to local execution on Arcade failure @@ -54,199 +57,216 @@ def __init__(self, self.arcade_client = arcade_client self.enable_fallback = enable_fallback self.prefer_arcade = prefer_arcade - - logger.info("ArcadeGateway initialized", - has_arcade_client=bool(arcade_client), - enable_fallback=enable_fallback, - prefer_arcade=prefer_arcade) - - async def execute_tool(self, - tool_name: str, - arguments: Dict[str, Any], - local_function: Optional[callable] = None, - user_id: Optional[str] = None, - timeout: Optional[int] = None) -> Dict[str, Any]: + + logger.info( + "ArcadeGateway initialized", + has_arcade_client=bool(arcade_client), + enable_fallback=enable_fallback, + prefer_arcade=prefer_arcade, + ) + + async def execute_tool( + self, + tool_name: str, + arguments: Dict[str, Any], + local_function: Optional[callable] = None, + user_id: Optional[str] = None, + timeout: Optional[int] = None, + ) -> Dict[str, Any]: """ Execute a tool using the best available method. - + Args: tool_name: Name of the tool to execute arguments: Tool arguments local_function: Local function to execute if available user_id: Optional user identifier timeout: Execution timeout - + Returns: Tool execution result - + Raises: ToolExecutionError: If execution fails on all available methods """ execution_methods = self._determine_execution_order(tool_name, local_function) - + last_error = None - + for method in execution_methods: try: if method == "arcade": - logger.debug("Attempting Arcade execution", - tool_name=tool_name, - user_id=user_id) - + logger.debug( + "Attempting Arcade execution", + tool_name=tool_name, + user_id=user_id, + ) + result = await self._execute_via_arcade( tool_name, arguments, user_id, timeout ) - - logger.info("Tool executed successfully via Arcade", - tool_name=tool_name, - user_id=user_id) - + + logger.info( + "Tool executed successfully via Arcade", + tool_name=tool_name, + user_id=user_id, + ) + return result - + elif method == "local": - logger.debug("Attempting local execution", - tool_name=tool_name, - user_id=user_id) - + logger.debug( + "Attempting local execution", + tool_name=tool_name, + user_id=user_id, + ) + result = await self._execute_locally( tool_name, local_function, arguments ) - - logger.info("Tool executed successfully locally", - tool_name=tool_name, - user_id=user_id) - + + logger.info( + "Tool executed successfully locally", + tool_name=tool_name, + user_id=user_id, + ) + return result - + except Exception as e: last_error = e - logger.warning("Execution method failed", - method=method, - tool_name=tool_name, - error=str(e)) - + logger.warning( + "Execution method failed", + method=method, + tool_name=tool_name, + error=str(e), + ) + # If this is the last method and fallback is disabled, raise immediately if not self.enable_fallback and method == execution_methods[0]: raise - + # All methods failed error_msg = f"Tool execution failed on all available methods. Last error: {str(last_error)}" - logger.error("All execution methods failed", - tool_name=tool_name, - methods_tried=execution_methods, - last_error=str(last_error)) - + logger.error( + "All execution methods failed", + tool_name=tool_name, + methods_tried=execution_methods, + last_error=str(last_error), + ) + raise ToolExecutionError(error_msg) - - def _determine_execution_order(self, - tool_name: str, - local_function: Optional[callable]) -> list: + + def _determine_execution_order( + self, tool_name: str, local_function: Optional[callable] + ) -> list: """ Determine the order of execution methods to try. - + Args: tool_name: Name of the tool local_function: Local function if available - + Returns: List of execution methods in order of preference """ methods = [] - + # If we prefer Arcade and have a client, try Arcade first if self.prefer_arcade and self.arcade_client: methods.append("arcade") - + # Add local as fallback if available and fallback is enabled if local_function and self.enable_fallback: methods.append("local") - + # If we prefer local or don't have Arcade client else: # Try local first if available if local_function: methods.append("local") - + # Add Arcade as fallback if available and fallback is enabled if self.arcade_client and self.enable_fallback: methods.append("arcade") - + # If no methods are available, at least try what we have if not methods: if self.arcade_client: methods.append("arcade") elif local_function: methods.append("local") - + return methods - - async def _execute_via_arcade(self, - tool_name: str, - arguments: Dict[str, Any], - user_id: Optional[str], - timeout: Optional[int]) -> Dict[str, Any]: + + async def _execute_via_arcade( + self, + tool_name: str, + arguments: Dict[str, Any], + user_id: Optional[str], + timeout: Optional[int], + ) -> Dict[str, Any]: """ Execute tool via Arcade.dev platform. - + Args: tool_name: Name of the tool arguments: Tool arguments user_id: Optional user identifier timeout: Execution timeout - + Returns: Arcade execution result - + Raises: ArcadeExecutionError: If Arcade execution fails """ if not self.arcade_client: raise ArcadeExecutionError("No Arcade client available") - + try: result = await self.arcade_client.execute_tool( tool_name=tool_name, arguments=arguments, timeout=timeout, - user_id=user_id + user_id=user_id, ) - + # Ensure result has expected structure if not isinstance(result, dict): raise ArcadeExecutionError("Invalid result format from Arcade") - + # Add execution metadata result["execution_method"] = "arcade" result["platform"] = "arcade.dev" - + return result - + except ArcadeError: raise except Exception as e: raise ArcadeExecutionError(f"Arcade execution failed: {str(e)}") - - async def _execute_locally(self, - tool_name: str, - local_function: callable, - arguments: Dict[str, Any]) -> Dict[str, Any]: + + async def _execute_locally( + self, tool_name: str, local_function: callable, arguments: Dict[str, Any] + ) -> Dict[str, Any]: """ Execute tool locally using provided function. - + Args: tool_name: Name of the tool local_function: Local function to execute arguments: Tool arguments - + Returns: Local execution result - + Raises: ToolExecutionError: If local execution fails """ if not local_function: raise ToolExecutionError("No local function available") - + try: # Execute function (handle both sync and async functions) if asyncio.iscoroutinefunction(local_function): @@ -254,25 +274,27 @@ async def _execute_locally(self, else: # Run sync function in thread pool to avoid blocking loop = asyncio.get_event_loop() - result = await loop.run_in_executor(None, lambda: local_function(**arguments)) - + result = await loop.run_in_executor( + None, lambda: local_function(**arguments) + ) + # Ensure result is a dictionary if not isinstance(result, dict): result = {"result": result} - + # Add execution metadata result["execution_method"] = "local" result["platform"] = "local" - + return result - + except Exception as e: raise ToolExecutionError(f"Local execution failed: {str(e)}") - + async def health_check(self) -> Dict[str, Any]: """ Perform health check on available execution methods. - + Returns: Health check results """ @@ -280,9 +302,9 @@ async def health_check(self) -> Dict[str, Any]: "gateway_healthy": True, "arcade_available": False, "local_available": True, # Local execution is always available - "methods": [] + "methods": [], } - + # Check Arcade availability if self.arcade_client: try: @@ -290,25 +312,25 @@ async def health_check(self) -> Dict[str, Any]: await self.arcade_client.connect() health_status["arcade_available"] = True health_status["methods"].append("arcade") - + logger.debug("Arcade health check passed") - + except Exception as e: logger.warning("Arcade health check failed", error=str(e)) health_status["arcade_error"] = str(e) - + # Local execution is always available health_status["methods"].append("local") - + # Overall health health_status["gateway_healthy"] = len(health_status["methods"]) > 0 - + return health_status - + def get_execution_stats(self) -> Dict[str, Any]: """ Get execution statistics and configuration. - + Returns: Execution statistics """ @@ -316,26 +338,28 @@ def get_execution_stats(self) -> Dict[str, Any]: "arcade_client_configured": bool(self.arcade_client), "fallback_enabled": self.enable_fallback, "prefer_arcade": self.prefer_arcade, - "available_methods": ["arcade"] if self.arcade_client else [] + ["local"] + "available_methods": ["arcade"] if self.arcade_client else [] + ["local"], } -def create_arcade_gateway(arcade_client: Optional[ArcadeClient] = None, - enable_fallback: bool = True, - prefer_arcade: bool = False) -> ArcadeGateway: +def create_arcade_gateway( + arcade_client: Optional[ArcadeClient] = None, + enable_fallback: bool = True, + prefer_arcade: bool = False, +) -> ArcadeGateway: """ Create and configure an Arcade gateway. - + Args: arcade_client: Optional Arcade client enable_fallback: Whether to enable fallback between methods prefer_arcade: Whether to prefer Arcade over local execution - + Returns: Configured ArcadeGateway instance """ return ArcadeGateway( arcade_client=arcade_client, enable_fallback=enable_fallback, - prefer_arcade=prefer_arcade - ) \ No newline at end of file + prefer_arcade=prefer_arcade, + ) diff --git a/src/benchmarking/__init__.py b/src/benchmarking/__init__.py index 8146a0d..77152bf 100644 --- a/src/benchmarking/__init__.py +++ b/src/benchmarking/__init__.py @@ -10,7 +10,7 @@ BenchmarkRunner, BenchmarkConfig, BenchmarkResult, - BenchmarkSummary + BenchmarkSummary, ) from .comparisons import ( @@ -18,7 +18,7 @@ PerformanceComparison, ComparisonResult, ComparisonMetrics, - SystemType + SystemType, ) from .profiler import ( @@ -26,7 +26,7 @@ BottleneckAnalyzer, ProfileResult, BottleneckAnalysis, - ProfilePoint + ProfilePoint, ) from .monitoring import ( @@ -34,7 +34,7 @@ PerformanceTracker, MonitoringConfig, PerformanceAlert, - PerformanceTrend + PerformanceTrend, ) from .visualization import ( @@ -42,47 +42,43 @@ ReportGenerator, BenchmarkReport, ChartData, - ReportSection + ReportSection, ) __all__ = [ # Framework "BenchmarkFramework", - "BenchmarkRunner", + "BenchmarkRunner", "BenchmarkConfig", "BenchmarkResult", "BenchmarkSummary", - # Comparisons "RAGComparison", - "PerformanceComparison", + "PerformanceComparison", "ComparisonResult", "ComparisonMetrics", "SystemType", - # Profiling "SystemProfiler", "BottleneckAnalyzer", "ProfileResult", "BottleneckAnalysis", "ProfilePoint", - # Monitoring "ContinuousMonitor", "PerformanceTracker", - "MonitoringConfig", + "MonitoringConfig", "PerformanceAlert", "PerformanceTrend", - # Visualization "BenchmarkVisualizer", "ReportGenerator", "BenchmarkReport", "ChartData", - "ReportSection" + "ReportSection", ] # Version info __version__ = "1.0.0" __author__ = "FACT Team" -__description__ = "Comprehensive benchmarking system for FACT performance validation" \ No newline at end of file +__description__ = "Comprehensive benchmarking system for FACT performance validation" diff --git a/src/benchmarking/comparisons.py b/src/benchmarking/comparisons.py index 8281315..5db1189 100644 --- a/src/benchmarking/comparisons.py +++ b/src/benchmarking/comparisons.py @@ -21,12 +21,17 @@ # Fall back to absolute imports (when run as script) import sys from pathlib import Path + # Add src to path if not already there src_path = str(Path(__file__).parent.parent) if src_path not in sys.path: sys.path.insert(0, src_path) - - from benchmarking.framework import BenchmarkResult, BenchmarkSummary, BenchmarkFramework + + from benchmarking.framework import ( + BenchmarkResult, + BenchmarkSummary, + BenchmarkFramework, + ) from cache.manager import CacheManager logger = structlog.get_logger(__name__) @@ -34,6 +39,7 @@ class SystemType(Enum): """System types for comparison.""" + FACT = "fact" TRADITIONAL_RAG = "traditional_rag" HYBRID = "hybrid" @@ -42,6 +48,7 @@ class SystemType(Enum): @dataclass class ComparisonMetrics: """Metrics for system comparison.""" + system_type: SystemType avg_latency_ms: float p95_latency_ms: float @@ -58,6 +65,7 @@ class ComparisonMetrics: @dataclass class ComparisonResult: """Result of system comparison.""" + fact_metrics: ComparisonMetrics rag_metrics: ComparisonMetrics improvement_factors: Dict[str, float] @@ -70,127 +78,146 @@ class ComparisonResult: class RAGComparison: """ Comprehensive comparison between FACT and traditional RAG systems. - + Simulates traditional RAG behavior and measures performance differences across latency, cost, and accuracy dimensions. """ - + def __init__(self, benchmark_framework: Optional[BenchmarkFramework] = None): """ Initialize RAG comparison. - + Args: benchmark_framework: Benchmark framework for measurements """ self.framework = benchmark_framework or BenchmarkFramework() - + # Traditional RAG simulation parameters self.rag_base_latency_ms = 400 # Typical RAG latency self.rag_vector_search_ms = 50 # Vector database search time self.rag_llm_processing_ms = 300 # LLM processing time self.rag_context_preparation_ms = 50 # Context preparation time - + # Token usage patterns self.rag_avg_input_tokens = 1500 # Large context with retrieved docs self.rag_avg_output_tokens = 300 # Detailed responses self.fact_hit_input_tokens = 100 # Minimal prompt for cache - self.fact_hit_output_tokens = 50 # Cached response + self.fact_hit_output_tokens = 50 # Cached response self.fact_miss_input_tokens = 200 # Tool-enhanced prompt - self.fact_miss_output_tokens = 150 # Response with tool results - + self.fact_miss_output_tokens = 150 # Response with tool results + logger.info("RAG comparison initialized") - - async def run_comparison_benchmark(self, - queries: List[str], - cache_manager: Optional[CacheManager] = None, - iterations: int = 10) -> ComparisonResult: + + async def run_comparison_benchmark( + self, + queries: List[str], + cache_manager: Optional[CacheManager] = None, + iterations: int = 10, + ) -> ComparisonResult: """ Run comprehensive comparison benchmark. - + Args: queries: Test queries for comparison cache_manager: Cache manager for FACT system iterations: Number of iterations per query - + Returns: Comparison results """ - logger.info("Starting RAG comparison benchmark", - queries_count=len(queries), - iterations=iterations) - + logger.info( + "Starting RAG comparison benchmark", + queries_count=len(queries), + iterations=iterations, + ) + # Run FACT benchmarks fact_summary = await self.framework.run_benchmark_suite(queries, cache_manager) fact_metrics = self._extract_comparison_metrics(fact_summary, SystemType.FACT) - + # Simulate traditional RAG performance rag_results = await self._simulate_rag_performance(queries, iterations) - rag_summary = self.framework._generate_summary(rag_results, fact_summary.execution_time_seconds) - rag_metrics = self._extract_comparison_metrics(rag_summary, SystemType.TRADITIONAL_RAG) - + rag_summary = self.framework._generate_summary( + rag_results, fact_summary.execution_time_seconds + ) + rag_metrics = self._extract_comparison_metrics( + rag_summary, SystemType.TRADITIONAL_RAG + ) + # Calculate improvement factors - improvement_factors = self._calculate_improvement_factors(fact_metrics, rag_metrics) - + improvement_factors = self._calculate_improvement_factors( + fact_metrics, rag_metrics + ) + # Calculate cost savings cost_savings = self._calculate_cost_savings(fact_metrics, rag_metrics) - + # Generate performance analysis - performance_analysis = self._analyze_performance_differences(fact_metrics, rag_metrics) - + performance_analysis = self._analyze_performance_differences( + fact_metrics, rag_metrics + ) + # Generate recommendation - recommendation = self._generate_recommendation(improvement_factors, cost_savings) - + recommendation = self._generate_recommendation( + improvement_factors, cost_savings + ) + result = ComparisonResult( fact_metrics=fact_metrics, rag_metrics=rag_metrics, improvement_factors=improvement_factors, cost_savings=cost_savings, performance_analysis=performance_analysis, - recommendation=recommendation + recommendation=recommendation, ) - - logger.info("RAG comparison completed", - latency_improvement=improvement_factors.get("latency", 0), - cost_reduction=cost_savings.get("percentage", 0)) - + + logger.info( + "RAG comparison completed", + latency_improvement=improvement_factors.get("latency", 0), + cost_reduction=cost_savings.get("percentage", 0), + ) + return result - - async def _simulate_rag_performance(self, - queries: List[str], - iterations: int) -> List[BenchmarkResult]: + + async def _simulate_rag_performance( + self, queries: List[str], iterations: int + ) -> List[BenchmarkResult]: """ Simulate traditional RAG system performance. - + Args: queries: Test queries iterations: Number of iterations - + Returns: Simulated RAG benchmark results """ results = [] - + for iteration in range(iterations): for query in queries: # Simulate RAG processing time base_latency = self.rag_base_latency_ms - + # Add variability based on query complexity - complexity_factor = min(len(query) / 100, 2.0) # Up to 2x for complex queries + complexity_factor = min( + len(query) / 100, 2.0 + ) # Up to 2x for complex queries variable_latency = base_latency * (0.8 + complexity_factor * 0.4) - + # Add random variation (±20%) import random + variation = random.uniform(0.8, 1.2) total_latency = variable_latency * variation - + # Simulate processing delay await asyncio.sleep(total_latency / 1000) # Convert to seconds - + # Calculate token costs for RAG token_count = self.rag_avg_input_tokens + self.rag_avg_output_tokens token_cost = self._calculate_rag_token_cost(token_count) - + result = BenchmarkResult( query=query, response_time_ms=total_latency, @@ -198,16 +225,16 @@ async def _simulate_rag_performance(self, cache_hit=False, # Traditional RAG doesn't use cache token_count=token_count, token_cost=token_cost, - metadata={"system_type": "traditional_rag"} + metadata={"system_type": "traditional_rag"}, ) - + results.append(result) - + return results - - def _extract_comparison_metrics(self, - summary: BenchmarkSummary, - system_type: SystemType) -> ComparisonMetrics: + + def _extract_comparison_metrics( + self, summary: BenchmarkSummary, system_type: SystemType + ) -> ComparisonMetrics: """Extract comparison metrics from benchmark summary.""" return ComparisonMetrics( system_type=system_type, @@ -220,45 +247,55 @@ def _extract_comparison_metrics(self, error_rate=summary.error_rate, throughput_qps=summary.throughput_qps, memory_usage_mb=0.0, # TODO: Implement memory tracking - cpu_usage_percent=0.0 # TODO: Implement CPU tracking + cpu_usage_percent=0.0, # TODO: Implement CPU tracking ) - - def _calculate_improvement_factors(self, - fact_metrics: ComparisonMetrics, - rag_metrics: ComparisonMetrics) -> Dict[str, float]: + + def _calculate_improvement_factors( + self, fact_metrics: ComparisonMetrics, rag_metrics: ComparisonMetrics + ) -> Dict[str, float]: """Calculate improvement factors for FACT vs RAG.""" improvements = {} - + # Latency improvements if rag_metrics.avg_latency_ms > 0: - improvements["latency"] = rag_metrics.avg_latency_ms / fact_metrics.avg_latency_ms - + improvements["latency"] = ( + rag_metrics.avg_latency_ms / fact_metrics.avg_latency_ms + ) + if rag_metrics.p95_latency_ms > 0: - improvements["p95_latency"] = rag_metrics.p95_latency_ms / fact_metrics.p95_latency_ms - + improvements["p95_latency"] = ( + rag_metrics.p95_latency_ms / fact_metrics.p95_latency_ms + ) + # Throughput improvements if fact_metrics.throughput_qps > 0 and rag_metrics.throughput_qps > 0: - improvements["throughput"] = fact_metrics.throughput_qps / rag_metrics.throughput_qps - + improvements["throughput"] = ( + fact_metrics.throughput_qps / rag_metrics.throughput_qps + ) + # Cost improvements if rag_metrics.avg_token_cost > 0: - improvements["cost_efficiency"] = rag_metrics.avg_token_cost / fact_metrics.avg_token_cost - + improvements["cost_efficiency"] = ( + rag_metrics.avg_token_cost / fact_metrics.avg_token_cost + ) + return improvements - - def _calculate_cost_savings(self, - fact_metrics: ComparisonMetrics, - rag_metrics: ComparisonMetrics) -> Dict[str, float]: + + def _calculate_cost_savings( + self, fact_metrics: ComparisonMetrics, rag_metrics: ComparisonMetrics + ) -> Dict[str, float]: """Calculate cost savings from using FACT.""" savings = {} - + if rag_metrics.total_token_cost > 0: - absolute_savings = rag_metrics.total_token_cost - fact_metrics.total_token_cost + absolute_savings = ( + rag_metrics.total_token_cost - fact_metrics.total_token_cost + ) percentage_savings = (absolute_savings / rag_metrics.total_token_cost) * 100 - + savings["absolute_usd"] = absolute_savings savings["percentage"] = percentage_savings - + # Extrapolate monthly savings if absolute_savings > 0: # Assume this represents 1 hour of usage @@ -266,31 +303,47 @@ def _calculate_cost_savings(self, monthly_savings = daily_savings * 30 savings["monthly_usd"] = monthly_savings savings["annual_usd"] = monthly_savings * 12 - + return savings - - def _analyze_performance_differences(self, - fact_metrics: ComparisonMetrics, - rag_metrics: ComparisonMetrics) -> Dict[str, Any]: + + def _analyze_performance_differences( + self, fact_metrics: ComparisonMetrics, rag_metrics: ComparisonMetrics + ) -> Dict[str, Any]: """Analyze detailed performance differences.""" analysis = { "latency_analysis": { "fact_avg_ms": fact_metrics.avg_latency_ms, "rag_avg_ms": rag_metrics.avg_latency_ms, - "improvement_factor": rag_metrics.avg_latency_ms / fact_metrics.avg_latency_ms if fact_metrics.avg_latency_ms > 0 else 0, - "time_saved_per_query_ms": rag_metrics.avg_latency_ms - fact_metrics.avg_latency_ms, - "meets_target": fact_metrics.avg_latency_ms <= 100 # Overall target + "improvement_factor": ( + rag_metrics.avg_latency_ms / fact_metrics.avg_latency_ms + if fact_metrics.avg_latency_ms > 0 + else 0 + ), + "time_saved_per_query_ms": rag_metrics.avg_latency_ms + - fact_metrics.avg_latency_ms, + "meets_target": fact_metrics.avg_latency_ms <= 100, # Overall target }, "cost_analysis": { "fact_cost_per_query": fact_metrics.avg_token_cost, "rag_cost_per_query": rag_metrics.avg_token_cost, - "cost_reduction_factor": rag_metrics.avg_token_cost / fact_metrics.avg_token_cost if fact_metrics.avg_token_cost > 0 else 0, - "savings_per_query": rag_metrics.avg_token_cost - fact_metrics.avg_token_cost + "cost_reduction_factor": ( + rag_metrics.avg_token_cost / fact_metrics.avg_token_cost + if fact_metrics.avg_token_cost > 0 + else 0 + ), + "savings_per_query": rag_metrics.avg_token_cost + - fact_metrics.avg_token_cost, }, "efficiency_analysis": { "fact_cache_hit_rate": fact_metrics.cache_hit_rate, - "throughput_advantage": fact_metrics.throughput_qps / rag_metrics.throughput_qps if rag_metrics.throughput_qps > 0 else 0, - "scalability_factor": self._calculate_scalability_factor(fact_metrics, rag_metrics) + "throughput_advantage": ( + fact_metrics.throughput_qps / rag_metrics.throughput_qps + if rag_metrics.throughput_qps > 0 + else 0 + ), + "scalability_factor": self._calculate_scalability_factor( + fact_metrics, rag_metrics + ), }, "target_compliance": { "hit_latency_target": 48.0, @@ -298,15 +351,19 @@ def _analyze_performance_differences(self, "cost_reduction_target": 75.0, "cache_hit_rate_target": 60.0, "fact_meets_latency": fact_metrics.avg_latency_ms <= 100, - "fact_meets_cost": (rag_metrics.avg_token_cost - fact_metrics.avg_token_cost) / rag_metrics.avg_token_cost >= 0.75 - } + "fact_meets_cost": ( + rag_metrics.avg_token_cost - fact_metrics.avg_token_cost + ) + / rag_metrics.avg_token_cost + >= 0.75, + }, } - + return analysis - - def _calculate_scalability_factor(self, - fact_metrics: ComparisonMetrics, - rag_metrics: ComparisonMetrics) -> float: + + def _calculate_scalability_factor( + self, fact_metrics: ComparisonMetrics, rag_metrics: ComparisonMetrics + ) -> float: """Calculate scalability advantage of FACT over RAG.""" # FACT scales better due to caching - estimate scaling factor # Traditional RAG scales linearly with load @@ -314,25 +371,25 @@ def _calculate_scalability_factor(self, cache_efficiency = fact_metrics.cache_hit_rate / 100.0 scaling_advantage = 1.0 + (cache_efficiency * 2.0) # Up to 3x scaling advantage return scaling_advantage - + def _calculate_rag_token_cost(self, token_count: int) -> float: """Calculate token cost for traditional RAG system.""" # Use same pricing as FACT framework input_tokens = int(token_count * 0.8) # RAG uses more input tokens output_tokens = int(token_count * 0.2) - + input_cost = input_tokens * self.framework.input_token_cost output_cost = output_tokens * self.framework.output_token_cost - + return input_cost + output_cost - - def _generate_recommendation(self, - improvement_factors: Dict[str, float], - cost_savings: Dict[str, float]) -> str: + + def _generate_recommendation( + self, improvement_factors: Dict[str, float], cost_savings: Dict[str, float] + ) -> str: """Generate recommendation based on comparison results.""" latency_improvement = improvement_factors.get("latency", 1.0) cost_reduction = cost_savings.get("percentage", 0.0) - + if latency_improvement >= 3.0 and cost_reduction >= 70.0: return "FACT shows exceptional performance advantages. Strongly recommended for production deployment." elif latency_improvement >= 2.0 and cost_reduction >= 50.0: @@ -346,106 +403,145 @@ def _generate_recommendation(self, class PerformanceComparison: """ Advanced performance comparison with detailed analysis. - + Provides deeper insights into performance characteristics, bottlenecks, and optimization opportunities. """ - + def __init__(self): """Initialize performance comparison.""" self.baseline_measurements: Dict[str, List[float]] = { "fact_latencies": [], "rag_latencies": [], "fact_costs": [], - "rag_costs": [] + "rag_costs": [], } - + logger.info("Performance comparison initialized") - - def analyze_latency_distribution(self, - fact_results: List[BenchmarkResult], - rag_results: List[BenchmarkResult]) -> Dict[str, Any]: + + def analyze_latency_distribution( + self, fact_results: List[BenchmarkResult], rag_results: List[BenchmarkResult] + ) -> Dict[str, Any]: """Analyze latency distribution patterns.""" fact_latencies = [r.response_time_ms for r in fact_results if r.success] rag_latencies = [r.response_time_ms for r in rag_results if r.success] - + analysis = { "fact_distribution": self._calculate_distribution_stats(fact_latencies), "rag_distribution": self._calculate_distribution_stats(rag_latencies), "comparison": { - "median_improvement": statistics.median(rag_latencies) / statistics.median(fact_latencies) if fact_latencies else 0, - "variance_ratio": statistics.variance(rag_latencies) / statistics.variance(fact_latencies) if len(fact_latencies) > 1 else 0, - "consistency_advantage": self._calculate_consistency_advantage(fact_latencies, rag_latencies) - } + "median_improvement": ( + statistics.median(rag_latencies) / statistics.median(fact_latencies) + if fact_latencies + else 0 + ), + "variance_ratio": ( + statistics.variance(rag_latencies) + / statistics.variance(fact_latencies) + if len(fact_latencies) > 1 + else 0 + ), + "consistency_advantage": self._calculate_consistency_advantage( + fact_latencies, rag_latencies + ), + }, } - + return analysis - - def analyze_cost_efficiency(self, - fact_results: List[BenchmarkResult], - rag_results: List[BenchmarkResult]) -> Dict[str, Any]: + + def analyze_cost_efficiency( + self, fact_results: List[BenchmarkResult], rag_results: List[BenchmarkResult] + ) -> Dict[str, Any]: """Analyze cost efficiency patterns.""" fact_costs = [r.token_cost for r in fact_results if r.token_cost is not None] rag_costs = [r.token_cost for r in rag_results if r.token_cost is not None] - - hit_costs = [r.token_cost for r in fact_results if r.cache_hit and r.token_cost is not None] - miss_costs = [r.token_cost for r in fact_results if not r.cache_hit and r.token_cost is not None] - + + hit_costs = [ + r.token_cost + for r in fact_results + if r.cache_hit and r.token_cost is not None + ] + miss_costs = [ + r.token_cost + for r in fact_results + if not r.cache_hit and r.token_cost is not None + ] + analysis = { "overall_savings": { "fact_avg_cost": statistics.mean(fact_costs) if fact_costs else 0, "rag_avg_cost": statistics.mean(rag_costs) if rag_costs else 0, - "savings_ratio": statistics.mean(rag_costs) / statistics.mean(fact_costs) if fact_costs else 0 + "savings_ratio": ( + statistics.mean(rag_costs) / statistics.mean(fact_costs) + if fact_costs + else 0 + ), }, "cache_impact": { "hit_avg_cost": statistics.mean(hit_costs) if hit_costs else 0, "miss_avg_cost": statistics.mean(miss_costs) if miss_costs else 0, - "hit_vs_miss_ratio": statistics.mean(miss_costs) / statistics.mean(hit_costs) if hit_costs else 0 + "hit_vs_miss_ratio": ( + statistics.mean(miss_costs) / statistics.mean(hit_costs) + if hit_costs + else 0 + ), }, - "efficiency_score": self._calculate_efficiency_score(fact_costs, rag_costs) + "efficiency_score": self._calculate_efficiency_score(fact_costs, rag_costs), } - + return analysis - + def _calculate_distribution_stats(self, values: List[float]) -> Dict[str, float]: """Calculate distribution statistics.""" if not values: return {"mean": 0, "median": 0, "std": 0, "min": 0, "max": 0} - + return { "mean": statistics.mean(values), "median": statistics.median(values), "std": statistics.stdev(values) if len(values) > 1 else 0, "min": min(values), "max": max(values), - "p90": statistics.quantiles(values, n=10)[8] if len(values) >= 10 else max(values), - "p95": statistics.quantiles(values, n=20)[18] if len(values) >= 20 else max(values), - "p99": statistics.quantiles(values, n=100)[98] if len(values) >= 100 else max(values) + "p90": ( + statistics.quantiles(values, n=10)[8] + if len(values) >= 10 + else max(values) + ), + "p95": ( + statistics.quantiles(values, n=20)[18] + if len(values) >= 20 + else max(values) + ), + "p99": ( + statistics.quantiles(values, n=100)[98] + if len(values) >= 100 + else max(values) + ), } - - def _calculate_consistency_advantage(self, - fact_latencies: List[float], - rag_latencies: List[float]) -> float: + + def _calculate_consistency_advantage( + self, fact_latencies: List[float], rag_latencies: List[float] + ) -> float: """Calculate consistency advantage of FACT over RAG.""" if len(fact_latencies) <= 1 or len(rag_latencies) <= 1: return 0.0 - + fact_cv = statistics.stdev(fact_latencies) / statistics.mean(fact_latencies) rag_cv = statistics.stdev(rag_latencies) / statistics.mean(rag_latencies) - + # Lower coefficient of variation indicates better consistency return rag_cv / fact_cv if fact_cv > 0 else 0.0 - - def _calculate_efficiency_score(self, - fact_costs: List[float], - rag_costs: List[float]) -> float: + + def _calculate_efficiency_score( + self, fact_costs: List[float], rag_costs: List[float] + ) -> float: """Calculate overall efficiency score (0-100).""" if not fact_costs or not rag_costs: return 0.0 - + cost_ratio = statistics.mean(rag_costs) / statistics.mean(fact_costs) - + # Convert to 0-100 scale where 100 = perfect efficiency # Assume 5x cost improvement = 100% efficiency efficiency_score = min(100, (cost_ratio - 1) * 20) - return max(0, efficiency_score) \ No newline at end of file + return max(0, efficiency_score) diff --git a/src/benchmarking/framework.py b/src/benchmarking/framework.py index e1e8cc7..312f395 100644 --- a/src/benchmarking/framework.py +++ b/src/benchmarking/framework.py @@ -22,11 +22,12 @@ # Fall back to absolute imports (when run as script) import sys from pathlib import Path + # Add src to path if not already there src_path = str(Path(__file__).parent.parent) if src_path not in sys.path: sys.path.insert(0, src_path) - + from cache.manager import CacheManager from core.driver import process_user_query from monitoring.metrics import get_metrics_collector @@ -37,6 +38,7 @@ @dataclass class BenchmarkConfig: """Configuration for benchmark execution.""" + iterations: int = 10 warmup_iterations: int = 3 concurrent_users: int = 1 @@ -52,6 +54,7 @@ class BenchmarkConfig: @dataclass class BenchmarkResult: """Individual benchmark measurement result.""" + query: str response_time_ms: float success: bool @@ -66,12 +69,13 @@ class BenchmarkResult: @dataclass class BenchmarkSummary: """Aggregated benchmark results.""" + total_queries: int successful_queries: int failed_queries: int cache_hits: int cache_misses: int - + # Latency metrics avg_response_time_ms: float min_response_time_ms: float @@ -79,28 +83,28 @@ class BenchmarkSummary: p50_response_time_ms: float p95_response_time_ms: float p99_response_time_ms: float - + # Cache-specific latency avg_hit_latency_ms: float avg_miss_latency_ms: float - + # Cost metrics total_token_cost: float avg_token_cost: float estimated_savings: float cost_reduction_percentage: float - + # Performance targets hit_latency_target_met: bool miss_latency_target_met: bool cost_reduction_target_met: bool cache_hit_rate_target_met: bool - + # Quality metrics error_rate: float cache_hit_rate: float throughput_qps: float - + execution_time_seconds: float timestamp: float = field(default_factory=time.time) @@ -108,61 +112,63 @@ class BenchmarkSummary: class BenchmarkFramework: """ Core benchmarking framework for FACT performance validation. - + Provides comprehensive measurement of response times, token costs, and comparison capabilities with traditional RAG systems. """ - + def __init__(self, config: Optional[BenchmarkConfig] = None): """ Initialize benchmarking framework. - + Args: config: Benchmark configuration """ self.config = config or BenchmarkConfig() self.metrics_collector = get_metrics_collector() self.results_history: List[BenchmarkResult] = [] - + # Token cost estimation (Claude pricing) self.input_token_cost = 0.000003 # $0.003 per 1K tokens self.output_token_cost = 0.000015 # $0.015 per 1K tokens - + logger.info("Benchmark framework initialized", config=self.config) - - async def run_single_benchmark(self, - query: str, - cache_manager: Optional[CacheManager] = None) -> BenchmarkResult: + + async def run_single_benchmark( + self, query: str, cache_manager: Optional[CacheManager] = None + ) -> BenchmarkResult: """ Run a single benchmark measurement. - + Args: query: Query to benchmark cache_manager: Optional cache manager for hit detection - + Returns: Benchmark result """ start_time = time.perf_counter() timestamp = time.time() - + try: # Pre-check for cache hit detection cache_hit = False pre_cache_check_time = 0.0 - + if cache_manager: pre_check_start = time.perf_counter() query_hash = cache_manager.generate_hash(query) cached_result = cache_manager.get(query_hash) pre_cache_check_time = (time.perf_counter() - pre_check_start) * 1000 cache_hit = cached_result is not None - + # If it's a cache hit, measure actual cache latency if cache_hit: response = cached_result.content if cached_result else "" end_time = time.perf_counter() - response_time_ms = pre_cache_check_time # Use actual cache access time + response_time_ms = ( + pre_cache_check_time # Use actual cache access time + ) else: # Execute query for cache miss response = await process_user_query(query) @@ -173,11 +179,11 @@ async def run_single_benchmark(self, response = await process_user_query(query) end_time = time.perf_counter() response_time_ms = (end_time - start_time) * 1000 - + # Enhanced token cost calculation token_count = self._estimate_token_count(query, response) token_cost = self._calculate_enhanced_token_cost(token_count, cache_hit) - + result = BenchmarkResult( query=query, response_time_ms=response_time_ms, @@ -185,71 +191,81 @@ async def run_single_benchmark(self, cache_hit=cache_hit, token_count=token_count, token_cost=token_cost, - timestamp=timestamp + timestamp=timestamp, ) - + self.results_history.append(result) - - logger.debug("Benchmark completed", - response_time_ms=response_time_ms, - cache_hit=cache_hit, - token_count=token_count) - + + logger.debug( + "Benchmark completed", + response_time_ms=response_time_ms, + cache_hit=cache_hit, + token_count=token_count, + ) + return result - + except Exception as e: end_time = time.perf_counter() response_time_ms = (end_time - start_time) * 1000 - + result = BenchmarkResult( query=query, response_time_ms=response_time_ms, success=False, cache_hit=False, error=str(e), - timestamp=timestamp + timestamp=timestamp, ) - + self.results_history.append(result) - - logger.error("Benchmark failed", - query=query, - error=str(e), - response_time_ms=response_time_ms) - + + logger.error( + "Benchmark failed", + query=query, + error=str(e), + response_time_ms=response_time_ms, + ) + return result - - async def run_benchmark_suite(self, - queries: List[str], - cache_manager: Optional[CacheManager] = None) -> BenchmarkSummary: + + async def run_benchmark_suite( + self, queries: List[str], cache_manager: Optional[CacheManager] = None + ) -> BenchmarkSummary: """ Run a complete benchmark suite. - + Args: queries: List of queries to benchmark cache_manager: Optional cache manager - + Returns: Benchmark summary """ - logger.info("Starting benchmark suite", - total_queries=len(queries) * self.config.iterations, - iterations=self.config.iterations) - + logger.info( + "Starting benchmark suite", + total_queries=len(queries) * self.config.iterations, + iterations=self.config.iterations, + ) + start_time = time.perf_counter() all_results: List[BenchmarkResult] = [] - + # Warmup phase if self.config.warmup_iterations > 0: - logger.info("Running warmup phase", iterations=self.config.warmup_iterations) + logger.info( + "Running warmup phase", iterations=self.config.warmup_iterations + ) for _ in range(self.config.warmup_iterations): - for query in queries[:min(len(queries), 3)]: # Use first 3 queries for warmup + for query in queries[ + : min(len(queries), 3) + ]: # Use first 3 queries for warmup await self.run_single_benchmark(query, cache_manager) - + # Main benchmark phase for iteration in range(self.config.iterations): logger.debug("Benchmark iteration", iteration=iteration + 1) - + if self.config.concurrent_users == 1: # Sequential execution for query in queries: @@ -261,88 +277,123 @@ async def run_benchmark_suite(self, for query in queries: for _ in range(self.config.concurrent_users): task = asyncio.create_task( - self.run_single_benchmark(f"{query} (user {_})", cache_manager) + self.run_single_benchmark( + f"{query} (user {_})", cache_manager + ) ) tasks.append(task) - + batch_results = await asyncio.gather(*tasks, return_exceptions=True) for result in batch_results: if isinstance(result, BenchmarkResult): all_results.append(result) - + end_time = time.perf_counter() execution_time = end_time - start_time - + # Generate summary summary = self._generate_summary(all_results, execution_time) - - logger.info("Benchmark suite completed", - total_queries=summary.total_queries, - avg_response_time_ms=summary.avg_response_time_ms, - cache_hit_rate=summary.cache_hit_rate, - execution_time_seconds=execution_time) - + + logger.info( + "Benchmark suite completed", + total_queries=summary.total_queries, + avg_response_time_ms=summary.avg_response_time_ms, + cache_hit_rate=summary.cache_hit_rate, + execution_time_seconds=execution_time, + ) + return summary - - def _generate_summary(self, results: List[BenchmarkResult], execution_time: float) -> BenchmarkSummary: + + def _generate_summary( + self, results: List[BenchmarkResult], execution_time: float + ) -> BenchmarkSummary: """Generate benchmark summary from results.""" if not results: return BenchmarkSummary( - total_queries=0, successful_queries=0, failed_queries=0, - cache_hits=0, cache_misses=0, avg_response_time_ms=0.0, - min_response_time_ms=0.0, max_response_time_ms=0.0, - p50_response_time_ms=0.0, p95_response_time_ms=0.0, p99_response_time_ms=0.0, - avg_hit_latency_ms=0.0, avg_miss_latency_ms=0.0, - total_token_cost=0.0, avg_token_cost=0.0, estimated_savings=0.0, - cost_reduction_percentage=0.0, hit_latency_target_met=False, - miss_latency_target_met=False, cost_reduction_target_met=False, - cache_hit_rate_target_met=False, error_rate=0.0, cache_hit_rate=0.0, - throughput_qps=0.0, execution_time_seconds=execution_time + total_queries=0, + successful_queries=0, + failed_queries=0, + cache_hits=0, + cache_misses=0, + avg_response_time_ms=0.0, + min_response_time_ms=0.0, + max_response_time_ms=0.0, + p50_response_time_ms=0.0, + p95_response_time_ms=0.0, + p99_response_time_ms=0.0, + avg_hit_latency_ms=0.0, + avg_miss_latency_ms=0.0, + total_token_cost=0.0, + avg_token_cost=0.0, + estimated_savings=0.0, + cost_reduction_percentage=0.0, + hit_latency_target_met=False, + miss_latency_target_met=False, + cost_reduction_target_met=False, + cache_hit_rate_target_met=False, + error_rate=0.0, + cache_hit_rate=0.0, + throughput_qps=0.0, + execution_time_seconds=execution_time, ) - + # Basic counts total_queries = len(results) successful_queries = sum(1 for r in results if r.success) failed_queries = total_queries - successful_queries cache_hits = sum(1 for r in results if r.cache_hit) cache_misses = total_queries - cache_hits - + # Latency calculations response_times = [r.response_time_ms for r in results if r.success] - hit_latencies = [r.response_time_ms for r in results if r.success and r.cache_hit] - miss_latencies = [r.response_time_ms for r in results if r.success and not r.cache_hit] - + hit_latencies = [ + r.response_time_ms for r in results if r.success and r.cache_hit + ] + miss_latencies = [ + r.response_time_ms for r in results if r.success and not r.cache_hit + ] + avg_response_time = statistics.mean(response_times) if response_times else 0.0 min_response_time = min(response_times) if response_times else 0.0 max_response_time = max(response_times) if response_times else 0.0 - + # Percentiles p50 = statistics.median(response_times) if response_times else 0.0 - p95 = statistics.quantiles(response_times, n=20)[18] if len(response_times) >= 20 else max_response_time - p99 = statistics.quantiles(response_times, n=100)[98] if len(response_times) >= 100 else max_response_time - + p95 = ( + statistics.quantiles(response_times, n=20)[18] + if len(response_times) >= 20 + else max_response_time + ) + p99 = ( + statistics.quantiles(response_times, n=100)[98] + if len(response_times) >= 100 + else max_response_time + ) + avg_hit_latency = statistics.mean(hit_latencies) if hit_latencies else 0.0 avg_miss_latency = statistics.mean(miss_latencies) if miss_latencies else 0.0 - + # Cost calculations token_costs = [r.token_cost for r in results if r.token_cost is not None] total_token_cost = sum(token_costs) if token_costs else 0.0 avg_token_cost = statistics.mean(token_costs) if token_costs else 0.0 - + # Enhanced cost savings calculation with realistic baseline if token_costs: # Calculate baseline cost using industry-standard RAG system assumptions avg_tokens_per_query = 1200 # Conservative estimate for RAG systems baseline_token_cost_per_query = avg_tokens_per_query * self.input_token_cost baseline_cost = total_queries * baseline_token_cost_per_query - + # FACT system cost (actual usage) fact_cost = total_token_cost - + # Calculate savings and percentage estimated_savings = max(0, baseline_cost - fact_cost) - cost_reduction_percentage = (estimated_savings / baseline_cost * 100) if baseline_cost > 0 else 0.0 - + cost_reduction_percentage = ( + (estimated_savings / baseline_cost * 100) if baseline_cost > 0 else 0.0 + ) + # Ensure realistic bounds (should be achievable with FACT) if cost_reduction_percentage > 95: cost_reduction_percentage = 95.0 # Cap at 95% for realism @@ -352,18 +403,22 @@ def _generate_summary(self, results: List[BenchmarkResult], execution_time: floa baseline_cost = 0.0 estimated_savings = 0.0 cost_reduction_percentage = 90.0 # Default expected value when no data - + # Quality metrics - error_rate = (failed_queries / total_queries * 100) if total_queries > 0 else 0.0 + error_rate = ( + (failed_queries / total_queries * 100) if total_queries > 0 else 0.0 + ) cache_hit_rate = (cache_hits / total_queries) if total_queries > 0 else 0.0 throughput_qps = total_queries / execution_time if execution_time > 0 else 0.0 - + # Target validation hit_latency_target_met = avg_hit_latency <= self.config.target_hit_latency_ms miss_latency_target_met = avg_miss_latency <= self.config.target_miss_latency_ms - cost_reduction_target_met = cost_reduction_percentage >= (self.config.target_cost_reduction_hit * 100) + cost_reduction_target_met = cost_reduction_percentage >= ( + self.config.target_cost_reduction_hit * 100 + ) cache_hit_rate_target_met = cache_hit_rate >= self.config.target_cache_hit_rate - + return BenchmarkSummary( total_queries=total_queries, successful_queries=successful_queries, @@ -389,26 +444,30 @@ def _generate_summary(self, results: List[BenchmarkResult], execution_time: floa error_rate=error_rate, cache_hit_rate=cache_hit_rate, throughput_qps=throughput_qps, - execution_time_seconds=execution_time + execution_time_seconds=execution_time, ) - + def _estimate_token_count(self, query: str, response: str) -> int: """Estimate token count for query and response.""" # Rough estimation: ~4 characters per token query_tokens = len(query) // 4 response_tokens = len(response) // 4 if response else 0 return query_tokens + response_tokens - + def _calculate_token_cost(self, token_count: int) -> float: """Calculate token cost based on estimated usage.""" # Assume 70% input tokens, 30% output tokens input_tokens = int(token_count * 0.7) output_tokens = int(token_count * 0.3) - - return (input_tokens * self.input_token_cost + - output_tokens * self.output_token_cost) - - def _calculate_enhanced_token_cost(self, token_count: int, cache_hit: bool) -> float: + + return ( + input_tokens * self.input_token_cost + + output_tokens * self.output_token_cost + ) + + def _calculate_enhanced_token_cost( + self, token_count: int, cache_hit: bool + ) -> float: """Enhanced token cost calculation considering cache efficiency.""" if cache_hit: # Cache hits have minimal token costs (only retrieval overhead) @@ -417,18 +476,20 @@ def _calculate_enhanced_token_cost(self, token_count: int, cache_hit: bool) -> f # Cache misses use full token processing input_tokens = int(token_count * 0.65) # Slightly optimized with FACT output_tokens = int(token_count * 0.35) - - return (input_tokens * self.input_token_cost + - output_tokens * self.output_token_cost) + + return ( + input_tokens * self.input_token_cost + + output_tokens * self.output_token_cost + ) class BenchmarkRunner: """High-level benchmark execution orchestrator.""" - + def __init__(self, framework: Optional[BenchmarkFramework] = None): """ Initialize benchmark runner. - + Args: framework: Benchmark framework instance """ @@ -443,27 +504,30 @@ def __init__(self, framework: Optional[BenchmarkFramework] = None): "Compare performance across regions", "What is the customer acquisition cost?", "Show quarterly expense breakdown", - "Predict next quarter's revenue" + "Predict next quarter's revenue", ] - + logger.info("Benchmark runner initialized") - - async def run_performance_validation(self, - cache_manager: Optional[CacheManager] = None) -> Dict[str, Any]: + + async def run_performance_validation( + self, cache_manager: Optional[CacheManager] = None + ) -> Dict[str, Any]: """ Run complete performance validation against targets. - + Args: cache_manager: Optional cache manager - + Returns: Validation results """ logger.info("Starting performance validation") - + # Run benchmark suite - summary = await self.framework.run_benchmark_suite(self.test_queries, cache_manager) - + summary = await self.framework.run_benchmark_suite( + self.test_queries, cache_manager + ) + # Validate against targets validation_results = { "timestamp": time.time(), @@ -473,100 +537,111 @@ async def run_performance_validation(self, "target_ms": self.framework.config.target_hit_latency_ms, "actual_ms": summary.avg_hit_latency_ms, "met": summary.hit_latency_target_met, - "margin_ms": self.framework.config.target_hit_latency_ms - summary.avg_hit_latency_ms + "margin_ms": self.framework.config.target_hit_latency_ms + - summary.avg_hit_latency_ms, }, "cache_miss_latency": { "target_ms": self.framework.config.target_miss_latency_ms, "actual_ms": summary.avg_miss_latency_ms, "met": summary.miss_latency_target_met, - "margin_ms": self.framework.config.target_miss_latency_ms - summary.avg_miss_latency_ms + "margin_ms": self.framework.config.target_miss_latency_ms + - summary.avg_miss_latency_ms, }, "cost_reduction": { - "target_percent": self.framework.config.target_cost_reduction_hit * 100, + "target_percent": self.framework.config.target_cost_reduction_hit + * 100, "actual_percent": summary.cost_reduction_percentage, "met": summary.cost_reduction_target_met, - "margin_percent": summary.cost_reduction_percentage - (self.framework.config.target_cost_reduction_hit * 100) + "margin_percent": summary.cost_reduction_percentage + - (self.framework.config.target_cost_reduction_hit * 100), }, "cache_hit_rate": { "target_percent": self.framework.config.target_cache_hit_rate * 100, "actual_percent": summary.cache_hit_rate, "met": summary.cache_hit_rate_target_met, - "margin_percent": summary.cache_hit_rate - (self.framework.config.target_cache_hit_rate * 100) - } + "margin_percent": summary.cache_hit_rate + - (self.framework.config.target_cache_hit_rate * 100), + }, }, - "overall_pass": all([ - summary.hit_latency_target_met, - summary.miss_latency_target_met, - summary.cost_reduction_target_met, - summary.cache_hit_rate_target_met - ]) + "overall_pass": all( + [ + summary.hit_latency_target_met, + summary.miss_latency_target_met, + summary.cost_reduction_target_met, + summary.cache_hit_rate_target_met, + ] + ), } - - logger.info("Performance validation completed", - overall_pass=validation_results["overall_pass"], - hit_latency_met=summary.hit_latency_target_met, - cost_reduction_met=summary.cost_reduction_target_met) - + + logger.info( + "Performance validation completed", + overall_pass=validation_results["overall_pass"], + hit_latency_met=summary.hit_latency_target_met, + cost_reduction_met=summary.cost_reduction_target_met, + ) + return validation_results - - async def run_load_test(self, - concurrent_users: int = 10, - duration_seconds: int = 60) -> Dict[str, Any]: + + async def run_load_test( + self, concurrent_users: int = 10, duration_seconds: int = 60 + ) -> Dict[str, Any]: """ Run load testing to validate performance under concurrent load. - + Args: concurrent_users: Number of concurrent users duration_seconds: Test duration - + Returns: Load test results """ - logger.info("Starting load test", - concurrent_users=concurrent_users, - duration_seconds=duration_seconds) - + logger.info( + "Starting load test", + concurrent_users=concurrent_users, + duration_seconds=duration_seconds, + ) + # Update config for load testing original_config = self.framework.config self.framework.config = BenchmarkConfig( iterations=1, concurrent_users=concurrent_users, - timeout_seconds=duration_seconds + timeout_seconds=duration_seconds, ) - + start_time = time.time() results = [] - + # Run concurrent sessions async def user_session(user_id: int): session_results = [] end_time = start_time + duration_seconds - + while time.time() < end_time: query = f"{self.test_queries[user_id % len(self.test_queries)]} (user {user_id})" result = await self.framework.run_single_benchmark(query) session_results.append(result) - + # Brief pause between queries await asyncio.sleep(0.1) - + return session_results - + # Execute concurrent sessions tasks = [user_session(i) for i in range(concurrent_users)] session_results = await asyncio.gather(*tasks) - + # Flatten results for session in session_results: results.extend(session) - + # Restore original config self.framework.config = original_config - + # Generate load test summary execution_time = time.time() - start_time summary = self.framework._generate_summary(results, execution_time) - + load_test_results = { "timestamp": time.time(), "concurrent_users": concurrent_users, @@ -577,32 +652,40 @@ async def user_session(user_id: int): "p95_response_time_ms": summary.p95_response_time_ms, "error_rate": summary.error_rate, "cache_hit_rate": summary.cache_hit_rate, - "performance_degradation": self._calculate_performance_degradation(results) + "performance_degradation": self._calculate_performance_degradation(results), } - - logger.info("Load test completed", - throughput_qps=summary.throughput_qps, - avg_response_time_ms=summary.avg_response_time_ms) - + + logger.info( + "Load test completed", + throughput_qps=summary.throughput_qps, + avg_response_time_ms=summary.avg_response_time_ms, + ) + return load_test_results - - def _calculate_performance_degradation(self, results: List[BenchmarkResult]) -> float: + + def _calculate_performance_degradation( + self, results: List[BenchmarkResult] + ) -> float: """Calculate performance degradation over time during load test.""" if len(results) < 10: return 0.0 - + # Compare first 10% vs last 10% of results early_count = max(1, len(results) // 10) late_count = max(1, len(results) // 10) - + early_results = results[:early_count] late_results = results[-late_count:] - - early_avg = statistics.mean(r.response_time_ms for r in early_results if r.success) - late_avg = statistics.mean(r.response_time_ms for r in late_results if r.success) - + + early_avg = statistics.mean( + r.response_time_ms for r in early_results if r.success + ) + late_avg = statistics.mean( + r.response_time_ms for r in late_results if r.success + ) + if early_avg == 0: return 0.0 - + degradation = (late_avg - early_avg) / early_avg * 100 - return max(0.0, degradation) \ No newline at end of file + return max(0.0, degradation) diff --git a/src/benchmarking/monitoring.py b/src/benchmarking/monitoring.py index b73b34f..8d4c54a 100644 --- a/src/benchmarking/monitoring.py +++ b/src/benchmarking/monitoring.py @@ -24,11 +24,12 @@ # Fall back to absolute imports (when run as script) import sys from pathlib import Path + # Add src to path if not already there src_path = str(Path(__file__).parent.parent) if src_path not in sys.path: sys.path.insert(0, src_path) - + from benchmarking.framework import BenchmarkFramework, BenchmarkResult from benchmarking.profiler import SystemProfiler, BottleneckAnalyzer from cache.manager import CacheManager @@ -40,6 +41,7 @@ @dataclass class PerformanceAlert: """Performance alert definition.""" + alert_id: str severity: str # "info", "warning", "critical" component: str @@ -55,6 +57,7 @@ class PerformanceAlert: @dataclass class PerformanceTrend: """Performance trend analysis.""" + metric_name: str time_period: str direction: str # "improving", "stable", "degrading" @@ -67,12 +70,13 @@ class PerformanceTrend: @dataclass class MonitoringConfig: """Configuration for continuous monitoring.""" + monitoring_interval_seconds: int = 60 alert_check_interval_seconds: int = 30 trend_analysis_hours: int = 24 max_alerts_per_hour: int = 10 alert_cooldown_minutes: int = 15 - + # Performance thresholds response_time_warning_ms: float = 80.0 response_time_critical_ms: float = 120.0 @@ -87,15 +91,15 @@ class MonitoringConfig: class ContinuousMonitor: """ Continuous performance monitoring system. - + Provides real-time monitoring, alerting, and trend analysis for FACT system performance. """ - + def __init__(self, config: Optional[MonitoringConfig] = None): """ Initialize continuous monitor. - + Args: config: Monitoring configuration """ @@ -103,64 +107,62 @@ def __init__(self, config: Optional[MonitoringConfig] = None): self.benchmark_framework = BenchmarkFramework() self.profiler = SystemProfiler() self.bottleneck_analyzer = BottleneckAnalyzer() - + # Monitoring state self.monitoring_active = False self.monitor_task: Optional[asyncio.Task] = None self.alert_task: Optional[asyncio.Task] = None - + # Data storage self.performance_history: deque = deque(maxlen=10000) self.active_alerts: Dict[str, PerformanceAlert] = {} self.alert_history: deque = deque(maxlen=1000) self.trends: Dict[str, PerformanceTrend] = {} - + # Alert callbacks self.alert_callbacks: List[Callable[[PerformanceAlert], None]] = [] - + # Test queries for monitoring self.monitoring_queries = [ "What is the current system status?", "Generate a quick performance summary", - "Check recent metrics" + "Check recent metrics", ] - + logger.info("Continuous monitor initialized") - + async def start_monitoring(self, cache_manager: Optional[CacheManager] = None): """ Start continuous performance monitoring. - + Args: cache_manager: Cache manager to monitor """ if self.monitoring_active: logger.warning("Monitoring already active") return - + self.monitoring_active = True - + # Start profiler monitoring await self.profiler.start_continuous_monitoring() - + # Start monitoring tasks - self.monitor_task = asyncio.create_task( - self._monitoring_loop(cache_manager) - ) - self.alert_task = asyncio.create_task( - self._alert_checking_loop(cache_manager) + self.monitor_task = asyncio.create_task(self._monitoring_loop(cache_manager)) + self.alert_task = asyncio.create_task(self._alert_checking_loop(cache_manager)) + + logger.info( + "Continuous monitoring started", + interval_seconds=self.config.monitoring_interval_seconds, ) - - logger.info("Continuous monitoring started", - interval_seconds=self.config.monitoring_interval_seconds) - + async def stop_monitoring(self): """Stop continuous performance monitoring.""" self.monitoring_active = False - + # Stop profiler monitoring await self.profiler.stop_continuous_monitoring() - + # Cancel monitoring tasks if self.monitor_task: self.monitor_task.cancel() @@ -168,112 +170,118 @@ async def stop_monitoring(self): await self.monitor_task except asyncio.CancelledError: pass - + if self.alert_task: self.alert_task.cancel() try: await self.alert_task except asyncio.CancelledError: pass - + logger.info("Continuous monitoring stopped") - + async def _monitoring_loop(self, cache_manager: Optional[CacheManager]): """Main monitoring loop.""" query_index = 0 - + while self.monitoring_active: try: # Run performance measurement - query = self.monitoring_queries[query_index % len(self.monitoring_queries)] - + query = self.monitoring_queries[ + query_index % len(self.monitoring_queries) + ] + start_time = time.perf_counter() result = await self.benchmark_framework.run_single_benchmark( query, cache_manager ) - + # Store performance data self.performance_history.append(result) - + # Update trends await self._update_performance_trends() - + query_index += 1 - - logger.debug("Monitoring measurement completed", - response_time_ms=result.response_time_ms, - cache_hit=result.cache_hit) - + + logger.debug( + "Monitoring measurement completed", + response_time_ms=result.response_time_ms, + cache_hit=result.cache_hit, + ) + await asyncio.sleep(self.config.monitoring_interval_seconds) - + except asyncio.CancelledError: break except Exception as e: logger.error("Error in monitoring loop", error=str(e)) await asyncio.sleep(5.0) # Brief pause before retry - + async def _alert_checking_loop(self, cache_manager: Optional[CacheManager]): """Alert checking loop.""" while self.monitoring_active: try: # Check for performance alerts await self._check_performance_alerts(cache_manager) - + # Check for trend alerts await self._check_trend_alerts() - + # Clean up resolved alerts self._cleanup_resolved_alerts() - + await asyncio.sleep(self.config.alert_check_interval_seconds) - + except asyncio.CancelledError: break except Exception as e: logger.error("Error in alert checking loop", error=str(e)) await asyncio.sleep(5.0) - + async def _check_performance_alerts(self, cache_manager: Optional[CacheManager]): """Check for performance threshold violations.""" if not self.performance_history: return - + # Get recent performance data recent_results = list(self.performance_history)[-10:] # Last 10 measurements - + if not recent_results: return - + # Calculate current metrics response_times = [r.response_time_ms for r in recent_results if r.success] cache_hits = sum(1 for r in recent_results if r.cache_hit) total_requests = len(recent_results) errors = sum(1 for r in recent_results if not r.success) - + if not response_times: return - + avg_response_time = sum(response_times) / len(response_times) - cache_hit_rate = (cache_hits / total_requests * 100) if total_requests > 0 else 0 + cache_hit_rate = ( + (cache_hits / total_requests * 100) if total_requests > 0 else 0 + ) error_rate = (errors / total_requests * 100) if total_requests > 0 else 0 - + # Check response time alerts await self._check_response_time_alert(avg_response_time) - + # Check cache hit rate alerts await self._check_cache_hit_rate_alert(cache_hit_rate) - + # Check error rate alerts await self._check_error_rate_alert(error_rate) - + # Check cache-specific alerts if cache_manager: await self._check_cache_alerts(cache_manager) - + async def _check_response_time_alert(self, avg_response_time: float): """Check response time threshold alerts.""" alert_id = "response_time" - + if avg_response_time >= self.config.response_time_critical_ms: severity = "critical" threshold = self.config.response_time_critical_ms @@ -285,7 +293,7 @@ async def _check_response_time_alert(self, avg_response_time: float): if alert_id in self.active_alerts: await self._resolve_alert(alert_id) return - + # Create or update alert if alert_id not in self.active_alerts: alert = PerformanceAlert( @@ -295,15 +303,15 @@ async def _check_response_time_alert(self, avg_response_time: float): metric_name="avg_response_time_ms", threshold_value=threshold, actual_value=avg_response_time, - message=f"Average response time {avg_response_time:.1f}ms exceeds {severity} threshold {threshold}ms" + message=f"Average response time {avg_response_time:.1f}ms exceeds {severity} threshold {threshold}ms", ) - + await self._trigger_alert(alert) - + async def _check_cache_hit_rate_alert(self, cache_hit_rate: float): """Check cache hit rate threshold alerts.""" alert_id = "cache_hit_rate" - + if cache_hit_rate <= self.config.cache_hit_rate_critical: severity = "critical" threshold = self.config.cache_hit_rate_critical @@ -315,7 +323,7 @@ async def _check_cache_hit_rate_alert(self, cache_hit_rate: float): if alert_id in self.active_alerts: await self._resolve_alert(alert_id) return - + if alert_id not in self.active_alerts: alert = PerformanceAlert( alert_id=alert_id, @@ -324,15 +332,15 @@ async def _check_cache_hit_rate_alert(self, cache_hit_rate: float): metric_name="cache_hit_rate", threshold_value=threshold, actual_value=cache_hit_rate, - message=f"Cache hit rate {cache_hit_rate:.1f}% below {severity} threshold {threshold}%" + message=f"Cache hit rate {cache_hit_rate:.1f}% below {severity} threshold {threshold}%", ) - + await self._trigger_alert(alert) - + async def _check_error_rate_alert(self, error_rate: float): """Check error rate threshold alerts.""" alert_id = "error_rate" - + if error_rate >= self.config.error_rate_critical: severity = "critical" threshold = self.config.error_rate_critical @@ -344,7 +352,7 @@ async def _check_error_rate_alert(self, error_rate: float): if alert_id in self.active_alerts: await self._resolve_alert(alert_id) return - + if alert_id not in self.active_alerts: alert = PerformanceAlert( alert_id=alert_id, @@ -353,18 +361,18 @@ async def _check_error_rate_alert(self, error_rate: float): metric_name="error_rate", threshold_value=threshold, actual_value=error_rate, - message=f"Error rate {error_rate:.1f}% exceeds {severity} threshold {threshold}%" + message=f"Error rate {error_rate:.1f}% exceeds {severity} threshold {threshold}%", ) - + await self._trigger_alert(alert) - + async def _check_cache_alerts(self, cache_manager: CacheManager): """Check cache-specific alerts.""" try: metrics = cache_manager.get_metrics() - + # Memory utilization alert - memory_util = (metrics.total_size / cache_manager.max_size_bytes * 100) + memory_util = metrics.total_size / cache_manager.max_size_bytes * 100 if memory_util > 90: alert_id = "cache_memory" if alert_id not in self.active_alerts: @@ -375,19 +383,19 @@ async def _check_cache_alerts(self, cache_manager: CacheManager): metric_name="memory_utilization", threshold_value=90.0, actual_value=memory_util, - message=f"Cache memory utilization {memory_util:.1f}% approaching limit" + message=f"Cache memory utilization {memory_util:.1f}% approaching limit", ) await self._trigger_alert(alert) - + except Exception as e: logger.error("Error checking cache alerts", error=str(e)) - + async def _check_trend_alerts(self): """Check for concerning performance trends.""" for metric_name, trend in self.trends.items(): if trend.direction == "degrading" and trend.significance > 0.7: alert_id = f"trend_{metric_name}" - + if alert_id not in self.active_alerts: alert = PerformanceAlert( alert_id=alert_id, @@ -396,90 +404,95 @@ async def _check_trend_alerts(self): metric_name=metric_name, threshold_value=0.0, # Trend-based, no fixed threshold actual_value=trend.change_percentage, - message=f"{metric_name} showing degrading trend: {trend.change_percentage:.1f}% change" + message=f"{metric_name} showing degrading trend: {trend.change_percentage:.1f}% change", ) - + await self._trigger_alert(alert) - + async def _trigger_alert(self, alert: PerformanceAlert): """Trigger a performance alert.""" self.active_alerts[alert.alert_id] = alert self.alert_history.append(alert) - - logger.warning("Performance alert triggered", - alert_id=alert.alert_id, - severity=alert.severity, - component=alert.component, - message=alert.message) - + + logger.warning( + "Performance alert triggered", + alert_id=alert.alert_id, + severity=alert.severity, + component=alert.component, + message=alert.message, + ) + # Call alert callbacks for callback in self.alert_callbacks: try: callback(alert) except Exception as e: logger.error("Error in alert callback", error=str(e)) - + async def _resolve_alert(self, alert_id: str): """Resolve an active alert.""" if alert_id in self.active_alerts: alert = self.active_alerts[alert_id] alert.resolved = True alert.resolution_time = time.time() - + del self.active_alerts[alert_id] - - logger.info("Performance alert resolved", - alert_id=alert_id, - component=alert.component) - + + logger.info( + "Performance alert resolved", + alert_id=alert_id, + component=alert.component, + ) + def _cleanup_resolved_alerts(self): """Clean up old resolved alerts from history.""" cutoff_time = time.time() - (24 * 60 * 60) # 24 hours ago - + # Keep only recent alerts recent_alerts = deque() for alert in self.alert_history: if alert.timestamp >= cutoff_time: recent_alerts.append(alert) - + self.alert_history = recent_alerts - + async def _update_performance_trends(self): """Update performance trend analysis.""" if len(self.performance_history) < 10: return - + # Analyze trends for key metrics metrics_to_analyze = [ ("response_time", lambda r: r.response_time_ms), ("cache_hit_rate", lambda r: 1.0 if r.cache_hit else 0.0), - ("success_rate", lambda r: 1.0 if r.success else 0.0) + ("success_rate", lambda r: 1.0 if r.success else 0.0), ] - + for metric_name, extractor in metrics_to_analyze: trend = self._calculate_trend(metric_name, extractor) if trend: self.trends[metric_name] = trend - - def _calculate_trend(self, metric_name: str, value_extractor: Callable) -> Optional[PerformanceTrend]: + + def _calculate_trend( + self, metric_name: str, value_extractor: Callable + ) -> Optional[PerformanceTrend]: """Calculate trend for a specific metric.""" try: # Get recent data points recent_data = list(self.performance_history)[-100:] # Last 100 measurements - + if len(recent_data) < 10: return None - + # Extract values and timestamps data_points = [ - (result.timestamp, value_extractor(result)) - for result in recent_data + (result.timestamp, value_extractor(result)) for result in recent_data ] - + # Simple linear trend analysis timestamps = [p[0] for p in data_points] values = [p[1] for p in data_points] - + if not values or all(v == values[0] for v in values): return PerformanceTrend( metric_name=metric_name, @@ -488,38 +501,42 @@ def _calculate_trend(self, metric_name: str, value_extractor: Callable) -> Optio change_percentage=0.0, significance=1.0, data_points=data_points, - analysis="No significant variation detected" + analysis="No significant variation detected", ) - + # Calculate slope (trend direction) n = len(values) sum_x = sum(range(n)) sum_y = sum(values) sum_xy = sum(i * v for i, v in enumerate(values)) sum_x2 = sum(i * i for i in range(n)) - + slope = (n * sum_xy - sum_x * sum_y) / (n * sum_x2 - sum_x * sum_x) - + # Calculate percentage change first_avg = sum(values[:5]) / 5 if len(values) >= 5 else values[0] last_avg = sum(values[-5:]) / 5 if len(values) >= 5 else values[-1] - + if first_avg != 0: change_percentage = ((last_avg - first_avg) / first_avg) * 100 else: change_percentage = 0.0 - + # Determine trend direction if abs(change_percentage) < 5.0: direction = "stable" elif change_percentage > 0: - direction = "improving" if metric_name != "response_time" else "degrading" + direction = ( + "improving" if metric_name != "response_time" else "degrading" + ) else: - direction = "degrading" if metric_name != "response_time" else "improving" - + direction = ( + "degrading" if metric_name != "response_time" else "improving" + ) + # Calculate significance (correlation strength) significance = min(1.0, abs(slope) * 10) # Simple significance measure - + return PerformanceTrend( metric_name=metric_name, time_period="recent", @@ -527,105 +544,125 @@ def _calculate_trend(self, metric_name: str, value_extractor: Callable) -> Optio change_percentage=abs(change_percentage), significance=significance, data_points=data_points[-20:], # Keep last 20 points - analysis=f"Trend shows {direction} pattern with {change_percentage:.1f}% change" + analysis=f"Trend shows {direction} pattern with {change_percentage:.1f}% change", ) - + except Exception as e: logger.error("Error calculating trend", metric=metric_name, error=str(e)) return None - + def add_alert_callback(self, callback: Callable[[PerformanceAlert], None]): """Add callback function for alert notifications.""" self.alert_callbacks.append(callback) - + def get_monitoring_status(self) -> Dict[str, Any]: """Get current monitoring status.""" return { "monitoring_active": self.monitoring_active, "active_alerts": len(self.active_alerts), "alert_summary": { - severity: len([a for a in self.active_alerts.values() if a.severity == severity]) + severity: len( + [a for a in self.active_alerts.values() if a.severity == severity] + ) for severity in ["info", "warning", "critical"] }, "performance_history_size": len(self.performance_history), "trends_tracked": len(self.trends), - "config": asdict(self.config) + "config": asdict(self.config), } - + def export_monitoring_report(self) -> Dict[str, Any]: """Export comprehensive monitoring report.""" # Recent performance summary recent_results = list(self.performance_history)[-50:] - + if recent_results: successful_results = [r for r in recent_results if r.success] response_times = [r.response_time_ms for r in successful_results] cache_hits = sum(1 for r in recent_results if r.cache_hit) - + performance_summary = { "total_measurements": len(recent_results), "successful_measurements": len(successful_results), - "avg_response_time_ms": sum(response_times) / len(response_times) if response_times else 0, - "cache_hit_rate": (cache_hits / len(recent_results) * 100) if recent_results else 0, - "error_rate": ((len(recent_results) - len(successful_results)) / len(recent_results) * 100) if recent_results else 0 + "avg_response_time_ms": ( + sum(response_times) / len(response_times) if response_times else 0 + ), + "cache_hit_rate": ( + (cache_hits / len(recent_results) * 100) if recent_results else 0 + ), + "error_rate": ( + ( + (len(recent_results) - len(successful_results)) + / len(recent_results) + * 100 + ) + if recent_results + else 0 + ), } else: performance_summary = {} - + return { "timestamp": time.time(), "monitoring_period": "recent", "status": self.get_monitoring_status(), "performance_summary": performance_summary, "active_alerts": [asdict(alert) for alert in self.active_alerts.values()], - "recent_alert_history": [asdict(alert) for alert in list(self.alert_history)[-10:]], + "recent_alert_history": [ + asdict(alert) for alert in list(self.alert_history)[-10:] + ], "trends": {name: asdict(trend) for name, trend in self.trends.items()}, - "recommendations": self._generate_monitoring_recommendations() + "recommendations": self._generate_monitoring_recommendations(), } - + def _generate_monitoring_recommendations(self) -> List[str]: """Generate recommendations based on monitoring data.""" recommendations = [] - + # Check for active critical alerts - critical_alerts = [a for a in self.active_alerts.values() if a.severity == "critical"] + critical_alerts = [ + a for a in self.active_alerts.values() if a.severity == "critical" + ] if critical_alerts: recommendations.append("Address critical performance alerts immediately") - + # Check for degrading trends - degrading_trends = [t for t in self.trends.values() if t.direction == "degrading"] + degrading_trends = [ + t for t in self.trends.values() if t.direction == "degrading" + ] if degrading_trends: recommendations.append("Investigate degrading performance trends") - + # General recommendations if len(self.performance_history) < 100: recommendations.append("Allow more time for comprehensive trend analysis") - + if not recommendations: recommendations.append("Performance monitoring shows healthy system status") - + return recommendations class PerformanceTracker: """ Lightweight performance tracking for specific operations. - + Provides focused tracking for specific performance metrics without full continuous monitoring overhead. """ - + def __init__(self): """Initialize performance tracker.""" self.tracked_operations: Dict[str, List[float]] = {} self.operation_metadata: Dict[str, Dict[str, Any]] = {} - + logger.info("Performance tracker initialized") - + def track_operation(self, operation_name: str, duration_ms: float, **metadata): """ Track a single operation performance. - + Args: operation_name: Name of the operation duration_ms: Operation duration in milliseconds @@ -634,36 +671,40 @@ def track_operation(self, operation_name: str, duration_ms: float, **metadata): if operation_name not in self.tracked_operations: self.tracked_operations[operation_name] = [] self.operation_metadata[operation_name] = {} - + self.tracked_operations[operation_name].append(duration_ms) - + # Update metadata for key, value in metadata.items(): if key not in self.operation_metadata[operation_name]: self.operation_metadata[operation_name][key] = [] self.operation_metadata[operation_name][key].append(value) - + # Keep only recent measurements (last 1000) if len(self.tracked_operations[operation_name]) > 1000: - self.tracked_operations[operation_name] = self.tracked_operations[operation_name][-1000:] - + self.tracked_operations[operation_name] = self.tracked_operations[ + operation_name + ][-1000:] + # Also trim metadata for key in self.operation_metadata[operation_name]: if len(self.operation_metadata[operation_name][key]) > 1000: - self.operation_metadata[operation_name][key] = self.operation_metadata[operation_name][key][-1000:] - + self.operation_metadata[operation_name][key] = ( + self.operation_metadata[operation_name][key][-1000:] + ) + def get_operation_stats(self, operation_name: str) -> Dict[str, Any]: """Get statistics for a tracked operation.""" if operation_name not in self.tracked_operations: return {"error": "Operation not tracked"} - + durations = self.tracked_operations[operation_name] - + if not durations: return {"error": "No data available"} - + import statistics - + stats = { "operation_name": operation_name, "total_executions": len(durations), @@ -673,24 +714,30 @@ def get_operation_stats(self, operation_name: str) -> Dict[str, Any]: "median_duration_ms": statistics.median(durations), "std_deviation": statistics.stdev(durations) if len(durations) > 1 else 0, } - + # Add percentiles if enough data if len(durations) >= 10: - stats.update({ - "p90_duration_ms": statistics.quantiles(durations, n=10)[8], - "p95_duration_ms": statistics.quantiles(durations, n=20)[18], - "p99_duration_ms": statistics.quantiles(durations, n=100)[98] if len(durations) >= 100 else max(durations) - }) - + stats.update( + { + "p90_duration_ms": statistics.quantiles(durations, n=10)[8], + "p95_duration_ms": statistics.quantiles(durations, n=20)[18], + "p99_duration_ms": ( + statistics.quantiles(durations, n=100)[98] + if len(durations) >= 100 + else max(durations) + ), + } + ) + return stats - + def get_all_stats(self) -> Dict[str, Dict[str, Any]]: """Get statistics for all tracked operations.""" return { operation: self.get_operation_stats(operation) for operation in self.tracked_operations.keys() } - + def clear_tracking(self, operation_name: Optional[str] = None): """Clear tracking data for specific operation or all operations.""" if operation_name: @@ -699,5 +746,5 @@ def clear_tracking(self, operation_name: Optional[str] = None): else: self.tracked_operations.clear() self.operation_metadata.clear() - - logger.info("Tracking data cleared", operation=operation_name or "all") \ No newline at end of file + + logger.info("Tracking data cleared", operation=operation_name or "all") diff --git a/src/benchmarking/profiler.py b/src/benchmarking/profiler.py index 429586a..b8f7ea7 100644 --- a/src/benchmarking/profiler.py +++ b/src/benchmarking/profiler.py @@ -22,11 +22,12 @@ # Fall back to absolute imports (when run as script) import sys from pathlib import Path + # Add src to path if not already there src_path = str(Path(__file__).parent.parent) if src_path not in sys.path: sys.path.insert(0, src_path) - + from cache.manager import CacheManager from monitoring.metrics import get_metrics_collector @@ -36,6 +37,7 @@ @dataclass class ProfilePoint: """Individual profiling measurement point.""" + name: str start_time: float end_time: float @@ -48,6 +50,7 @@ class ProfilePoint: @dataclass class SystemSnapshot: """System resource snapshot.""" + timestamp: float cpu_percent: float memory_percent: float @@ -63,6 +66,7 @@ class SystemSnapshot: @dataclass class BottleneckAnalysis: """Analysis of system bottlenecks.""" + component: str severity: str # "low", "medium", "high", "critical" impact_percentage: float @@ -74,6 +78,7 @@ class BottleneckAnalysis: @dataclass class ProfileResult: """Complete profiling result.""" + execution_time_ms: float profile_points: List[ProfilePoint] system_snapshots: List[SystemSnapshot] @@ -86,15 +91,15 @@ class ProfileResult: class SystemProfiler: """ Advanced system profiler for performance analysis. - + Provides detailed profiling of system components, resource usage, and performance bottleneck identification. """ - + def __init__(self, sampling_interval: float = 0.1): """ Initialize system profiler. - + Args: sampling_interval: Resource sampling interval in seconds """ @@ -104,24 +109,23 @@ def __init__(self, sampling_interval: float = 0.1): self.active_profiles: Dict[str, float] = {} self._monitoring = False self._monitor_task: Optional[asyncio.Task] = None - + # Performance thresholds self.thresholds = { "cpu_percent": 80.0, "memory_percent": 85.0, "response_time_ms": 100.0, "cache_latency_ms": 50.0, - "db_latency_ms": 10.0 + "db_latency_ms": 10.0, } - - logger.info("System profiler initialized", - sampling_interval=sampling_interval) - + + logger.info("System profiler initialized", sampling_interval=sampling_interval) + @asynccontextmanager async def profile_operation(self, operation_name: str, **metadata): """ Context manager for profiling operations. - + Args: operation_name: Name of the operation being profiled **metadata: Additional metadata to include @@ -129,18 +133,18 @@ async def profile_operation(self, operation_name: str, **metadata): start_time = time.perf_counter() start_cpu = psutil.cpu_percent() start_memory = psutil.virtual_memory().used / (1024 * 1024) - + try: yield finally: end_time = time.perf_counter() end_cpu = psutil.cpu_percent() end_memory = psutil.virtual_memory().used / (1024 * 1024) - + duration_ms = (end_time - start_time) * 1000 avg_cpu = (start_cpu + end_cpu) / 2 avg_memory = (start_memory + end_memory) / 2 - + profile_point = ProfilePoint( name=operation_name, start_time=start_time, @@ -148,65 +152,67 @@ async def profile_operation(self, operation_name: str, **metadata): duration_ms=duration_ms, cpu_percent=avg_cpu, memory_mb=avg_memory, - metadata=metadata + metadata=metadata, ) - + self.profile_points.append(profile_point) - - logger.debug("Operation profiled", - operation=operation_name, - duration_ms=duration_ms, - cpu_percent=avg_cpu) - + + logger.debug( + "Operation profiled", + operation=operation_name, + duration_ms=duration_ms, + cpu_percent=avg_cpu, + ) + async def start_continuous_monitoring(self): """Start continuous system resource monitoring.""" if self._monitoring: return - + self._monitoring = True self._monitor_task = asyncio.create_task(self._monitor_resources()) - + logger.info("Started continuous monitoring") - + async def stop_continuous_monitoring(self): """Stop continuous system resource monitoring.""" self._monitoring = False - + if self._monitor_task: self._monitor_task.cancel() try: await self._monitor_task except asyncio.CancelledError: pass - + logger.info("Stopped continuous monitoring") - + async def _monitor_resources(self): """Monitor system resources continuously.""" while self._monitoring: try: snapshot = self._take_system_snapshot() self.system_snapshots.append(snapshot) - + # Keep only recent snapshots (last 1000) if len(self.system_snapshots) > 1000: self.system_snapshots = self.system_snapshots[-1000:] - + await asyncio.sleep(self.sampling_interval) - + except asyncio.CancelledError: break except Exception as e: logger.error("Error monitoring resources", error=str(e)) await asyncio.sleep(1.0) - + def _take_system_snapshot(self) -> SystemSnapshot: """Take a snapshot of current system resources.""" try: memory = psutil.virtual_memory() disk_io = psutil.disk_io_counters() network_io = psutil.net_io_counters() - + return SystemSnapshot( timestamp=time.time(), cpu_percent=psutil.cpu_percent(), @@ -214,177 +220,200 @@ def _take_system_snapshot(self) -> SystemSnapshot: memory_available_mb=memory.available / (1024 * 1024), disk_io_read_mb=disk_io.read_bytes / (1024 * 1024) if disk_io else 0, disk_io_write_mb=disk_io.write_bytes / (1024 * 1024) if disk_io else 0, - network_sent_mb=network_io.bytes_sent / (1024 * 1024) if network_io else 0, - network_recv_mb=network_io.bytes_recv / (1024 * 1024) if network_io else 0, + network_sent_mb=( + network_io.bytes_sent / (1024 * 1024) if network_io else 0 + ), + network_recv_mb=( + network_io.bytes_recv / (1024 * 1024) if network_io else 0 + ), process_count=len(psutil.pids()), - thread_count=threading.active_count() + thread_count=threading.active_count(), ) except Exception as e: logger.error("Error taking system snapshot", error=str(e)) return SystemSnapshot( timestamp=time.time(), - cpu_percent=0, memory_percent=0, memory_available_mb=0, - disk_io_read_mb=0, disk_io_write_mb=0, - network_sent_mb=0, network_recv_mb=0, - process_count=0, thread_count=0 + cpu_percent=0, + memory_percent=0, + memory_available_mb=0, + disk_io_read_mb=0, + disk_io_write_mb=0, + network_sent_mb=0, + network_recv_mb=0, + process_count=0, + thread_count=0, ) - - async def profile_complete_operation(self, - operation: Callable, - operation_name: str, - *args, **kwargs) -> Tuple[Any, ProfileResult]: + + async def profile_complete_operation( + self, operation: Callable, operation_name: str, *args, **kwargs + ) -> Tuple[Any, ProfileResult]: """ Profile a complete operation with detailed analysis. - + Args: operation: Async operation to profile operation_name: Name for the operation *args, **kwargs: Arguments for the operation - + Returns: Tuple of (operation_result, profile_result) """ # Clear previous profiling data self.profile_points.clear() self.system_snapshots.clear() - + # Start monitoring await self.start_continuous_monitoring() - + start_time = time.perf_counter() - + try: # Execute operation with profiling async with self.profile_operation(operation_name): result = await operation(*args, **kwargs) - + end_time = time.perf_counter() execution_time_ms = (end_time - start_time) * 1000 - + # Stop monitoring await self.stop_continuous_monitoring() - + # Analyze results bottlenecks = self._analyze_bottlenecks() performance_summary = self._generate_performance_summary() recommendations = self._generate_optimization_recommendations(bottlenecks) - + profile_result = ProfileResult( execution_time_ms=execution_time_ms, profile_points=self.profile_points.copy(), system_snapshots=self.system_snapshots.copy(), bottlenecks=bottlenecks, performance_summary=performance_summary, - optimization_recommendations=recommendations + optimization_recommendations=recommendations, ) - - logger.info("Operation profiling completed", - operation=operation_name, - execution_time_ms=execution_time_ms, - bottlenecks_found=len(bottlenecks)) - + + logger.info( + "Operation profiling completed", + operation=operation_name, + execution_time_ms=execution_time_ms, + bottlenecks_found=len(bottlenecks), + ) + return result, profile_result - + except Exception as e: await self.stop_continuous_monitoring() logger.error("Error during profiling", error=str(e)) raise - + def _analyze_bottlenecks(self) -> List[BottleneckAnalysis]: """Analyze profiling data to identify bottlenecks.""" bottlenecks = [] - + # Analyze CPU bottlenecks cpu_bottlenecks = self._analyze_cpu_bottlenecks() bottlenecks.extend(cpu_bottlenecks) - + # Analyze memory bottlenecks memory_bottlenecks = self._analyze_memory_bottlenecks() bottlenecks.extend(memory_bottlenecks) - + # Analyze operation latency bottlenecks latency_bottlenecks = self._analyze_latency_bottlenecks() bottlenecks.extend(latency_bottlenecks) - + # Analyze I/O bottlenecks io_bottlenecks = self._analyze_io_bottlenecks() bottlenecks.extend(io_bottlenecks) - + # Sort by severity and impact - bottlenecks.sort(key=lambda x: ( - {"critical": 4, "high": 3, "medium": 2, "low": 1}[x.severity], - x.impact_percentage - ), reverse=True) - + bottlenecks.sort( + key=lambda x: ( + {"critical": 4, "high": 3, "medium": 2, "low": 1}[x.severity], + x.impact_percentage, + ), + reverse=True, + ) + return bottlenecks - + def _analyze_cpu_bottlenecks(self) -> List[BottleneckAnalysis]: """Analyze CPU usage patterns for bottlenecks.""" bottlenecks = [] - + if not self.system_snapshots: return bottlenecks - + cpu_values = [s.cpu_percent for s in self.system_snapshots] avg_cpu = sum(cpu_values) / len(cpu_values) max_cpu = max(cpu_values) - + if avg_cpu > self.thresholds["cpu_percent"]: - severity = "critical" if avg_cpu > 95 else "high" if avg_cpu > 85 else "medium" - - bottlenecks.append(BottleneckAnalysis( - component="CPU", - severity=severity, - impact_percentage=avg_cpu, - description=f"High CPU utilization averaging {avg_cpu:.1f}%", - recommendations=[ - "Consider optimizing CPU-intensive operations", - "Implement asynchronous processing where possible", - "Review algorithm efficiency", - "Scale horizontally if sustained high load" - ], - metrics={"avg_cpu": avg_cpu, "max_cpu": max_cpu} - )) - + severity = ( + "critical" if avg_cpu > 95 else "high" if avg_cpu > 85 else "medium" + ) + + bottlenecks.append( + BottleneckAnalysis( + component="CPU", + severity=severity, + impact_percentage=avg_cpu, + description=f"High CPU utilization averaging {avg_cpu:.1f}%", + recommendations=[ + "Consider optimizing CPU-intensive operations", + "Implement asynchronous processing where possible", + "Review algorithm efficiency", + "Scale horizontally if sustained high load", + ], + metrics={"avg_cpu": avg_cpu, "max_cpu": max_cpu}, + ) + ) + return bottlenecks - + def _analyze_memory_bottlenecks(self) -> List[BottleneckAnalysis]: """Analyze memory usage patterns for bottlenecks.""" bottlenecks = [] - + if not self.system_snapshots: return bottlenecks - + memory_values = [s.memory_percent for s in self.system_snapshots] avg_memory = sum(memory_values) / len(memory_values) max_memory = max(memory_values) - + if avg_memory > self.thresholds["memory_percent"]: - severity = "critical" if avg_memory > 95 else "high" if avg_memory > 90 else "medium" - - bottlenecks.append(BottleneckAnalysis( - component="Memory", - severity=severity, - impact_percentage=avg_memory, - description=f"High memory utilization averaging {avg_memory:.1f}%", - recommendations=[ - "Review cache size configuration", - "Implement memory-efficient data structures", - "Consider garbage collection tuning", - "Monitor for memory leaks" - ], - metrics={"avg_memory": avg_memory, "max_memory": max_memory} - )) - + severity = ( + "critical" + if avg_memory > 95 + else "high" if avg_memory > 90 else "medium" + ) + + bottlenecks.append( + BottleneckAnalysis( + component="Memory", + severity=severity, + impact_percentage=avg_memory, + description=f"High memory utilization averaging {avg_memory:.1f}%", + recommendations=[ + "Review cache size configuration", + "Implement memory-efficient data structures", + "Consider garbage collection tuning", + "Monitor for memory leaks", + ], + metrics={"avg_memory": avg_memory, "max_memory": max_memory}, + ) + ) + return bottlenecks - + def _analyze_latency_bottlenecks(self) -> List[BottleneckAnalysis]: """Analyze operation latency for bottlenecks.""" bottlenecks = [] - + if not self.profile_points: return bottlenecks - + # Group by operation type operation_groups = {} for point in self.profile_points: @@ -392,92 +421,111 @@ def _analyze_latency_bottlenecks(self) -> List[BottleneckAnalysis]: if op_type not in operation_groups: operation_groups[op_type] = [] operation_groups[op_type].append(point.duration_ms) - + for op_type, durations in operation_groups.items(): avg_duration = sum(durations) / len(durations) max_duration = max(durations) - + # Check against thresholds - threshold = self.thresholds.get(f"{op_type.lower()}_latency_ms", - self.thresholds["response_time_ms"]) - + threshold = self.thresholds.get( + f"{op_type.lower()}_latency_ms", self.thresholds["response_time_ms"] + ) + if avg_duration > threshold: severity = "critical" if avg_duration > threshold * 2 else "high" impact = min(100, (avg_duration / threshold - 1) * 100) - - bottlenecks.append(BottleneckAnalysis( - component=f"{op_type} Latency", - severity=severity, - impact_percentage=impact, - description=f"{op_type} operations averaging {avg_duration:.1f}ms (threshold: {threshold}ms)", - recommendations=[ - f"Optimize {op_type.lower()} operations", - "Review database query efficiency", - "Consider caching frequently accessed data", - "Implement connection pooling" - ], - metrics={"avg_duration": avg_duration, "max_duration": max_duration, "threshold": threshold} - )) - + + bottlenecks.append( + BottleneckAnalysis( + component=f"{op_type} Latency", + severity=severity, + impact_percentage=impact, + description=f"{op_type} operations averaging {avg_duration:.1f}ms (threshold: {threshold}ms)", + recommendations=[ + f"Optimize {op_type.lower()} operations", + "Review database query efficiency", + "Consider caching frequently accessed data", + "Implement connection pooling", + ], + metrics={ + "avg_duration": avg_duration, + "max_duration": max_duration, + "threshold": threshold, + }, + ) + ) + return bottlenecks - + def _analyze_io_bottlenecks(self) -> List[BottleneckAnalysis]: """Analyze I/O patterns for bottlenecks.""" bottlenecks = [] - + if len(self.system_snapshots) < 2: return bottlenecks - + # Calculate I/O rates first_snapshot = self.system_snapshots[0] last_snapshot = self.system_snapshots[-1] time_diff = last_snapshot.timestamp - first_snapshot.timestamp - + if time_diff <= 0: return bottlenecks - - disk_read_rate = (last_snapshot.disk_io_read_mb - first_snapshot.disk_io_read_mb) / time_diff - disk_write_rate = (last_snapshot.disk_io_write_mb - first_snapshot.disk_io_write_mb) / time_diff - + + disk_read_rate = ( + last_snapshot.disk_io_read_mb - first_snapshot.disk_io_read_mb + ) / time_diff + disk_write_rate = ( + last_snapshot.disk_io_write_mb - first_snapshot.disk_io_write_mb + ) / time_diff + # High disk I/O threshold (MB/s) high_io_threshold = 50.0 - + if disk_read_rate > high_io_threshold: - bottlenecks.append(BottleneckAnalysis( - component="Disk I/O Read", - severity="medium", - impact_percentage=min(100, disk_read_rate / high_io_threshold * 100), - description=f"High disk read rate: {disk_read_rate:.1f} MB/s", - recommendations=[ - "Consider SSD storage for better performance", - "Implement read caching", - "Optimize database queries to reduce disk reads", - "Review file access patterns" - ], - metrics={"read_rate_mbs": disk_read_rate} - )) - + bottlenecks.append( + BottleneckAnalysis( + component="Disk I/O Read", + severity="medium", + impact_percentage=min( + 100, disk_read_rate / high_io_threshold * 100 + ), + description=f"High disk read rate: {disk_read_rate:.1f} MB/s", + recommendations=[ + "Consider SSD storage for better performance", + "Implement read caching", + "Optimize database queries to reduce disk reads", + "Review file access patterns", + ], + metrics={"read_rate_mbs": disk_read_rate}, + ) + ) + if disk_write_rate > high_io_threshold: - bottlenecks.append(BottleneckAnalysis( - component="Disk I/O Write", - severity="medium", - impact_percentage=min(100, disk_write_rate / high_io_threshold * 100), - description=f"High disk write rate: {disk_write_rate:.1f} MB/s", - recommendations=[ - "Implement write batching", - "Consider asynchronous writes", - "Review logging levels and output", - "Optimize cache write-back policies" - ], - metrics={"write_rate_mbs": disk_write_rate} - )) - + bottlenecks.append( + BottleneckAnalysis( + component="Disk I/O Write", + severity="medium", + impact_percentage=min( + 100, disk_write_rate / high_io_threshold * 100 + ), + description=f"High disk write rate: {disk_write_rate:.1f} MB/s", + recommendations=[ + "Implement write batching", + "Consider asynchronous writes", + "Review logging levels and output", + "Optimize cache write-back policies", + ], + metrics={"write_rate_mbs": disk_write_rate}, + ) + ) + return bottlenecks - + def _generate_performance_summary(self) -> Dict[str, Any]: """Generate performance summary from profiling data.""" summary = {} - + # Operation performance if self.profile_points: durations = [p.duration_ms for p in self.profile_points] @@ -486,45 +534,53 @@ def _generate_performance_summary(self) -> Dict[str, Any]: "avg_duration_ms": sum(durations) / len(durations), "min_duration_ms": min(durations), "max_duration_ms": max(durations), - "total_duration_ms": sum(durations) + "total_duration_ms": sum(durations), } - + # Resource utilization if self.system_snapshots: cpu_values = [s.cpu_percent for s in self.system_snapshots] memory_values = [s.memory_percent for s in self.system_snapshots] - + summary["resources"] = { "avg_cpu_percent": sum(cpu_values) / len(cpu_values), "max_cpu_percent": max(cpu_values), "avg_memory_percent": sum(memory_values) / len(memory_values), "max_memory_percent": max(memory_values), - "monitoring_duration_seconds": self.system_snapshots[-1].timestamp - self.system_snapshots[0].timestamp + "monitoring_duration_seconds": self.system_snapshots[-1].timestamp + - self.system_snapshots[0].timestamp, } - + return summary - - def _generate_optimization_recommendations(self, - bottlenecks: List[BottleneckAnalysis]) -> List[str]: + + def _generate_optimization_recommendations( + self, bottlenecks: List[BottleneckAnalysis] + ) -> List[str]: """Generate optimization recommendations based on bottlenecks.""" recommendations = [] - + # High-priority recommendations based on critical bottlenecks critical_bottlenecks = [b for b in bottlenecks if b.severity == "critical"] if critical_bottlenecks: - recommendations.append("Address critical performance bottlenecks immediately") + recommendations.append( + "Address critical performance bottlenecks immediately" + ) for bottleneck in critical_bottlenecks: - recommendations.extend(bottleneck.recommendations[:2]) # Top 2 recommendations - + recommendations.extend( + bottleneck.recommendations[:2] + ) # Top 2 recommendations + # General optimization recommendations - recommendations.extend([ - "Implement comprehensive monitoring and alerting", - "Consider performance testing under various load conditions", - "Review and optimize database queries and indexes", - "Implement efficient caching strategies", - "Consider horizontal scaling for high-load scenarios" - ]) - + recommendations.extend( + [ + "Implement comprehensive monitoring and alerting", + "Consider performance testing under various load conditions", + "Review and optimize database queries and indexes", + "Implement efficient caching strategies", + "Consider horizontal scaling for high-load scenarios", + ] + ) + # Remove duplicates while preserving order seen = set() unique_recommendations = [] @@ -532,126 +588,148 @@ def _generate_optimization_recommendations(self, if rec not in seen: seen.add(rec) unique_recommendations.append(rec) - + return unique_recommendations[:10] # Return top 10 recommendations class BottleneckAnalyzer: """ Specialized analyzer for identifying and categorizing bottlenecks. - + Provides advanced analysis of system bottlenecks with actionable optimization recommendations. """ - + def __init__(self): """Initialize bottleneck analyzer.""" self.analysis_history: List[BottleneckAnalysis] = [] logger.info("Bottleneck analyzer initialized") - - def analyze_cache_performance(self, cache_manager: CacheManager) -> List[BottleneckAnalysis]: + + def analyze_cache_performance( + self, cache_manager: CacheManager + ) -> List[BottleneckAnalysis]: """Analyze cache performance for bottlenecks.""" bottlenecks = [] - + try: metrics = cache_manager.get_metrics() - + # Cache hit rate analysis if metrics.hit_rate < 60.0: # Below target severity = "high" if metrics.hit_rate < 40 else "medium" - - bottlenecks.append(BottleneckAnalysis( - component="Cache Hit Rate", - severity=severity, - impact_percentage=100 - metrics.hit_rate, - description=f"Cache hit rate {metrics.hit_rate:.1f}% below optimal", - recommendations=[ - "Review cache warming strategies", - "Increase cache size if memory allows", - "Optimize cache eviction policies", - "Analyze query patterns for better caching" - ], - metrics={"hit_rate": metrics.hit_rate, "target": 60.0} - )) - + + bottlenecks.append( + BottleneckAnalysis( + component="Cache Hit Rate", + severity=severity, + impact_percentage=100 - metrics.hit_rate, + description=f"Cache hit rate {metrics.hit_rate:.1f}% below optimal", + recommendations=[ + "Review cache warming strategies", + "Increase cache size if memory allows", + "Optimize cache eviction policies", + "Analyze query patterns for better caching", + ], + metrics={"hit_rate": metrics.hit_rate, "target": 60.0}, + ) + ) + # Cache memory utilization if metrics.total_size > cache_manager.max_size_bytes * 0.9: - bottlenecks.append(BottleneckAnalysis( - component="Cache Memory", - severity="medium", - impact_percentage=90.0, - description="Cache approaching memory limit", - recommendations=[ - "Increase cache size allocation", - "Implement more aggressive eviction", - "Review cached content for optimization", - "Consider distributed caching" - ], - metrics={"utilization": metrics.total_size / cache_manager.max_size_bytes} - )) - + bottlenecks.append( + BottleneckAnalysis( + component="Cache Memory", + severity="medium", + impact_percentage=90.0, + description="Cache approaching memory limit", + recommendations=[ + "Increase cache size allocation", + "Implement more aggressive eviction", + "Review cached content for optimization", + "Consider distributed caching", + ], + metrics={ + "utilization": metrics.total_size + / cache_manager.max_size_bytes + }, + ) + ) + except Exception as e: logger.error("Error analyzing cache performance", error=str(e)) - + return bottlenecks - - def analyze_query_patterns(self, recent_queries: List[str]) -> List[BottleneckAnalysis]: + + def analyze_query_patterns( + self, recent_queries: List[str] + ) -> List[BottleneckAnalysis]: """Analyze query patterns for potential bottlenecks.""" bottlenecks = [] - + if not recent_queries: return bottlenecks - + # Analyze query complexity complex_queries = [q for q in recent_queries if len(q) > 200] if len(complex_queries) > len(recent_queries) * 0.3: # >30% complex queries - bottlenecks.append(BottleneckAnalysis( - component="Query Complexity", - severity="medium", - impact_percentage=len(complex_queries) / len(recent_queries) * 100, - description=f"{len(complex_queries)} of {len(recent_queries)} queries are complex", - recommendations=[ - "Break down complex queries into simpler components", - "Implement query preprocessing and optimization", - "Consider query result caching", - "Review query construction patterns" - ], - metrics={"complex_ratio": len(complex_queries) / len(recent_queries)} - )) - + bottlenecks.append( + BottleneckAnalysis( + component="Query Complexity", + severity="medium", + impact_percentage=len(complex_queries) / len(recent_queries) * 100, + description=f"{len(complex_queries)} of {len(recent_queries)} queries are complex", + recommendations=[ + "Break down complex queries into simpler components", + "Implement query preprocessing and optimization", + "Consider query result caching", + "Review query construction patterns", + ], + metrics={ + "complex_ratio": len(complex_queries) / len(recent_queries) + }, + ) + ) + return bottlenecks - - def generate_bottleneck_report(self, - bottlenecks: List[BottleneckAnalysis]) -> Dict[str, Any]: + + def generate_bottleneck_report( + self, bottlenecks: List[BottleneckAnalysis] + ) -> Dict[str, Any]: """Generate comprehensive bottleneck analysis report.""" if not bottlenecks: return { "summary": "No significant bottlenecks detected", "total_bottlenecks": 0, "severity_breakdown": {}, - "recommendations": ["Continue monitoring for performance trends"] + "recommendations": ["Continue monitoring for performance trends"], } - + # Severity breakdown severity_counts = {} for bottleneck in bottlenecks: - severity_counts[bottleneck.severity] = severity_counts.get(bottleneck.severity, 0) + 1 - + severity_counts[bottleneck.severity] = ( + severity_counts.get(bottleneck.severity, 0) + 1 + ) + # Top recommendations all_recommendations = [] for bottleneck in bottlenecks: all_recommendations.extend(bottleneck.recommendations) - + # Remove duplicates and prioritize unique_recommendations = list(dict.fromkeys(all_recommendations)) - + report = { "timestamp": time.time(), "summary": f"Identified {len(bottlenecks)} performance bottlenecks", "total_bottlenecks": len(bottlenecks), "severity_breakdown": severity_counts, - "critical_components": [b.component for b in bottlenecks if b.severity == "critical"], - "high_impact_components": [b.component for b in bottlenecks if b.impact_percentage > 70], + "critical_components": [ + b.component for b in bottlenecks if b.severity == "critical" + ], + "high_impact_components": [ + b.component for b in bottlenecks if b.impact_percentage > 70 + ], "recommendations": unique_recommendations[:10], "detailed_analysis": [ { @@ -659,10 +737,10 @@ def generate_bottleneck_report(self, "severity": b.severity, "impact": b.impact_percentage, "description": b.description, - "top_recommendations": b.recommendations[:3] + "top_recommendations": b.recommendations[:3], } for b in bottlenecks - ] + ], } - - return report \ No newline at end of file + + return report diff --git a/src/benchmarking/visualization.py b/src/benchmarking/visualization.py index 4f641db..a2d9b97 100644 --- a/src/benchmarking/visualization.py +++ b/src/benchmarking/visualization.py @@ -22,11 +22,12 @@ # Fall back to absolute imports (when run as script) import sys from pathlib import Path + # Add src to path if not already there src_path = str(Path(__file__).parent.parent) if src_path not in sys.path: sys.path.insert(0, src_path) - + from benchmarking.framework import BenchmarkResult, BenchmarkSummary from benchmarking.comparisons import ComparisonResult from benchmarking.profiler import ProfileResult, BottleneckAnalysis @@ -38,6 +39,7 @@ @dataclass class ChartData: """Data structure for chart generation.""" + chart_type: str # "line", "bar", "scatter", "histogram" title: str x_label: str @@ -49,6 +51,7 @@ class ChartData: @dataclass class ReportSection: """Individual section of a benchmark report.""" + section_id: str title: str content_type: str # "text", "table", "chart", "metrics" @@ -59,6 +62,7 @@ class ReportSection: @dataclass class BenchmarkReport: """Complete benchmark report structure.""" + report_id: str title: str generated_at: float @@ -72,49 +76,49 @@ class BenchmarkReport: class BenchmarkVisualizer: """ Comprehensive visualization system for benchmark data. - + Generates charts, tables, and visual reports for performance analysis. """ - + def __init__(self): """Initialize benchmark visualizer.""" self.color_scheme = { - "fact": "#2E8B57", # Sea Green - "rag": "#CD5C5C", # Indian Red + "fact": "#2E8B57", # Sea Green + "rag": "#CD5C5C", # Indian Red "cache_hit": "#32CD32", # Lime Green - "cache_miss": "#FF6347", # Tomato - "warning": "#FFD700", # Gold - "critical": "#DC143C", # Crimson - "success": "#228B22", # Forest Green - "neutral": "#708090" # Slate Gray + "cache_miss": "#FF6347", # Tomato + "warning": "#FFD700", # Gold + "critical": "#DC143C", # Crimson + "success": "#228B22", # Forest Green + "neutral": "#708090", # Slate Gray } - + logger.info("Benchmark visualizer initialized") - - def create_latency_comparison_chart(self, - fact_results: List[BenchmarkResult], - rag_results: List[BenchmarkResult]) -> ChartData: + + def create_latency_comparison_chart( + self, fact_results: List[BenchmarkResult], rag_results: List[BenchmarkResult] + ) -> ChartData: """Create latency comparison chart between FACT and RAG.""" # Extract latency data fact_latencies = [r.response_time_ms for r in fact_results if r.success] rag_latencies = [r.response_time_ms for r in rag_results if r.success] - + # Create comparison data data_series = [ { "name": "FACT System", "data": fact_latencies, "color": self.color_scheme["fact"], - "type": "box_plot" + "type": "box_plot", }, { "name": "Traditional RAG", "data": rag_latencies, - "color": self.color_scheme["rag"], - "type": "box_plot" - } + "color": self.color_scheme["rag"], + "type": "box_plot", + }, ] - + return ChartData( chart_type="box_plot", title="Response Time Comparison: FACT vs Traditional RAG", @@ -122,39 +126,64 @@ def create_latency_comparison_chart(self, y_label="Response Time (ms)", data_series=data_series, metadata={ - "fact_avg": sum(fact_latencies) / len(fact_latencies) if fact_latencies else 0, - "rag_avg": sum(rag_latencies) / len(rag_latencies) if rag_latencies else 0, - "improvement_factor": (sum(rag_latencies) / len(rag_latencies)) / (sum(fact_latencies) / len(fact_latencies)) if fact_latencies and rag_latencies else 0 - } + "fact_avg": ( + sum(fact_latencies) / len(fact_latencies) if fact_latencies else 0 + ), + "rag_avg": ( + sum(rag_latencies) / len(rag_latencies) if rag_latencies else 0 + ), + "improvement_factor": ( + (sum(rag_latencies) / len(rag_latencies)) + / (sum(fact_latencies) / len(fact_latencies)) + if fact_latencies and rag_latencies + else 0 + ), + }, ) - + def create_cache_performance_chart(self, benchmark_summary) -> ChartData: """Create cache hit/miss performance visualization.""" # Handle both BenchmarkSummary and List[BenchmarkResult] - if hasattr(benchmark_summary, 'avg_hit_latency_ms'): + if hasattr(benchmark_summary, "avg_hit_latency_ms"): # BenchmarkSummary object - hit_latencies = [benchmark_summary.avg_hit_latency_ms] if benchmark_summary.cache_hits > 0 else [] - miss_latencies = [benchmark_summary.avg_miss_latency_ms] if benchmark_summary.cache_misses > 0 else [] + hit_latencies = ( + [benchmark_summary.avg_hit_latency_ms] + if benchmark_summary.cache_hits > 0 + else [] + ) + miss_latencies = ( + [benchmark_summary.avg_miss_latency_ms] + if benchmark_summary.cache_misses > 0 + else [] + ) else: # List of BenchmarkResult objects - hit_latencies = [r.response_time_ms for r in benchmark_summary if r.success and r.cache_hit] - miss_latencies = [r.response_time_ms for r in benchmark_summary if r.success and not r.cache_hit] - + hit_latencies = [ + r.response_time_ms + for r in benchmark_summary + if r.success and r.cache_hit + ] + miss_latencies = [ + r.response_time_ms + for r in benchmark_summary + if r.success and not r.cache_hit + ] + data_series = [ { "name": "Cache Hits", "data": hit_latencies, "color": self.color_scheme["cache_hit"], - "type": "histogram" + "type": "histogram", }, { - "name": "Cache Misses", + "name": "Cache Misses", "data": miss_latencies, "color": self.color_scheme["cache_miss"], - "type": "histogram" - } + "type": "histogram", + }, ] - + return ChartData( chart_type="histogram", title="Response Time Distribution: Cache Hits vs Misses", @@ -162,41 +191,55 @@ def create_cache_performance_chart(self, benchmark_summary) -> ChartData: y_label="Frequency", data_series=data_series, metadata={ - "hit_avg": sum(hit_latencies) / len(hit_latencies) if hit_latencies else 0, - "miss_avg": sum(miss_latencies) / len(miss_latencies) if miss_latencies else 0, + "hit_avg": ( + sum(hit_latencies) / len(hit_latencies) if hit_latencies else 0 + ), + "miss_avg": ( + sum(miss_latencies) / len(miss_latencies) if miss_latencies else 0 + ), "hit_count": len(hit_latencies), "miss_count": len(miss_latencies), - "cache_hit_rate": len(hit_latencies) / (len(hit_latencies) + len(miss_latencies)) * 100 if (len(hit_latencies) + len(miss_latencies)) > 0 else 0 - } + "cache_hit_rate": ( + len(hit_latencies) + / (len(hit_latencies) + len(miss_latencies)) + * 100 + if (len(hit_latencies) + len(miss_latencies)) > 0 + else 0 + ), + }, ) - - def create_performance_timeline_chart(self, results: List[BenchmarkResult]) -> ChartData: + + def create_performance_timeline_chart( + self, results: List[BenchmarkResult] + ) -> ChartData: """Create performance over time timeline chart.""" # Sort results by timestamp sorted_results = sorted(results, key=lambda r: r.timestamp) - + # Create time series data timestamps = [r.timestamp for r in sorted_results] response_times = [r.response_time_ms for r in sorted_results] cache_hits = [1 if r.cache_hit else 0 for r in sorted_results] - + data_series = [ { "name": "Response Time", "data": list(zip(timestamps, response_times)), "color": self.color_scheme["fact"], "type": "line", - "y_axis": "left" + "y_axis": "left", }, { "name": "Cache Hit Rate", - "data": list(zip(timestamps, self._calculate_rolling_cache_hit_rate(cache_hits))), + "data": list( + zip(timestamps, self._calculate_rolling_cache_hit_rate(cache_hits)) + ), "color": self.color_scheme["cache_hit"], - "type": "line", - "y_axis": "right" - } + "type": "line", + "y_axis": "right", + }, ] - + return ChartData( chart_type="line", title="Performance Timeline", @@ -205,29 +248,41 @@ def create_performance_timeline_chart(self, results: List[BenchmarkResult]) -> C data_series=data_series, metadata={ "total_measurements": len(results), - "time_span_hours": (max(timestamps) - min(timestamps)) / 3600 if timestamps else 0 - } + "time_span_hours": ( + (max(timestamps) - min(timestamps)) / 3600 if timestamps else 0 + ), + }, ) - - def create_cost_comparison_chart(self, comparison_result: ComparisonResult) -> ChartData: + + def create_cost_comparison_chart( + self, comparison_result: ComparisonResult + ) -> ChartData: """Create cost comparison visualization.""" fact_cost = comparison_result.fact_metrics.avg_token_cost rag_cost = comparison_result.rag_metrics.avg_token_cost - + data_series = [ { "name": "Cost Comparison", "data": [ - {"category": "FACT System", "value": fact_cost, "color": self.color_scheme["fact"]}, - {"category": "Traditional RAG", "value": rag_cost, "color": self.color_scheme["rag"]} + { + "category": "FACT System", + "value": fact_cost, + "color": self.color_scheme["fact"], + }, + { + "category": "Traditional RAG", + "value": rag_cost, + "color": self.color_scheme["rag"], + }, ], - "type": "bar" + "type": "bar", } ] - + savings = rag_cost - fact_cost if rag_cost > fact_cost else 0 savings_percentage = (savings / rag_cost * 100) if rag_cost > 0 else 0 - + return ChartData( chart_type="bar", title="Token Cost Comparison per Query", @@ -238,11 +293,13 @@ def create_cost_comparison_chart(self, comparison_result: ComparisonResult) -> C "fact_cost": fact_cost, "rag_cost": rag_cost, "absolute_savings": savings, - "percentage_savings": savings_percentage - } + "percentage_savings": savings_percentage, + }, ) - - def create_bottleneck_analysis_chart(self, bottlenecks: List[BottleneckAnalysis]) -> ChartData: + + def create_bottleneck_analysis_chart( + self, bottlenecks: List[BottleneckAnalysis] + ) -> ChartData: """Create bottleneck analysis visualization.""" if not bottlenecks: return ChartData( @@ -251,20 +308,22 @@ def create_bottleneck_analysis_chart(self, bottlenecks: List[BottleneckAnalysis] x_label="Components", y_label="Impact (%)", data_series=[], - metadata={"status": "healthy"} + metadata={"status": "healthy"}, ) - + # Sort bottlenecks by impact - sorted_bottlenecks = sorted(bottlenecks, key=lambda b: b.impact_percentage, reverse=True) - + sorted_bottlenecks = sorted( + bottlenecks, key=lambda b: b.impact_percentage, reverse=True + ) + # Create severity color mapping severity_colors = { "critical": self.color_scheme["critical"], "high": "#FF8C00", # Dark Orange "medium": self.color_scheme["warning"], - "low": "#90EE90" # Light Green + "low": "#90EE90", # Light Green } - + data_series = [ { "name": "Bottleneck Impact", @@ -272,15 +331,17 @@ def create_bottleneck_analysis_chart(self, bottlenecks: List[BottleneckAnalysis] { "category": b.component, "value": b.impact_percentage, - "color": severity_colors.get(b.severity, self.color_scheme["neutral"]), - "severity": b.severity + "color": severity_colors.get( + b.severity, self.color_scheme["neutral"] + ), + "severity": b.severity, } for b in sorted_bottlenecks[:10] # Top 10 bottlenecks ], - "type": "bar" + "type": "bar", } ] - + return ChartData( chart_type="bar", title="Performance Bottleneck Analysis", @@ -289,23 +350,27 @@ def create_bottleneck_analysis_chart(self, bottlenecks: List[BottleneckAnalysis] data_series=data_series, metadata={ "total_bottlenecks": len(bottlenecks), - "critical_count": sum(1 for b in bottlenecks if b.severity == "critical"), - "high_count": sum(1 for b in bottlenecks if b.severity == "high") - } + "critical_count": sum( + 1 for b in bottlenecks if b.severity == "critical" + ), + "high_count": sum(1 for b in bottlenecks if b.severity == "high"), + }, ) - - def _calculate_rolling_cache_hit_rate(self, cache_hits: List[int], window_size: int = 10) -> List[float]: + + def _calculate_rolling_cache_hit_rate( + self, cache_hits: List[int], window_size: int = 10 + ) -> List[float]: """Calculate rolling cache hit rate.""" rolling_rates = [] - + for i in range(len(cache_hits)): start_idx = max(0, i - window_size + 1) - window_data = cache_hits[start_idx:i + 1] + window_data = cache_hits[start_idx : i + 1] hit_rate = sum(window_data) / len(window_data) * 100 rolling_rates.append(hit_rate) - + return rolling_rates - + def export_chart_data_json(self, chart: ChartData) -> str: """Export chart data as JSON for external visualization tools.""" return json.dumps(asdict(chart), indent=2, default=str) @@ -314,76 +379,94 @@ def export_chart_data_json(self, chart: ChartData) -> str: class ReportGenerator: """ Comprehensive report generation system. - + Creates detailed HTML, JSON, and text reports from benchmark data. """ - + def __init__(self, visualizer: Optional[BenchmarkVisualizer] = None): """ Initialize report generator. - + Args: visualizer: Benchmark visualizer instance """ self.visualizer = visualizer or BenchmarkVisualizer() - + logger.info("Report generator initialized") - - def generate_comprehensive_report(self, - benchmark_summary: BenchmarkSummary, - comparison_result: Optional[ComparisonResult] = None, - profile_result: Optional[ProfileResult] = None, - alerts: Optional[List[PerformanceAlert]] = None) -> BenchmarkReport: + + def generate_comprehensive_report( + self, + benchmark_summary: BenchmarkSummary, + comparison_result: Optional[ComparisonResult] = None, + profile_result: Optional[ProfileResult] = None, + alerts: Optional[List[PerformanceAlert]] = None, + ) -> BenchmarkReport: """ Generate comprehensive benchmark report. - + Args: benchmark_summary: Main benchmark results comparison_result: Optional comparison with RAG profile_result: Optional profiling results alerts: Optional performance alerts - + Returns: Complete benchmark report """ report_id = f"fact_benchmark_{int(time.time())}" - + # Generate report sections sections = [] charts = [] - + # Executive Summary - sections.append(self._create_executive_summary_section(benchmark_summary, comparison_result)) - + sections.append( + self._create_executive_summary_section(benchmark_summary, comparison_result) + ) + # Performance Metrics Section sections.append(self._create_performance_metrics_section(benchmark_summary)) - + # Target Compliance Section sections.append(self._create_target_compliance_section(benchmark_summary)) - + # Comparison Analysis (if available) if comparison_result: sections.append(self._create_comparison_analysis_section(comparison_result)) - charts.append(self.visualizer.create_cost_comparison_chart(comparison_result)) - + charts.append( + self.visualizer.create_cost_comparison_chart(comparison_result) + ) + # Bottleneck Analysis (if available) if profile_result: sections.append(self._create_bottleneck_analysis_section(profile_result)) - charts.append(self.visualizer.create_bottleneck_analysis_chart(profile_result.bottlenecks)) - + charts.append( + self.visualizer.create_bottleneck_analysis_chart( + profile_result.bottlenecks + ) + ) + # Alerts Section (if available) if alerts: sections.append(self._create_alerts_section(alerts)) - + # Recommendations Section - sections.append(self._create_recommendations_section(benchmark_summary, comparison_result, profile_result)) - + sections.append( + self._create_recommendations_section( + benchmark_summary, comparison_result, profile_result + ) + ) + # Generate overall recommendations - recommendations = self._generate_overall_recommendations(benchmark_summary, comparison_result, profile_result, alerts) - + recommendations = self._generate_overall_recommendations( + benchmark_summary, comparison_result, profile_result, alerts + ) + # Create summary - summary = self._create_report_summary(benchmark_summary, comparison_result, profile_result) - + summary = self._create_report_summary( + benchmark_summary, comparison_result, profile_result + ) + return BenchmarkReport( report_id=report_id, title="FACT System Performance Benchmark Report", @@ -394,25 +477,39 @@ def generate_comprehensive_report(self, recommendations=recommendations, raw_data={ "benchmark_summary": asdict(benchmark_summary), - "comparison_result": asdict(comparison_result) if comparison_result else None, + "comparison_result": ( + asdict(comparison_result) if comparison_result else None + ), "profile_result": asdict(profile_result) if profile_result else None, - "alerts": [asdict(alert) for alert in alerts] if alerts else None - } + "alerts": [asdict(alert) for alert in alerts] if alerts else None, + }, ) - - def _create_executive_summary_section(self, - benchmark_summary: BenchmarkSummary, - comparison_result: Optional[ComparisonResult]) -> ReportSection: + + def _create_executive_summary_section( + self, + benchmark_summary: BenchmarkSummary, + comparison_result: Optional[ComparisonResult], + ) -> ReportSection: """Create executive summary section.""" # Calculate key metrics - overall_performance = "EXCELLENT" if benchmark_summary.avg_response_time_ms < 50 else \ - "GOOD" if benchmark_summary.avg_response_time_ms < 100 else \ - "NEEDS_IMPROVEMENT" - - cache_performance = "EXCELLENT" if benchmark_summary.cache_hit_rate > 70 else \ - "GOOD" if benchmark_summary.cache_hit_rate > 50 else \ - "NEEDS_IMPROVEMENT" - + overall_performance = ( + "EXCELLENT" + if benchmark_summary.avg_response_time_ms < 50 + else ( + "GOOD" + if benchmark_summary.avg_response_time_ms < 100 + else "NEEDS_IMPROVEMENT" + ) + ) + + cache_performance = ( + "EXCELLENT" + if benchmark_summary.cache_hit_rate > 70 + else ( + "GOOD" if benchmark_summary.cache_hit_rate > 50 else "NEEDS_IMPROVEMENT" + ) + ) + content = { "overall_assessment": overall_performance, "cache_assessment": cache_performance, @@ -420,32 +517,38 @@ def _create_executive_summary_section(self, "avg_response_time_ms": benchmark_summary.avg_response_time_ms, "cache_hit_rate": benchmark_summary.cache_hit_rate, "error_rate": benchmark_summary.error_rate, - "throughput_qps": benchmark_summary.throughput_qps + "throughput_qps": benchmark_summary.throughput_qps, }, "target_compliance": { "hit_latency_target_met": benchmark_summary.hit_latency_target_met, "miss_latency_target_met": benchmark_summary.miss_latency_target_met, "cost_reduction_target_met": benchmark_summary.cost_reduction_target_met, - "cache_hit_rate_target_met": benchmark_summary.cache_hit_rate_target_met - } + "cache_hit_rate_target_met": benchmark_summary.cache_hit_rate_target_met, + }, } - + if comparison_result: content["improvement_summary"] = { - "latency_improvement": comparison_result.improvement_factors.get("latency", 1.0), - "cost_savings_percentage": comparison_result.cost_savings.get("percentage", 0.0), - "recommendation": comparison_result.recommendation + "latency_improvement": comparison_result.improvement_factors.get( + "latency", 1.0 + ), + "cost_savings_percentage": comparison_result.cost_savings.get( + "percentage", 0.0 + ), + "recommendation": comparison_result.recommendation, } - + return ReportSection( section_id="executive_summary", title="Executive Summary", content_type="metrics", content=content, - priority=1 + priority=1, ) - - def _create_performance_metrics_section(self, benchmark_summary: BenchmarkSummary) -> ReportSection: + + def _create_performance_metrics_section( + self, benchmark_summary: BenchmarkSummary + ) -> ReportSection: """Create detailed performance metrics section.""" content = { "latency_metrics": { @@ -454,84 +557,106 @@ def _create_performance_metrics_section(self, benchmark_summary: BenchmarkSummar "maximum_ms": benchmark_summary.max_response_time_ms, "p50_ms": benchmark_summary.p50_response_time_ms, "p95_ms": benchmark_summary.p95_response_time_ms, - "p99_ms": benchmark_summary.p99_response_time_ms + "p99_ms": benchmark_summary.p99_response_time_ms, }, "cache_metrics": { "hit_rate_percent": benchmark_summary.cache_hit_rate, "hit_count": benchmark_summary.cache_hits, "miss_count": benchmark_summary.cache_misses, "avg_hit_latency_ms": benchmark_summary.avg_hit_latency_ms, - "avg_miss_latency_ms": benchmark_summary.avg_miss_latency_ms + "avg_miss_latency_ms": benchmark_summary.avg_miss_latency_ms, }, "reliability_metrics": { - "success_rate_percent": (benchmark_summary.successful_queries / benchmark_summary.total_queries * 100) if benchmark_summary.total_queries > 0 else 0, + "success_rate_percent": ( + ( + benchmark_summary.successful_queries + / benchmark_summary.total_queries + * 100 + ) + if benchmark_summary.total_queries > 0 + else 0 + ), "error_rate_percent": benchmark_summary.error_rate, "total_queries": benchmark_summary.total_queries, - "failed_queries": benchmark_summary.failed_queries + "failed_queries": benchmark_summary.failed_queries, }, "cost_metrics": { "total_cost_usd": benchmark_summary.total_token_cost, "avg_cost_per_query_usd": benchmark_summary.avg_token_cost, "estimated_savings_usd": benchmark_summary.estimated_savings, - "cost_reduction_percentage": benchmark_summary.cost_reduction_percentage - } + "cost_reduction_percentage": benchmark_summary.cost_reduction_percentage, + }, } - + return ReportSection( section_id="performance_metrics", title="Detailed Performance Metrics", content_type="table", content=content, - priority=1 + priority=1, ) - - def _create_target_compliance_section(self, benchmark_summary: BenchmarkSummary) -> ReportSection: + + def _create_target_compliance_section( + self, benchmark_summary: BenchmarkSummary + ) -> ReportSection: """Create target compliance analysis section.""" targets = { "cache_hit_latency": { "target_ms": 48.0, "actual_ms": benchmark_summary.avg_hit_latency_ms, "met": benchmark_summary.hit_latency_target_met, - "status": "PASS" if benchmark_summary.hit_latency_target_met else "FAIL" + "status": ( + "PASS" if benchmark_summary.hit_latency_target_met else "FAIL" + ), }, "cache_miss_latency": { "target_ms": 140.0, "actual_ms": benchmark_summary.avg_miss_latency_ms, "met": benchmark_summary.miss_latency_target_met, - "status": "PASS" if benchmark_summary.miss_latency_target_met else "FAIL" + "status": ( + "PASS" if benchmark_summary.miss_latency_target_met else "FAIL" + ), }, "cost_reduction": { "target_percent": 75.0, "actual_percent": benchmark_summary.cost_reduction_percentage, "met": benchmark_summary.cost_reduction_target_met, - "status": "PASS" if benchmark_summary.cost_reduction_target_met else "FAIL" + "status": ( + "PASS" if benchmark_summary.cost_reduction_target_met else "FAIL" + ), }, "cache_hit_rate": { "target_percent": 60.0, "actual_percent": benchmark_summary.cache_hit_rate, "met": benchmark_summary.cache_hit_rate_target_met, - "status": "PASS" if benchmark_summary.cache_hit_rate_target_met else "FAIL" - } + "status": ( + "PASS" if benchmark_summary.cache_hit_rate_target_met else "FAIL" + ), + }, } - + overall_compliance = all(target["met"] for target in targets.values()) - + content = { "overall_compliance": overall_compliance, "overall_status": "PASS" if overall_compliance else "FAIL", "target_details": targets, - "compliance_score": sum(1 for target in targets.values() if target["met"]) / len(targets) * 100 + "compliance_score": sum(1 for target in targets.values() if target["met"]) + / len(targets) + * 100, } - + return ReportSection( section_id="target_compliance", title="Performance Target Compliance", content_type="table", content=content, - priority=1 + priority=1, ) - - def _create_comparison_analysis_section(self, comparison_result: ComparisonResult) -> ReportSection: + + def _create_comparison_analysis_section( + self, comparison_result: ComparisonResult + ) -> ReportSection: """Create RAG comparison analysis section.""" content = { "performance_improvements": comparison_result.improvement_factors, @@ -539,121 +664,156 @@ def _create_comparison_analysis_section(self, comparison_result: ComparisonResul "fact_metrics": asdict(comparison_result.fact_metrics), "rag_metrics": asdict(comparison_result.rag_metrics), "detailed_analysis": comparison_result.performance_analysis, - "recommendation": comparison_result.recommendation + "recommendation": comparison_result.recommendation, } - + return ReportSection( section_id="comparison_analysis", title="FACT vs Traditional RAG Comparison", content_type="table", content=content, - priority=1 + priority=1, ) - - def _create_bottleneck_analysis_section(self, profile_result: ProfileResult) -> ReportSection: + + def _create_bottleneck_analysis_section( + self, profile_result: ProfileResult + ) -> ReportSection: """Create bottleneck analysis section.""" - critical_bottlenecks = [b for b in profile_result.bottlenecks if b.severity == "critical"] - high_bottlenecks = [b for b in profile_result.bottlenecks if b.severity == "high"] - + critical_bottlenecks = [ + b for b in profile_result.bottlenecks if b.severity == "critical" + ] + high_bottlenecks = [ + b for b in profile_result.bottlenecks if b.severity == "high" + ] + content = { "total_bottlenecks": len(profile_result.bottlenecks), "critical_count": len(critical_bottlenecks), "high_count": len(high_bottlenecks), "bottleneck_details": [asdict(b) for b in profile_result.bottlenecks], "performance_summary": profile_result.performance_summary, - "optimization_recommendations": profile_result.optimization_recommendations + "optimization_recommendations": profile_result.optimization_recommendations, } - + return ReportSection( section_id="bottleneck_analysis", title="Performance Bottleneck Analysis", content_type="table", content=content, - priority=2 + priority=2, ) - + def _create_alerts_section(self, alerts: List[PerformanceAlert]) -> ReportSection: """Create performance alerts section.""" critical_alerts = [a for a in alerts if a.severity == "critical"] warning_alerts = [a for a in alerts if a.severity == "warning"] - + content = { "total_alerts": len(alerts), "critical_count": len(critical_alerts), "warning_count": len(warning_alerts), "alert_details": [asdict(alert) for alert in alerts], - "urgent_actions_required": len(critical_alerts) > 0 + "urgent_actions_required": len(critical_alerts) > 0, } - + return ReportSection( section_id="performance_alerts", title="Performance Alerts", content_type="table", content=content, - priority=1 if critical_alerts else 2 + priority=1 if critical_alerts else 2, ) - - def _create_recommendations_section(self, - benchmark_summary: BenchmarkSummary, - comparison_result: Optional[ComparisonResult], - profile_result: Optional[ProfileResult]) -> ReportSection: + + def _create_recommendations_section( + self, + benchmark_summary: BenchmarkSummary, + comparison_result: Optional[ComparisonResult], + profile_result: Optional[ProfileResult], + ) -> ReportSection: """Create recommendations section.""" recommendations = [] - + # Performance-based recommendations if benchmark_summary.avg_response_time_ms > 100: - recommendations.append("Investigate high response times - consider caching optimization") - + recommendations.append( + "Investigate high response times - consider caching optimization" + ) + if benchmark_summary.cache_hit_rate < 60: - recommendations.append("Improve cache hit rate through better warming strategies") - + recommendations.append( + "Improve cache hit rate through better warming strategies" + ) + if benchmark_summary.error_rate > 5: - recommendations.append("Address system reliability issues - high error rate detected") - + recommendations.append( + "Address system reliability issues - high error rate detected" + ) + # Comparison-based recommendations - if comparison_result and comparison_result.improvement_factors.get("latency", 1) < 2: - recommendations.append("Consider additional optimization - latency improvement below target") - + if ( + comparison_result + and comparison_result.improvement_factors.get("latency", 1) < 2 + ): + recommendations.append( + "Consider additional optimization - latency improvement below target" + ) + # Bottleneck-based recommendations if profile_result: recommendations.extend(profile_result.optimization_recommendations[:3]) - + content = { "priority_recommendations": recommendations[:5], "all_recommendations": recommendations, - "implementation_priority": "HIGH" if benchmark_summary.error_rate > 10 else "MEDIUM" + "implementation_priority": ( + "HIGH" if benchmark_summary.error_rate > 10 else "MEDIUM" + ), } - + return ReportSection( section_id="recommendations", title="Optimization Recommendations", content_type="text", content=content, - priority=1 + priority=1, ) - - def _create_report_summary(self, - benchmark_summary: BenchmarkSummary, - comparison_result: Optional[ComparisonResult], - profile_result: Optional[ProfileResult]) -> Dict[str, Any]: + + def _create_report_summary( + self, + benchmark_summary: BenchmarkSummary, + comparison_result: Optional[ComparisonResult], + profile_result: Optional[ProfileResult], + ) -> Dict[str, Any]: """Create overall report summary.""" summary = { "benchmark_execution": { "total_queries": benchmark_summary.total_queries, "execution_time_seconds": benchmark_summary.execution_time_seconds, - "success_rate": (benchmark_summary.successful_queries / benchmark_summary.total_queries * 100) if benchmark_summary.total_queries > 0 else 0 + "success_rate": ( + ( + benchmark_summary.successful_queries + / benchmark_summary.total_queries + * 100 + ) + if benchmark_summary.total_queries > 0 + else 0 + ), }, "performance_grade": self._calculate_performance_grade(benchmark_summary), - "key_findings": self._extract_key_findings(benchmark_summary, comparison_result, profile_result), - "action_required": self._determine_action_required(benchmark_summary, profile_result) + "key_findings": self._extract_key_findings( + benchmark_summary, comparison_result, profile_result + ), + "action_required": self._determine_action_required( + benchmark_summary, profile_result + ), } - + return summary - + def _calculate_performance_grade(self, benchmark_summary: BenchmarkSummary) -> str: """Calculate overall performance grade.""" score = 0 - + # Latency score (40% weight) if benchmark_summary.avg_response_time_ms <= 50: score += 40 @@ -663,7 +823,7 @@ def _calculate_performance_grade(self, benchmark_summary: BenchmarkSummary) -> s score += 20 else: score += 10 - + # Cache hit rate score (30% weight) if benchmark_summary.cache_hit_rate >= 70: score += 30 @@ -673,7 +833,7 @@ def _calculate_performance_grade(self, benchmark_summary: BenchmarkSummary) -> s score += 20 else: score += 10 - + # Error rate score (20% weight) if benchmark_summary.error_rate <= 1: score += 20 @@ -683,7 +843,7 @@ def _calculate_performance_grade(self, benchmark_summary: BenchmarkSummary) -> s score += 10 else: score += 5 - + # Cost efficiency score (10% weight) if benchmark_summary.cost_reduction_percentage >= 80: score += 10 @@ -693,7 +853,7 @@ def _calculate_performance_grade(self, benchmark_summary: BenchmarkSummary) -> s score += 6 else: score += 3 - + # Convert to grade if score >= 90: return "A+" @@ -709,14 +869,16 @@ def _calculate_performance_grade(self, benchmark_summary: BenchmarkSummary) -> s return "C" else: return "D" - - def _extract_key_findings(self, - benchmark_summary: BenchmarkSummary, - comparison_result: Optional[ComparisonResult], - profile_result: Optional[ProfileResult]) -> List[str]: + + def _extract_key_findings( + self, + benchmark_summary: BenchmarkSummary, + comparison_result: Optional[ComparisonResult], + profile_result: Optional[ProfileResult], + ) -> List[str]: """Extract key findings from benchmark results.""" findings = [] - + # Performance findings if benchmark_summary.avg_response_time_ms <= 48: findings.append("Excellent response time performance achieved") @@ -724,7 +886,7 @@ def _extract_key_findings(self, findings.append("Good response time performance within targets") else: findings.append("Response time performance needs improvement") - + # Cache findings if benchmark_summary.cache_hit_rate >= 70: findings.append("Cache performance excellent with high hit rates") @@ -732,100 +894,127 @@ def _extract_key_findings(self, findings.append("Cache performance good, meeting minimum targets") else: findings.append("Cache performance below optimal, requires attention") - + # Cost findings if benchmark_summary.cost_reduction_percentage >= 80: - findings.append("Significant cost savings achieved compared to traditional methods") + findings.append( + "Significant cost savings achieved compared to traditional methods" + ) elif benchmark_summary.cost_reduction_percentage >= 60: findings.append("Good cost efficiency with substantial savings") else: findings.append("Cost efficiency below expectations") - + # Comparison findings if comparison_result: - latency_improvement = comparison_result.improvement_factors.get("latency", 1.0) + latency_improvement = comparison_result.improvement_factors.get( + "latency", 1.0 + ) if latency_improvement >= 4.0: findings.append("Outstanding latency improvement over traditional RAG") elif latency_improvement >= 2.0: findings.append("Significant latency improvement demonstrated") else: findings.append("Latency improvement present but below expectations") - + return findings[:5] # Return top 5 findings - - def _determine_action_required(self, - benchmark_summary: BenchmarkSummary, - profile_result: Optional[ProfileResult]) -> str: + + def _determine_action_required( + self, + benchmark_summary: BenchmarkSummary, + profile_result: Optional[ProfileResult], + ) -> str: """Determine required actions based on results.""" critical_issues = [] - + if benchmark_summary.error_rate > 10: critical_issues.append("High error rate") - + if benchmark_summary.avg_response_time_ms > 200: critical_issues.append("Poor response times") - + if benchmark_summary.cache_hit_rate < 40: critical_issues.append("Low cache efficiency") - + if profile_result: - critical_bottlenecks = [b for b in profile_result.bottlenecks if b.severity == "critical"] + critical_bottlenecks = [ + b for b in profile_result.bottlenecks if b.severity == "critical" + ] if critical_bottlenecks: critical_issues.append("Critical performance bottlenecks") - + if critical_issues: return f"IMMEDIATE ACTION REQUIRED: {', '.join(critical_issues)}" - elif benchmark_summary.avg_response_time_ms > 100 or benchmark_summary.cache_hit_rate < 60: + elif ( + benchmark_summary.avg_response_time_ms > 100 + or benchmark_summary.cache_hit_rate < 60 + ): return "OPTIMIZATION RECOMMENDED: Performance improvements needed" else: return "MONITORING: System performing within acceptable parameters" - - def _generate_overall_recommendations(self, - benchmark_summary: BenchmarkSummary, - comparison_result: Optional[ComparisonResult], - profile_result: Optional[ProfileResult], - alerts: Optional[List[PerformanceAlert]]) -> List[str]: + + def _generate_overall_recommendations( + self, + benchmark_summary: BenchmarkSummary, + comparison_result: Optional[ComparisonResult], + profile_result: Optional[ProfileResult], + alerts: Optional[List[PerformanceAlert]], + ) -> List[str]: """Generate overall optimization recommendations.""" recommendations = [] - + # Critical alerts take priority if alerts: critical_alerts = [a for a in alerts if a.severity == "critical"] if critical_alerts: - recommendations.append("Address critical performance alerts immediately") - + recommendations.append( + "Address critical performance alerts immediately" + ) + # Performance-based recommendations if benchmark_summary.avg_response_time_ms > 100: - recommendations.append("Optimize response time through caching and query optimization") - + recommendations.append( + "Optimize response time through caching and query optimization" + ) + if benchmark_summary.cache_hit_rate < 60: - recommendations.append("Implement cache warming and improve hit rate strategies") - + recommendations.append( + "Implement cache warming and improve hit rate strategies" + ) + if benchmark_summary.error_rate > 5: recommendations.append("Investigate and resolve system reliability issues") - + # Bottleneck recommendations if profile_result and profile_result.bottlenecks: - top_bottleneck = max(profile_result.bottlenecks, key=lambda b: b.impact_percentage) - recommendations.append(f"Address {top_bottleneck.component.lower()} bottleneck") - + top_bottleneck = max( + profile_result.bottlenecks, key=lambda b: b.impact_percentage + ) + recommendations.append( + f"Address {top_bottleneck.component.lower()} bottleneck" + ) + # Cost optimization if benchmark_summary.cost_reduction_percentage < 70: - recommendations.append("Improve cost efficiency through better caching strategies") - + recommendations.append( + "Improve cost efficiency through better caching strategies" + ) + # General recommendations - recommendations.extend([ - "Implement continuous performance monitoring", - "Establish performance regression testing", - "Document optimization strategies and results" - ]) - + recommendations.extend( + [ + "Implement continuous performance monitoring", + "Establish performance regression testing", + "Document optimization strategies and results", + ] + ) + return recommendations[:8] # Return top 8 recommendations - + def export_report_json(self, report: BenchmarkReport) -> str: """Export report as JSON.""" return json.dumps(asdict(report), indent=2, default=str) - + def export_report_summary_text(self, report: BenchmarkReport) -> str: """Export report summary as text.""" lines = [] @@ -835,19 +1024,25 @@ def export_report_summary_text(self, report: BenchmarkReport) -> str: lines.append(f"Generated: {datetime.fromtimestamp(report.generated_at)}") lines.append(f"Report ID: {report.report_id}") lines.append("") - + # Executive Summary - exec_summary = next((s for s in report.sections if s.section_id == "executive_summary"), None) + exec_summary = next( + (s for s in report.sections if s.section_id == "executive_summary"), None + ) if exec_summary: lines.append("EXECUTIVE SUMMARY") lines.append("-" * 20) content = exec_summary.content lines.append(f"Overall Performance: {content['overall_assessment']}") lines.append(f"Cache Performance: {content['cache_assessment']}") - lines.append(f"Average Response Time: {content['key_metrics']['avg_response_time_ms']:.1f}ms") - lines.append(f"Cache Hit Rate: {content['key_metrics']['cache_hit_rate']:.1f}%") + lines.append( + f"Average Response Time: {content['key_metrics']['avg_response_time_ms']:.1f}ms" + ) + lines.append( + f"Cache Hit Rate: {content['key_metrics']['cache_hit_rate']:.1f}%" + ) lines.append("") - + # Key Recommendations if report.recommendations: lines.append("KEY RECOMMENDATIONS") @@ -855,10 +1050,10 @@ def export_report_summary_text(self, report: BenchmarkReport) -> str: for i, rec in enumerate(report.recommendations[:5], 1): lines.append(f"{i}. {rec}") lines.append("") - + # Performance Grade if "performance_grade" in report.summary: lines.append(f"PERFORMANCE GRADE: {report.summary['performance_grade']}") lines.append("") - - return "\n".join(lines) \ No newline at end of file + + return "\n".join(lines) diff --git a/src/cache/validation.py b/src/cache/validation.py index 0e1796c..7cd9590 100644 --- a/src/cache/validation.py +++ b/src/cache/validation.py @@ -582,7 +582,7 @@ async def auto_repair_cache(self, validation_result: ValidationResult) -> Dict[s repair_summary["warnings_addressed"] += 1 logger.info("Auto-repair completed", **repair_summary) - return freed_space + return repair_summary except Exception as e: logger.error("Auto-repair failed", error=str(e)) diff --git a/src/core/agentic_flow.py b/src/core/agentic_flow.py index 1ffc727..eaca05a 100644 --- a/src/core/agentic_flow.py +++ b/src/core/agentic_flow.py @@ -6,7 +6,7 @@ The FACT algorithm combines: - Fast cache-first query resolution -- Access pattern optimization +- Access pattern optimization - Caching strategy adaptation - Token-efficient LLM interactions """ @@ -32,6 +32,7 @@ @dataclass class FACTQueryContext: """Context for FACT query processing.""" + query_id: str user_query: str query_hash: str @@ -39,7 +40,7 @@ class FACTQueryContext: cache_mode: str = "read" priority: float = 1.0 metadata: Dict[str, Any] = None - + def __post_init__(self): if self.metadata is None: self.metadata = {} @@ -48,6 +49,7 @@ def __post_init__(self): @dataclass class FACTResponse: """FACT algorithm response with performance metrics.""" + content: str cache_hit: bool execution_time_ms: float @@ -57,7 +59,7 @@ class FACTResponse: strategy_used: str query_context: FACTQueryContext performance_metrics: Dict[str, Any] = None - + def __post_init__(self): if self.performance_metrics is None: self.performance_metrics = {} @@ -66,22 +68,22 @@ def __post_init__(self): class FACTAlgorithm: """ Core FACT Algorithm Implementation - + Implements the Fast Access Caching Technology algorithm that optimizes LLM interactions through intelligent caching, access pattern analysis, and adaptive strategy selection. - + Key Components: 1. Cache-First Query Resolution - 2. Token-Efficient Content Processing + 2. Token-Efficient Content Processing 3. Adaptive Caching Strategies 4. Performance Monitoring and Optimization """ - + def __init__(self, config: Optional[Config] = None): """ Initialize FACT algorithm with configuration. - + Args: config: Optional configuration instance """ @@ -91,20 +93,20 @@ def __init__(self, config: Optional[Config] = None): self.metrics_collector = get_metrics_collector() self.tool_registry = get_tool_registry() self._initialized = False - + # FACT algorithm parameters self.cache_hit_target_ms = 50 # Target latency for cache hits - self.cache_miss_target_ms = 140 # Target latency for cache misses + self.cache_miss_target_ms = 140 # Target latency for cache misses self.min_cache_tokens = 500 # Minimum tokens for caching self.token_efficiency_target = 100.0 # Tokens per KB target - + logger.info("FACT algorithm initialized") - + async def initialize(self) -> None: """Initialize FACT algorithm components.""" if self._initialized: return - + try: # Initialize cache manager with FACT-specific configuration cache_config = { @@ -113,132 +115,144 @@ async def initialize(self) -> None: "max_size": "10MB", "ttl_seconds": 3600, # 1 hour TTL "hit_target_ms": self.cache_hit_target_ms, - "miss_target_ms": self.cache_miss_target_ms + "miss_target_ms": self.cache_miss_target_ms, } - + self.cache_manager = get_cache_manager(cache_config) - + # Initialize cache optimizer with adaptive strategy self.cache_optimizer = get_cache_optimizer(CacheStrategy.ADAPTIVE) - + # Start background optimization asyncio.create_task(self._run_background_optimization()) - + self._initialized = True logger.info("FACT algorithm initialization completed") - + except Exception as e: logger.error("FACT algorithm initialization failed", error=str(e)) raise FACTError(f"FACT initialization failed: {e}") - - async def process_query(self, user_query: str, context: Optional[Dict[str, Any]] = None) -> FACTResponse: + + async def process_query( + self, user_query: str, context: Optional[Dict[str, Any]] = None + ) -> FACTResponse: """ Process query using FACT algorithm. - + This is the main entry point for the FACT algorithm, implementing: 1. Query normalization and hashing 2. Cache-first lookup with performance tracking 3. LLM fallback with tool integration 4. Response caching with strategy optimization 5. Performance metrics collection - + Args: user_query: User's natural language query context: Optional context for query processing - + Returns: FACTResponse with content and performance metrics """ if not self._initialized: await self.initialize() - + start_time = time.perf_counter() - + # Create query context query_context = self._create_query_context(user_query, context) - + try: - logger.info("FACT query processing started", - query_id=query_context.query_id, - query_hash=query_context.query_hash[:16]) - + logger.info( + "FACT query processing started", + query_id=query_context.query_id, + query_hash=query_context.query_hash[:16], + ) + # Phase 1: Cache-first lookup cached_response = await self._cache_lookup_phase(query_context) if cached_response: return cached_response - + # Phase 2: LLM processing with tool integration llm_response = await self._llm_processing_phase(query_context) - + # Phase 3: Cache storage and optimization await self._cache_storage_phase(query_context, llm_response) - + # Record performance metrics execution_time = (time.perf_counter() - start_time) * 1000 self._record_query_metrics(query_context, llm_response, execution_time) - + return llm_response - + except Exception as e: execution_time = (time.perf_counter() - start_time) * 1000 - logger.error("FACT query processing failed", - query_id=query_context.query_id, - error=str(e), - execution_time_ms=execution_time) - + logger.error( + "FACT query processing failed", + query_id=query_context.query_id, + error=str(e), + execution_time_ms=execution_time, + ) + # Record failure metrics self.metrics_collector.record_tool_execution( tool_name="fact_algorithm", success=False, execution_time=execution_time, error_type=type(e).__name__, - metadata={"query_id": query_context.query_id} + metadata={"query_id": query_context.query_id}, ) - + raise FACTError(f"FACT query processing failed: {e}") - - def _create_query_context(self, user_query: str, context: Optional[Dict[str, Any]]) -> FACTQueryContext: + + def _create_query_context( + self, user_query: str, context: Optional[Dict[str, Any]] + ) -> FACTQueryContext: """Create query context with normalization and hashing.""" # Normalize query for consistent hashing normalized_query = user_query.strip().lower() - + # Generate deterministic hash query_hash = self.cache_manager.generate_hash(normalized_query) - + # Create unique query ID query_id = f"fact_{int(time.time() * 1000)}_{query_hash[:8]}" - + return FACTQueryContext( query_id=query_id, user_query=user_query, query_hash=query_hash, timestamp=time.time(), - metadata=context or {} + metadata=context or {}, ) - - async def _cache_lookup_phase(self, query_context: FACTQueryContext) -> Optional[FACTResponse]: + + async def _cache_lookup_phase( + self, query_context: FACTQueryContext + ) -> Optional[FACTResponse]: """ Phase 1: Cache-first lookup with performance tracking. - + Implements fast cache lookup with: - Performance monitoring - Cache hit optimization - Access pattern tracking """ cache_start = time.perf_counter() - + try: # Attempt cache lookup cached_entry = self.cache_manager.get(query_context.query_hash) - + cache_latency = (time.perf_counter() - cache_start) * 1000 - + if cached_entry: - logger.info("FACT cache hit", - query_id=query_context.query_id, - cache_latency_ms=cache_latency, - access_count=cached_entry.access_count) - + logger.info( + "FACT cache hit", + query_id=query_context.query_id, + cache_latency_ms=cache_latency, + access_count=cached_entry.access_count, + ) + # Create cache hit response response = FACTResponse( content=cached_entry.content, @@ -252,10 +266,10 @@ async def _cache_lookup_phase(self, query_context: FACTQueryContext) -> Optional performance_metrics={ "cache_latency_ms": cache_latency, "access_count": cached_entry.access_count, - "cache_age_seconds": time.time() - cached_entry.created_at - } + "cache_age_seconds": time.time() - cached_entry.created_at, + }, ) - + # Record cache hit metrics self.metrics_collector.record_tool_execution( tool_name="fact_cache_hit", @@ -264,59 +278,66 @@ async def _cache_lookup_phase(self, query_context: FACTQueryContext) -> Optional metadata={ "query_id": query_context.query_id, "token_count": cached_entry.token_count, - "access_count": cached_entry.access_count - } + "access_count": cached_entry.access_count, + }, ) - + return response else: - logger.info("FACT cache miss", - query_id=query_context.query_id, - cache_latency_ms=cache_latency) - + logger.info( + "FACT cache miss", + query_id=query_context.query_id, + cache_latency_ms=cache_latency, + ) + # Record cache miss metrics self.metrics_collector.record_tool_execution( tool_name="fact_cache_miss", success=True, execution_time=cache_latency, - metadata={"query_id": query_context.query_id} + metadata={"query_id": query_context.query_id}, ) - + return None - + except Exception as e: cache_latency = (time.perf_counter() - cache_start) * 1000 - logger.warning("Cache lookup failed", - query_id=query_context.query_id, - error=str(e), - cache_latency_ms=cache_latency) + logger.warning( + "Cache lookup failed", + query_id=query_context.query_id, + error=str(e), + cache_latency_ms=cache_latency, + ) return None - - async def _llm_processing_phase(self, query_context: FACTQueryContext) -> FACTResponse: + + async def _llm_processing_phase( + self, query_context: FACTQueryContext + ) -> FACTResponse: """ Phase 2: LLM processing with tool integration. - + Implements intelligent LLM interaction with: - Tool execution optimization - Token efficiency monitoring - Response quality assessment """ llm_start = time.perf_counter() - + try: # Get FACT driver for LLM processing from .driver import get_driver + driver = await get_driver(self.config) - + # Process query through driver response_content = await driver.process_query(query_context.user_query) - + llm_latency = (time.perf_counter() - llm_start) * 1000 - + # Calculate token count and efficiency token_count = CacheEntry._count_tokens(response_content) cache_efficiency = self._calculate_content_efficiency(response_content) - + # Create LLM response response = FACTResponse( content=response_content, @@ -330,43 +351,51 @@ async def _llm_processing_phase(self, query_context: FACTQueryContext) -> FACTRe performance_metrics={ "llm_latency_ms": llm_latency, "token_count": token_count, - "cache_efficiency": cache_efficiency - } + "cache_efficiency": cache_efficiency, + }, ) - - logger.info("FACT LLM processing completed", - query_id=query_context.query_id, - llm_latency_ms=llm_latency, - token_count=token_count, - cache_efficiency=cache_efficiency) - + + logger.info( + "FACT LLM processing completed", + query_id=query_context.query_id, + llm_latency_ms=llm_latency, + token_count=token_count, + cache_efficiency=cache_efficiency, + ) + return response - + except Exception as e: llm_latency = (time.perf_counter() - llm_start) * 1000 - logger.error("LLM processing failed", - query_id=query_context.query_id, - error=str(e), - llm_latency_ms=llm_latency) + logger.error( + "LLM processing failed", + query_id=query_context.query_id, + error=str(e), + llm_latency_ms=llm_latency, + ) raise ToolExecutionError(f"LLM processing failed: {e}") - - async def _cache_storage_phase(self, query_context: FACTQueryContext, response: FACTResponse) -> None: + + async def _cache_storage_phase( + self, query_context: FACTQueryContext, response: FACTResponse + ) -> None: """ Phase 3: Cache storage with strategy optimization. - + Implements intelligent caching with: - Content quality assessment - Storage strategy optimization - Cache efficiency monitoring """ if not response.content or response.token_count < self.min_cache_tokens: - logger.debug("Content not cached - insufficient tokens", - query_id=query_context.query_id, - token_count=response.token_count) + logger.debug( + "Content not cached - insufficient tokens", + query_id=query_context.query_id, + token_count=response.token_count, + ) return - + storage_start = time.perf_counter() - + try: # Check if content should be cached based on strategy should_cache = self.cache_optimizer.should_cache_content( @@ -374,21 +403,25 @@ async def _cache_storage_phase(self, query_context: FACTQueryContext, response: { "query": query_context.user_query, "token_count": response.token_count, - "execution_time": response.execution_time_ms - } + "execution_time": response.execution_time_ms, + }, ) - + if should_cache: # Store in cache - cache_entry = self.cache_manager.store(query_context.query_hash, response.content) - + cache_entry = self.cache_manager.store( + query_context.query_hash, response.content + ) + storage_latency = (time.perf_counter() - storage_start) * 1000 - - logger.info("Content cached successfully", - query_id=query_context.query_id, - token_count=cache_entry.token_count, - storage_latency_ms=storage_latency) - + + logger.info( + "Content cached successfully", + query_id=query_context.query_id, + token_count=cache_entry.token_count, + storage_latency_ms=storage_latency, + ) + # Record cache storage metrics self.metrics_collector.record_tool_execution( tool_name="fact_cache_store", @@ -396,41 +429,52 @@ async def _cache_storage_phase(self, query_context: FACTQueryContext, response: execution_time=storage_latency, metadata={ "query_id": query_context.query_id, - "token_count": cache_entry.token_count - } + "token_count": cache_entry.token_count, + }, ) else: - logger.debug("Content not cached - strategy decision", - query_id=query_context.query_id) - + logger.debug( + "Content not cached - strategy decision", + query_id=query_context.query_id, + ) + except CacheError as e: storage_latency = (time.perf_counter() - storage_start) * 1000 - logger.warning("Cache storage failed", - query_id=query_context.query_id, - error=str(e), - storage_latency_ms=storage_latency) - + logger.warning( + "Cache storage failed", + query_id=query_context.query_id, + error=str(e), + storage_latency_ms=storage_latency, + ) + def _calculate_cache_efficiency(self, cache_entry: CacheEntry) -> float: """Calculate cache efficiency score for an entry.""" - content_size_kb = len(cache_entry.content.encode('utf-8')) / 1024 + content_size_kb = len(cache_entry.content.encode("utf-8")) / 1024 return cache_entry.token_count / content_size_kb if content_size_kb > 0 else 0.0 - + def _calculate_content_efficiency(self, content: str) -> float: """Calculate content efficiency for new content.""" if not content: return 0.0 - + token_count = CacheEntry._count_tokens(content) - content_size_kb = len(content.encode('utf-8')) / 1024 + content_size_kb = len(content.encode("utf-8")) / 1024 return token_count / content_size_kb if content_size_kb > 0 else 0.0 - + def _count_tool_calls(self, content: str) -> int: """Estimate number of tool calls from response content.""" # Simple heuristic - count mentions of tool execution patterns tool_indicators = ["executed", "query", "result", "data"] - return sum(1 for indicator in tool_indicators if indicator.lower() in content.lower()) - - def _record_query_metrics(self, query_context: FACTQueryContext, response: FACTResponse, execution_time: float) -> None: + return sum( + 1 for indicator in tool_indicators if indicator.lower() in content.lower() + ) + + def _record_query_metrics( + self, + query_context: FACTQueryContext, + response: FACTResponse, + execution_time: float, + ) -> None: """Record comprehensive query processing metrics.""" self.metrics_collector.record_tool_execution( tool_name="fact_query_complete", @@ -442,42 +486,45 @@ def _record_query_metrics(self, query_context: FACTQueryContext, response: FACTR "token_count": response.token_count, "cache_efficiency": response.cache_efficiency, "strategy_used": response.strategy_used, - "tool_calls_count": response.tool_calls_count - } + "tool_calls_count": response.tool_calls_count, + }, ) - + async def _run_background_optimization(self) -> None: """Run background cache optimization.""" logger.info("Starting FACT background optimization") - + while True: try: await asyncio.sleep(300) # Run every 5 minutes - + if self.cache_manager and self.cache_optimizer: - optimization_results = await self.cache_optimizer.optimize_cache(self.cache_manager) - - logger.debug("Background optimization completed", - **optimization_results) - + optimization_results = await self.cache_optimizer.optimize_cache( + self.cache_manager + ) + + logger.debug( + "Background optimization completed", **optimization_results + ) + except asyncio.CancelledError: logger.info("Background optimization cancelled") break except Exception as e: logger.error("Background optimization failed", error=str(e)) # Continue running despite errors - + def get_algorithm_metrics(self) -> Dict[str, Any]: """Get FACT algorithm performance metrics.""" if not self.cache_manager: return {} - + # Get cache metrics cache_metrics = self.cache_manager.get_metrics() - + # Get system metrics system_metrics = self.metrics_collector.get_system_metrics() - + # Calculate FACT-specific metrics fact_metrics = { "algorithm": "FACT", @@ -494,10 +541,10 @@ def get_algorithm_metrics(self) -> Dict[str, Any]: "cache_hit_target_ms": self.cache_hit_target_ms, "cache_miss_target_ms": self.cache_miss_target_ms, "min_cache_tokens": self.min_cache_tokens, - "token_efficiency_target": self.token_efficiency_target - } + "token_efficiency_target": self.token_efficiency_target, + }, } - + return fact_metrics @@ -508,32 +555,34 @@ def get_algorithm_metrics(self) -> Dict[str, Any]: async def get_fact_algorithm(config: Optional[Config] = None) -> FACTAlgorithm: """ Get or create the global FACT algorithm instance. - + Args: config: Optional configuration - + Returns: Initialized FACT algorithm instance """ global _fact_algorithm - + if _fact_algorithm is None: _fact_algorithm = FACTAlgorithm(config) await _fact_algorithm.initialize() - + return _fact_algorithm -async def process_fact_query(user_query: str, context: Optional[Dict[str, Any]] = None) -> FACTResponse: +async def process_fact_query( + user_query: str, context: Optional[Dict[str, Any]] = None +) -> FACTResponse: """ Process a query using the FACT algorithm. - + Args: user_query: User's natural language query context: Optional query context - + Returns: FACT response with performance metrics """ algorithm = await get_fact_algorithm() - return await algorithm.process_query(user_query, context) \ No newline at end of file + return await algorithm.process_query(user_query, context) diff --git a/src/core/cli.py b/src/core/cli.py index 4d5ba5b..16b41ff 100644 --- a/src/core/cli.py +++ b/src/core/cli.py @@ -16,128 +16,129 @@ from .config import get_config from .errors import FACTError, create_user_friendly_message - logger = structlog.get_logger(__name__) class FACTCLi: """ Interactive command-line interface for the FACT system. - + Provides user interaction, query processing, and system management through a console-based interface. """ - + def __init__(self): """Initialize CLI interface.""" self.driver = None self.running = False - + async def initialize(self) -> None: """Initialize the CLI and underlying FACT system.""" try: logger.info("Initializing FACT CLI") self.driver = await get_driver() logger.info("FACT CLI initialized successfully") - + except Exception as e: logger.error("CLI initialization failed", error=str(e)) print(f"āŒ Initialization failed: {create_user_friendly_message(e)}") sys.exit(1) - + async def run_interactive(self) -> None: """ Run the interactive CLI loop. - + Handles user input, query processing, and graceful shutdown. """ if not self.driver: await self.initialize() - + self.running = True - + # Print welcome message print("šŸš€ FACT System - Fast-Access Cached Tools") - print("šŸ’” Ask questions about financial data. Type 'help' for commands or Ctrl+C to exit.") + print( + "šŸ’” Ask questions about financial data. Type 'help' for commands or Ctrl+C to exit." + ) print() - + # Show system status await self._show_status() - + while self.running: try: # Get user input user_input = input("\nšŸ’¬ > ").strip() - + if not user_input: continue - + # Handle special commands if await self._handle_command(user_input): continue - + # Process query through FACT system print("šŸ¤” Processing your query...") - + response = await self.driver.process_query(user_input) - + # Display response print("\nšŸ“Š Response:") print(response) - + except (EOFError, KeyboardInterrupt): print("\nšŸ‘‹ Goodbye!") break - + except Exception as e: error_message = create_user_friendly_message(e) print(f"\nāŒ Error: {error_message}") logger.error("CLI error", error=str(e)) - + await self._shutdown() - + async def _handle_command(self, input_text: str) -> bool: """ Handle special CLI commands. - + Args: input_text: User input to check for commands - + Returns: True if input was a command, False otherwise """ command = input_text.lower().strip() - - if command in ['help', '?']: + + if command in ["help", "?"]: await self._show_help() return True - - elif command in ['status', 'stats']: + + elif command in ["status", "stats"]: await self._show_status() return True - - elif command in ['tools', 'list-tools']: + + elif command in ["tools", "list-tools"]: await self._show_tools() return True - - elif command in ['schema', 'db-schema']: + + elif command in ["schema", "db-schema"]: await self._show_schema() return True - - elif command in ['samples', 'examples']: + + elif command in ["samples", "examples"]: await self._show_sample_queries() return True - - elif command in ['metrics', 'performance']: + + elif command in ["metrics", "performance"]: await self._show_metrics() return True - - elif command in ['exit', 'quit', 'q']: + + elif command in ["exit", "quit", "q"]: self.running = False return True - + return False - + async def _show_help(self) -> None: """Display help information.""" help_text = """ @@ -161,57 +162,62 @@ async def _show_help(self) -> None: The system will automatically use SQL tools to retrieve data and provide answers. """ print(help_text) - + async def _show_status(self) -> None: """Display system status information.""" try: if not self.driver: print("āŒ System not initialized") return - + config = get_config() metrics = self.driver.get_metrics() - + print("šŸ” System Status:") - print(f" • Status: {'āœ… Ready' if metrics['initialized'] else 'āŒ Not Ready'}") + print( + f" • Status: {'āœ… Ready' if metrics['initialized'] else 'āŒ Not Ready'}" + ) print(f" • Database: {config.database_path}") print(f" • Model: {config.claude_model}") print(f" • Cache Prefix: {config.cache_prefix}") - print(f" • Tools Registered: {len(self.driver.tool_registry.list_tools())}") - + print( + f" • Tools Registered: {len(self.driver.tool_registry.list_tools())}" + ) + except Exception as e: print(f"āŒ Failed to get status: {e}") - + async def _show_tools(self) -> None: """Display available tools.""" try: if not self.driver: print("āŒ System not initialized") return - + tool_names = self.driver.tool_registry.list_tools() - + print("šŸ› ļø Available Tools:") for tool_name in tool_names: tool_def = self.driver.tool_registry.get_tool(tool_name) print(f" • {tool_name}: {tool_def.description}") - + print(f"\nTotal: {len(tool_names)} tools") - + except Exception as e: print(f"āŒ Failed to list tools: {e}") - + async def _show_schema(self) -> None: """Display database schema information.""" try: if not self.driver: print("āŒ System not initialized") return - + # Use the SQL tool to get schema from ..tools.connectors.sql import sql_get_schema + schema_info = await sql_get_schema() - + if schema_info.get("status") == "success": print("šŸ—„ļø Database Schema:") for table in schema_info["tables"]: @@ -219,41 +225,44 @@ async def _show_schema(self) -> None: for column in table["columns"]: nullable = "NULL" if column["nullable"] else "NOT NULL" pk = " (PRIMARY KEY)" if column["primary_key"] else "" - print(f" • {column['name']}: {column['type']} {nullable}{pk}") + print( + f" • {column['name']}: {column['type']} {nullable}{pk}" + ) else: print(f"āŒ Failed to get schema: {schema_info.get('error')}") - + except Exception as e: print(f"āŒ Failed to show schema: {e}") - + async def _show_sample_queries(self) -> None: """Display sample SQL queries.""" try: if not self.driver: print("āŒ System not initialized") return - + # Use the SQL tool to get sample queries from ..tools.connectors.sql import sql_get_sample_queries + samples = sql_get_sample_queries() - + print("šŸ“ Sample Queries:") for i, sample in enumerate(samples["sample_queries"], 1): print(f"\n {i}. {sample['description']}") print(f" {sample['query']}") - + except Exception as e: print(f"āŒ Failed to show samples: {e}") - + async def _show_metrics(self) -> None: """Display performance metrics.""" try: if not self.driver: print("āŒ System not initialized") return - + metrics = self.driver.get_metrics() - + print("šŸ“Š Performance Metrics:") print(f" • Total Queries: {metrics['total_queries']}") print(f" • Cache Hit Rate: {metrics['cache_hit_rate']:.1f}%") @@ -261,17 +270,17 @@ async def _show_metrics(self) -> None: print(f" • Error Rate: {metrics['error_rate']:.1f}%") print(f" • Cache Hits: {metrics['cache_hits']}") print(f" • Cache Misses: {metrics['cache_misses']}") - + except Exception as e: print(f"āŒ Failed to show metrics: {e}") - + async def _shutdown(self) -> None: """Shutdown the CLI and underlying systems.""" try: print("\nšŸ”„ Shutting down...") await shutdown_driver() print("āœ… Shutdown complete") - + except Exception as e: print(f"āŒ Shutdown error: {e}") @@ -279,41 +288,29 @@ async def _shutdown(self) -> None: async def main() -> None: """ Main entry point for the FACT CLI application. - + Handles command-line arguments and starts the appropriate mode. """ parser = argparse.ArgumentParser( description="FACT System - Fast-Access Cached Tools", - formatter_class=argparse.RawDescriptionHelpFormatter - ) - - parser.add_argument( - "--version", - action="version", - version="FACT System v1.0.0" - ) - - parser.add_argument( - "--query", - type=str, - help="Process a single query and exit" - ) - - parser.add_argument( - "--config", - type=str, - help="Path to configuration file" + formatter_class=argparse.RawDescriptionHelpFormatter, ) - + + parser.add_argument("--version", action="version", version="FACT System v1.0.0") + + parser.add_argument("--query", type=str, help="Process a single query and exit") + + parser.add_argument("--config", type=str, help="Path to configuration file") + parser.add_argument( "--log-level", choices=["DEBUG", "INFO", "WARNING", "ERROR"], default="INFO", - help="Set logging level" + help="Set logging level", ) - + args = parser.parse_args() - + # Configure logging structlog.configure( wrapper_class=structlog.make_filtering_bound_logger( @@ -325,14 +322,14 @@ async def main() -> None: structlog.stdlib.add_logger_name, structlog.stdlib.add_log_level, structlog.processors.TimeStamper(fmt="ISO"), - structlog.dev.ConsoleRenderer() + structlog.dev.ConsoleRenderer(), ], cache_logger_on_first_use=True, ) - + try: cli = FACTCLi() - + if args.query: # Single query mode await cli.initialize() @@ -341,7 +338,7 @@ async def main() -> None: else: # Interactive mode await cli.run_interactive() - + except KeyboardInterrupt: print("\nšŸ‘‹ Interrupted by user") sys.exit(0) @@ -352,4 +349,4 @@ async def main() -> None: if __name__ == "__main__": - asyncio.run(main()) \ No newline at end of file + asyncio.run(main()) diff --git a/src/core/config.py b/src/core/config.py index 6161771..6d63681 100644 --- a/src/core/config.py +++ b/src/core/config.py @@ -11,32 +11,33 @@ from dotenv import load_dotenv import structlog - logger = structlog.get_logger(__name__) class ConfigurationError(Exception): """Raised when configuration validation fails.""" + pass class ConnectionError(Exception): """Raised when API connectivity tests fail.""" + pass class Config: """ Central configuration management for the FACT system. - + Handles environment variable loading, validation, and provides structured access to system configuration parameters. """ - + def __init__(self, env_file: Optional[str] = None): """ Initialize configuration from environment variables. - + Args: env_file: Optional path to .env file (defaults to .env) """ @@ -44,7 +45,7 @@ def __init__(self, env_file: Optional[str] = None): self._config: Dict[str, Any] = {} self._load_environment() self._validate_required_keys() - + def _load_environment(self) -> None: """Load environment variables from .env file if it exists.""" if os.path.exists(self.env_file): @@ -52,22 +53,19 @@ def _load_environment(self) -> None: logger.info("Loaded environment configuration", file=self.env_file) else: logger.warning("No .env file found, using system environment") - + def _validate_required_keys(self) -> None: """ Validate that all required configuration keys are present and valid. - + Raises: ConfigurationError: If any required keys are missing or invalid """ - required_keys = [ - "ANTHROPIC_API_KEY", - "ARCADE_API_KEY" - ] - + required_keys = ["ANTHROPIC_API_KEY", "ARCADE_API_KEY"] + missing_keys = [] invalid_keys = [] - + for key in required_keys: value = os.getenv(key) if not value: @@ -76,19 +74,19 @@ def _validate_required_keys(self) -> None: missing_keys.append(key) # Treat whitespace-only as missing elif self._is_placeholder_key(value.strip()): invalid_keys.append(key) - + if missing_keys: raise ConfigurationError( f"Missing required configuration keys: {', '.join(missing_keys)}" ) - + if invalid_keys: raise ConfigurationError( f"Invalid placeholder values for keys: {', '.join(invalid_keys)}. Please set real API keys." ) - + logger.info("Configuration validation passed") - + def _is_placeholder_key(self, value: str) -> bool: """Check if a configuration value is a placeholder.""" placeholder_patterns = [ @@ -98,35 +96,35 @@ def _is_placeholder_key(self, value: str) -> bool: "placeholder", "changeme", "todo", - "fix_me" + "fix_me", ] return any(pattern in value.lower() for pattern in placeholder_patterns) - + @property def anthropic_api_key(self) -> str: """Get Anthropic API key.""" return os.getenv("ANTHROPIC_API_KEY", "") - + @property def arcade_api_key(self) -> str: """Get Arcade API key.""" return os.getenv("ARCADE_API_KEY", "") - + @property def arcade_base_url(self) -> str: """Get Arcade base URL.""" return os.getenv("ARCADE_BASE_URL", "https://api.arcade-ai.com") - + @property def database_path(self) -> str: """Get database file path.""" return os.getenv("DATABASE_PATH", "data/fact_demo.db") - + @property def cache_prefix(self) -> str: """Get cache prefix for Claude caching.""" return os.getenv("CACHE_PREFIX", "fact_v1") - + @property def cache_config(self) -> Dict[str, Any]: """Get cache configuration dictionary.""" @@ -136,9 +134,9 @@ def cache_config(self) -> Dict[str, Any]: "max_size": os.getenv("CACHE_MAX_SIZE", "10MB"), "ttl_seconds": int(os.getenv("CACHE_TTL_SECONDS", "3600")), "hit_target_ms": float(os.getenv("CACHE_HIT_TARGET_MS", "30")), - "miss_target_ms": float(os.getenv("CACHE_MISS_TARGET_MS", "120")) + "miss_target_ms": float(os.getenv("CACHE_MISS_TARGET_MS", "120")), } - + @property def system_prompt(self) -> str: """Get system prompt for Claude.""" @@ -161,32 +159,33 @@ def system_prompt(self) -> str: Example: If asked "What's TechCorp's revenue?" immediately execute: SELECT revenue FROM financial_records WHERE company_id = (SELECT id FROM companies WHERE name LIKE '%TechCorp%') -Always show real data, not placeholders or descriptions of what you would do.""" +Always show real data, not placeholders or descriptions of what you would do.""", ) - + @property def claude_model(self) -> str: """Get Claude model name.""" return os.getenv("CLAUDE_MODEL", "claude-3-haiku-20240307") + @property def max_retries(self) -> int: """Get maximum retry attempts for failed operations.""" return int(os.getenv("MAX_RETRIES", "3")) - + @property def request_timeout(self) -> int: """Get request timeout in seconds.""" return int(os.getenv("REQUEST_TIMEOUT", "30")) - + @property def log_level(self) -> str: """Get logging level.""" return os.getenv("LOG_LEVEL", "INFO") - + def to_dict(self) -> Dict[str, Any]: """ Export configuration as dictionary (excluding sensitive data). - + Returns: Dictionary of non-sensitive configuration values """ @@ -207,7 +206,7 @@ def to_dict(self) -> Dict[str, Any]: def get_config() -> Config: """ Get global configuration instance. - + Returns: Configured Config instance """ @@ -217,10 +216,10 @@ def get_config() -> Config: def validate_configuration(config: Config) -> None: """ Validate configuration and test connectivity to required services. - + Args: config: Configuration instance to validate - + Raises: ConfigurationError: If configuration is invalid ConnectionError: If service connectivity tests fail @@ -228,10 +227,10 @@ def validate_configuration(config: Config) -> None: try: # Basic configuration validation is done in Config.__init__ logger.info("Configuration validation completed successfully") - + # Log configuration summary (without sensitive data) logger.info("Configuration summary", config=config.to_dict()) - + except Exception as e: logger.error("Configuration validation failed", error=str(e)) - raise ConfigurationError(f"Configuration validation failed: {e}") \ No newline at end of file + raise ConfigurationError(f"Configuration validation failed: {e}") diff --git a/src/core/conversation.py b/src/core/conversation.py index 5b89d07..f9fc846 100644 --- a/src/core/conversation.py +++ b/src/core/conversation.py @@ -10,13 +10,13 @@ from dataclasses import dataclass, field import structlog - logger = structlog.get_logger(__name__) @dataclass class ConversationTurn: """Represents a single conversation turn.""" + timestamp: float user_input: str assistant_response: str @@ -28,49 +28,54 @@ class ConversationTurn: @dataclass class ConversationContext: """Manages conversation context and state.""" + conversation_id: str turns: List[ConversationTurn] = field(default_factory=list) current_topic: Optional[str] = None pending_actions: List[str] = field(default_factory=list) database_context: Dict[str, Any] = field(default_factory=dict) - - def add_turn(self, user_input: str, assistant_response: str, - tool_calls: List[Dict[str, Any]] = None, - tool_results: List[Dict[str, Any]] = None) -> None: + + def add_turn( + self, + user_input: str, + assistant_response: str, + tool_calls: List[Dict[str, Any]] = None, + tool_results: List[Dict[str, Any]] = None, + ) -> None: """Add a new conversation turn.""" turn = ConversationTurn( timestamp=time.time(), user_input=user_input, assistant_response=assistant_response, tool_calls=tool_calls or [], - tool_results=tool_results or [] + tool_results=tool_results or [], ) self.turns.append(turn) - + # Detect incomplete responses and add pending actions incomplete_actions = self.detect_incomplete_response(assistant_response) for action in incomplete_actions: self.add_pending_action(action) - + # Keep only last 10 turns to manage memory if len(self.turns) > 10: self.turns = self.turns[-10:] - + def get_context_summary(self) -> str: """Generate a context summary for the LLM.""" if not self.turns: return "" - + context_parts = [] - + # Add current topic if available if self.current_topic: context_parts.append(f"Current conversation topic: {self.current_topic}") - + # Add pending actions if self.pending_actions: context_parts.append(f"Pending actions: {', '.join(self.pending_actions)}") - + # Add recent conversation history (last 3 turns) recent_turns = self.turns[-3:] if len(recent_turns) > 1: # Only add if there's actual history @@ -78,10 +83,12 @@ def get_context_summary(self) -> str: for i, turn in enumerate(recent_turns[:-1], 1): # Exclude current turn context_parts.append(f" {i}. User: {turn.user_input[:100]}...") if turn.assistant_response: - context_parts.append(f" Assistant: {turn.assistant_response[:100]}...") - + context_parts.append( + f" Assistant: {turn.assistant_response[:100]}..." + ) + return "\n".join(context_parts) if context_parts else "" - + def detect_topic(self, user_input: str) -> None: """Detect and update current conversation topic.""" # Simple keyword-based topic detection @@ -89,32 +96,32 @@ def detect_topic(self, user_input: str) -> None: "revenue": ["revenue", "sales", "income", "earnings"], "companies": ["company", "companies", "business", "organization"], "financial_analysis": ["compare", "trend", "analysis", "performance"], - "database_schema": ["schema", "tables", "structure", "database"] + "database_schema": ["schema", "tables", "structure", "database"], } - + user_lower = user_input.lower() for topic, keywords in topics.items(): if any(keyword in user_lower for keyword in keywords): self.current_topic = topic break - + def add_pending_action(self, action: str) -> None: """Add a pending action to be completed.""" if action not in self.pending_actions: self.pending_actions.append(action) logger.info("Added pending action", action=action) - + def complete_action(self, action: str) -> None: """Mark an action as completed.""" if action in self.pending_actions: self.pending_actions.remove(action) logger.info("Completed action", action=action) - + def clear_pending_actions(self) -> None: """Clear all pending actions.""" self.pending_actions.clear() logger.info("Cleared all pending actions") - + def detect_incomplete_response(self, response: str) -> List[str]: """Detect if response is incomplete and suggest follow-up actions.""" incomplete_indicators = [ @@ -127,82 +134,86 @@ def detect_incomplete_response(self, response: str) -> List[str]: "I'll show", "I'll analyze", "Looking at", - "I see that" + "I see that", ] - + suggested_actions = [] response_lower = response.lower() - + # Check for incomplete indicators for indicator in incomplete_indicators: if indicator.lower() in response_lower: # Extract what was promised if "retrieve" in response_lower or "query" in response_lower: - suggested_actions.append("Execute the database query to get the data") + suggested_actions.append( + "Execute the database query to get the data" + ) elif "show" in response_lower or "display" in response_lower: suggested_actions.append("Display the requested information") elif "analyze" in response_lower or "compare" in response_lower: suggested_actions.append("Perform the analysis as mentioned") break - + # Check for specific incomplete patterns if "companies in the technology sector" in response_lower: suggested_actions.append("Query and display Technology sector companies") elif "revenue trends" in response_lower: - suggested_actions.append("Query and analyze revenue trends across companies") + suggested_actions.append( + "Query and analyze revenue trends across companies" + ) elif "compare" in response_lower and "companies" in response_lower: suggested_actions.append("Compare companies as requested") - + return suggested_actions class ConversationManager: """Manages conversation contexts and provides enhanced prompting.""" - + def __init__(self): """Initialize conversation manager.""" self.conversations: Dict[str, ConversationContext] = {} self.current_conversation_id: Optional[str] = None - + def start_conversation(self, conversation_id: Optional[str] = None) -> str: """Start a new conversation or get existing one.""" if conversation_id is None: conversation_id = f"conv_{int(time.time() * 1000)}" - + if conversation_id not in self.conversations: self.conversations[conversation_id] = ConversationContext(conversation_id) logger.info("Started new conversation", conversation_id=conversation_id) - + self.current_conversation_id = conversation_id return conversation_id - + def get_current_context(self) -> Optional[ConversationContext]: """Get current conversation context.""" if self.current_conversation_id: return self.conversations.get(self.current_conversation_id) return None - + def enhance_system_prompt(self, base_prompt: str, user_input: str) -> str: """Enhance system prompt with conversation context.""" context = self.get_current_context() if not context: return base_prompt - + # Update topic based on current input context.detect_topic(user_input) - + # Get context summary context_summary = context.get_context_summary() - + enhanced_prompt = base_prompt - + if context_summary: enhanced_prompt += f"\n\nCONVERSATION CONTEXT:\n{context_summary}" - + # Add specific guidance based on context if context.pending_actions: enhanced_prompt += "\n\nIMPORTANT: You have pending actions to complete. Make sure to follow through with the necessary tool calls to provide complete answers." - + # Add topic-specific guidance if context.current_topic == "revenue": enhanced_prompt += "\n\nCONTEXT: User is interested in revenue/financial data. Use SQL tools to get specific numbers and trends." @@ -210,25 +221,25 @@ def enhance_system_prompt(self, base_prompt: str, user_input: str) -> str: enhanced_prompt += "\n\nCONTEXT: User is asking about companies. Use SQL tools to get company information and details." elif context.current_topic == "financial_analysis": enhanced_prompt += "\n\nCONTEXT: User wants financial analysis/comparisons. Execute queries to get data, then provide insights and comparisons." - + enhanced_prompt += "\n\nCRITICAL: Always complete your analysis. If you identify relevant data to retrieve, execute the SQL queries and provide the actual results, not just acknowledgments." - + return enhanced_prompt - + def should_auto_continue(self) -> bool: """Check if system should automatically continue with pending actions.""" context = self.get_current_context() return bool(context and context.pending_actions) - + def generate_follow_up_prompt(self) -> Optional[str]: """Generate a follow-up prompt to complete pending actions.""" context = self.get_current_context() if not context or not context.pending_actions: return None - + # Generate prompt based on pending actions action = context.pending_actions[0] # Take first pending action - + if "query" in action.lower() and "technology sector" in action.lower(): return "Show me all companies in the Technology sector" elif "revenue trends" in action.lower(): @@ -239,16 +250,22 @@ def generate_follow_up_prompt(self) -> Optional[str]: return "Please execute the database query to get the data" else: return "Please continue and complete the previous response" - - def add_turn(self, user_input: str, assistant_response: str, - tool_calls: List[Dict[str, Any]] = None, - tool_results: List[Dict[str, Any]] = None) -> None: + + def add_turn( + self, + user_input: str, + assistant_response: str, + tool_calls: List[Dict[str, Any]] = None, + tool_results: List[Dict[str, Any]] = None, + ) -> None: """Add a turn to current conversation.""" context = self.get_current_context() if context: context.add_turn(user_input, assistant_response, tool_calls, tool_results) - - def detect_incomplete_response(self, user_input: str, assistant_response: str) -> bool: + + def detect_incomplete_response( + self, user_input: str, assistant_response: str + ) -> bool: """Detect if the assistant response seems incomplete.""" # Check for common incomplete response patterns incomplete_indicators = [ @@ -257,9 +274,9 @@ def detect_incomplete_response(self, user_input: str, assistant_response: str) - "I see that there is", "I'll retrieve", "Let me get", - "I'll check" + "I'll check", ] - + # Check if response contains tool identification without execution if any(indicator in assistant_response for indicator in incomplete_indicators): # Check if it actually contains meaningful data/results @@ -269,17 +286,17 @@ def detect_incomplete_response(self, user_input: str, assistant_response: str) - "shows", "Total:", "revenue of", - "companies are" + "companies are", ] if not any(data in assistant_response for data in data_indicators): return True - + # Check for very short responses that don't answer the question if len(assistant_response.strip()) < 100 and "?" not in user_input: return True - + return False - + def get_continuation_prompt(self) -> str: """Get a prompt to encourage continuation of incomplete responses.""" return """Continue with your analysis. Execute the necessary SQL queries to get the actual data and provide specific results to the user. Don't just describe what you will do - do it and show the results.""" @@ -294,4 +311,4 @@ def get_conversation_manager() -> ConversationManager: global _conversation_manager if _conversation_manager is None: _conversation_manager = ConversationManager() - return _conversation_manager \ No newline at end of file + return _conversation_manager diff --git a/src/core/driver.py b/src/core/driver.py index a3a4860..b9fc604 100644 --- a/src/core/driver.py +++ b/src/core/driver.py @@ -16,10 +16,17 @@ from .config import Config, get_config, validate_configuration from .errors import ( - FACTError, ConfigurationError, ConnectionError, ToolExecutionError, - classify_error, create_user_friendly_message, log_error_with_context, - provide_graceful_degradation, CacheError + FACTError, + ConfigurationError, + ConnectionError, + ToolExecutionError, + classify_error, + create_user_friendly_message, + log_error_with_context, + provide_graceful_degradation, + CacheError, ) + try: # Try relative imports first (when used as package) from ..db.connection import DatabaseManager @@ -32,11 +39,12 @@ # Fall back to absolute imports (when run as script) import sys from pathlib import Path + # Add src to path if not already there src_path = str(Path(__file__).parent.parent) if src_path not in sys.path: sys.path.insert(0, src_path) - + from db.connection import DatabaseManager from tools.decorators import get_tool_registry from tools.connectors.sql import initialize_sql_tool @@ -47,17 +55,18 @@ logger = structlog.get_logger(__name__) + class FACTDriver: """ - + Manages cache control, query processing, tool execution, and system coordination following the FACT architecture principles. """ - + def __init__(self, config: Optional[Config] = None): """ Initialize FACT driver with configuration. - + Args: config: Optional configuration instance (creates default if None) """ @@ -66,19 +75,19 @@ def __init__(self, config: Optional[Config] = None): self.tool_registry = get_tool_registry() self.cache_system: Optional[FACTCacheSystem] = None self._initialized = False - + # Monitoring and metrics self.metrics_collector = get_metrics_collector() - + # Cache resilience components self.cache_circuit_breaker: Optional[CacheCircuitBreaker] = None self.resilient_cache: Optional[ResilientCacheWrapper] = None self._cache_degraded = False - + async def initialize(self) -> None: """ Initialize the FACT system components. - + Raises: ConfigurationError: If configuration is invalid ConnectionError: If service connections fail @@ -86,54 +95,56 @@ async def initialize(self) -> None: if self._initialized: logger.info("FACT driver already initialized") return - + try: logger.info("Initializing FACT system") - + # Validate configuration validate_configuration(self.config) - + # Initialize database await self._initialize_database() - + # Initialize cache system await self._initialize_cache() - + # Initialize tools await self._initialize_tools() - + # Test connections await self._test_connections() - + self._initialized = True logger.info("FACT system initialized successfully") - + except Exception as e: logger.error("FACT system initialization failed", error=str(e)) raise ConfigurationError(f"System initialization failed: {e}") - + async def process_query(self, user_input: str) -> str: """ Process a user query through the FACT pipeline. - + Args: user_input: Natural language query from user - + Returns: Generated response string - + Raises: FACTError: If query processing fails """ if not self._initialized: await self.initialize() - + query_id = f"query_{int(time.time() * 1000)}" start_time = time.time() - + try: - logger.info("Processing user query", query_id=query_id, query=user_input[:100]) - + logger.info( + "Processing user query", query_id=query_id, query=user_input[:100] + ) + # Step 1: Check cache first (cache-first pattern) with resilience cached_response = None if not self._cache_degraded: @@ -146,336 +157,361 @@ async def process_query(self, user_input: str) -> str: cached_response = cache_entry.content elif self.cache_system: # Fallback to direct cache system - cached_response = await self.cache_system.get_cached_response(user_input) - + cached_response = await self.cache_system.get_cached_response( + user_input + ) + except CacheError as e: if "CIRCUIT_BREAKER" in str(e.error_code): - logger.info("Cache circuit breaker active - proceeding without cache", - query_id=query_id, - circuit_state=self.cache_circuit_breaker.get_state().value if self.cache_circuit_breaker else "unknown") + logger.info( + "Cache circuit breaker active - proceeding without cache", + query_id=query_id, + circuit_state=( + self.cache_circuit_breaker.get_state().value + if self.cache_circuit_breaker + else "unknown" + ), + ) else: - logger.warning("Cache operation failed - continuing without cache", - query_id=query_id, - error=str(e)) + logger.warning( + "Cache operation failed - continuing without cache", + query_id=query_id, + error=str(e), + ) except Exception as e: - logger.warning("Unexpected cache error - continuing without cache", - query_id=query_id, - error=str(e)) - + logger.warning( + "Unexpected cache error - continuing without cache", + query_id=query_id, + error=str(e), + ) + if cached_response: # Cache hit - return cached response end_time = time.time() latency = (end_time - start_time) * 1000 - - logger.info("Cache hit - returning cached response", - query_id=query_id, - latency_ms=latency) - + + logger.info( + "Cache hit - returning cached response", + query_id=query_id, + latency_ms=latency, + ) + return cached_response - + # Step 2: Cache miss - process query normally logger.info("Cache miss - processing query with LLM", query_id=query_id) - + # Prepare messages for LLM - messages = [ - { - "role": "user", - "content": user_input - } - ] - + messages = [{"role": "user", "content": user_input}] + # Get tool schemas for LLM tool_schemas = self.tool_registry.export_all_schemas() - + # Make initial LLM call with cache control response = await self._call_llm_with_cache( - messages=messages, - tools=tool_schemas, - cache_mode="read" + messages=messages, tools=tool_schemas, cache_mode="read" ) - + # Handle tool calls if present (Anthropic format) tool_use_blocks = [] - if hasattr(response, 'content') and response.content: + if hasattr(response, "content") and response.content: for block in response.content: - if hasattr(block, 'type') and block.type == 'tool_use': + if hasattr(block, "type") and block.type == "tool_use": tool_use_blocks.append(block) - + if tool_use_blocks: logger.info("Processing tool calls", count=len(tool_use_blocks)) - + # Add assistant message with tool_use blocks to conversation - assistant_message = { - "role": "assistant", - "content": response.content - } + assistant_message = {"role": "assistant", "content": response.content} messages.append(assistant_message) - + # Execute tool calls tool_results = await self._execute_tool_calls(tool_use_blocks) - + # Add tool results to message history messages.extend(tool_results) - + # Get final response with tool results response = await self._call_llm_with_cache( - messages=messages, - tools=tool_schemas, - cache_mode="read" + messages=messages, tools=tool_schemas, cache_mode="read" ) - + # Check if the final response also has tool calls final_tool_blocks = [] - if hasattr(response, 'content') and response.content: + if hasattr(response, "content") and response.content: for block in response.content: - if hasattr(block, 'type') and block.type == 'tool_use': + if hasattr(block, "type") and block.type == "tool_use": final_tool_blocks.append(block) - + # Execute any final tool calls and continue until we get text response max_iterations = 5 # Prevent infinite loops iteration = 0 - + while final_tool_blocks and iteration < max_iterations: - logger.info("Processing final tool calls", count=len(final_tool_blocks), iteration=iteration) - + logger.info( + "Processing final tool calls", + count=len(final_tool_blocks), + iteration=iteration, + ) + # Add assistant message with tool_use blocks assistant_message = { "role": "assistant", - "content": response.content + "content": response.content, } messages.append(assistant_message) - + # Execute final tool calls - final_tool_results = await self._execute_tool_calls(final_tool_blocks) - + final_tool_results = await self._execute_tool_calls( + final_tool_blocks + ) + # Add final tool results messages.extend(final_tool_results) - + # Get final response response = await self._call_llm_with_cache( - messages=messages, - tools=tool_schemas, - cache_mode="read" + messages=messages, tools=tool_schemas, cache_mode="read" ) - + # Check if this response also has tool calls final_tool_blocks = [] - if hasattr(response, 'content') and response.content: + if hasattr(response, "content") and response.content: for block in response.content: - if hasattr(block, 'type') and block.type == 'tool_use': + if hasattr(block, "type") and block.type == "tool_use": final_tool_blocks.append(block) - + iteration += 1 - + # Extract response content from Anthropic SDK response response_text = "" - if hasattr(response, 'content') and response.content: + if hasattr(response, "content") and response.content: for block in response.content: - if hasattr(block, 'type') and block.type == 'text': + if hasattr(block, "type") and block.type == "text": response_text += block.text - + # If no text content found, provide informative message if not response_text: logger.warning("No text response received from LLM", query_id=query_id) response_text = "I apologize, but I was unable to generate a proper response. Please try rephrasing your question." - + # Step 3: Store response in cache for future use with resilience if not self._cache_degraded and response_text: try: if self.resilient_cache: # Use resilient cache with circuit breaker query_hash = self.resilient_cache.generate_hash(user_input) - cache_entry = await self.resilient_cache.store(query_hash, response_text) + cache_entry = await self.resilient_cache.store( + query_hash, response_text + ) if cache_entry: logger.debug("Response stored in cache", query_id=query_id) else: - logger.debug("Response not suitable for caching", query_id=query_id) + logger.debug( + "Response not suitable for caching", query_id=query_id + ) elif self.cache_system: # Fallback to direct cache system - cache_stored = await self.cache_system.store_response(user_input, response_text) + cache_stored = await self.cache_system.store_response( + user_input, response_text + ) if cache_stored: logger.debug("Response stored in cache", query_id=query_id) else: - logger.debug("Response not suitable for caching", query_id=query_id) - + logger.debug( + "Response not suitable for caching", query_id=query_id + ) + except CacheError as e: if "CIRCUIT_BREAKER" in str(e.error_code): - logger.debug("Cache storage skipped - circuit breaker active", query_id=query_id) + logger.debug( + "Cache storage skipped - circuit breaker active", + query_id=query_id, + ) else: - logger.warning("Cache storage failed - continuing without cache storage", - query_id=query_id, - error=str(e)) + logger.warning( + "Cache storage failed - continuing without cache storage", + query_id=query_id, + error=str(e), + ) except Exception as e: - logger.warning("Unexpected cache storage error", - query_id=query_id, - error=str(e)) - + logger.warning( + "Unexpected cache storage error", + query_id=query_id, + error=str(e), + ) + # Log performance metrics end_time = time.time() latency = (end_time - start_time) * 1000 - - logger.info("Query processed successfully", - query_id=query_id, - latency_ms=latency, - response_length=len(response_text)) - + + logger.info( + "Query processed successfully", + query_id=query_id, + latency_ms=latency, + response_length=len(response_text), + ) + return response_text - + except Exception as e: end_time = time.time() latency = (end_time - start_time) * 1000 - + # Record error in metrics collector self.metrics_collector.record_tool_execution( tool_name="fact_query", success=False, execution_time=latency, error_type=type(e).__name__, - metadata={"query_id": query_id} + metadata={"query_id": query_id}, ) - + # Log error with context - log_error_with_context(e, { - "query_id": query_id, - "user_input": user_input[:100], - "latency_ms": latency - }) - + log_error_with_context( + e, + { + "query_id": query_id, + "user_input": user_input[:100], + "latency_ms": latency, + }, + ) + # Handle error with graceful degradation error_category = classify_error(e) - + if error_category in ["connectivity", "tool_execution"]: return provide_graceful_degradation(error_category) else: return create_user_friendly_message(e) - + async def _initialize_database(self) -> None: """Initialize database connection and schema.""" try: self.database_manager = DatabaseManager(self.config.database_path) await self.database_manager.initialize_database() logger.info("Database initialized successfully") - + except Exception as e: logger.error("Database initialization failed", error=str(e)) raise ConfigurationError(f"Database initialization failed: {e}") - + async def _initialize_cache(self) -> None: """Initialize cache system with resilience and circuit breaker protection.""" try: cache_config = self.config.cache_config - + # Initialize base cache system self.cache_system = await initialize_cache_system( - config=cache_config, - enable_background_tasks=True + config=cache_config, enable_background_tasks=True ) - + # Initialize circuit breaker for cache resilience from ..cache.resilience import CircuitBreakerConfig - + circuit_config = CircuitBreakerConfig( failure_threshold=5, # Open after 5 failures success_threshold=3, # Close after 3 successes timeout_seconds=60.0, # Wait 60s before retry rolling_window_seconds=300.0, # 5-minute window gradual_recovery=True, - recovery_factor=0.5 # 50% of requests during recovery + recovery_factor=0.5, # 50% of requests during recovery ) - + self.cache_circuit_breaker = CacheCircuitBreaker(circuit_config) - + # Wrap cache system with resilient wrapper - if hasattr(self.cache_system, 'cache_manager'): + if hasattr(self.cache_system, "cache_manager"): self.resilient_cache = ResilientCacheWrapper( - self.cache_system.cache_manager, - self.cache_circuit_breaker + self.cache_system.cache_manager, self.cache_circuit_breaker ) - + # Start health monitoring await self.resilient_cache.start_monitoring() - - logger.info("Cache system with resilience initialized successfully", - prefix=cache_config["prefix"], - max_size=cache_config["max_size"], - circuit_breaker_enabled=True) - + + logger.info( + "Cache system with resilience initialized successfully", + prefix=cache_config["prefix"], + max_size=cache_config["max_size"], + circuit_breaker_enabled=True, + ) + except Exception as e: logger.error("Cache system initialization failed", error=str(e)) # Enable graceful degradation - continue without cache self._cache_degraded = True logger.warning("Continuing with cache degradation mode") - + async def _initialize_tools(self) -> None: """Initialize and register system tools.""" try: # Initialize SQL tool with database manager initialize_sql_tool(self.database_manager) - + # Log registered tools tool_info = self.tool_registry.get_tool_info() logger.info("Tools initialized", **tool_info) - + except Exception as e: logger.error("Tool initialization failed", error=str(e)) raise ConfigurationError(f"Tool initialization failed: {e}") - + async def _test_connections(self) -> None: """Test connections to external services.""" # Skip API validation if configured for testing if os.getenv("SKIP_API_VALIDATION", "false").lower() == "true": logger.info("Skipping API validation for testing environment") return - + try: # Test database connection if self.database_manager: await self.database_manager.get_database_info() logger.info("Database connection test passed") - + # Test LLM connection with direct Anthropic SDK client = anthropic.Anthropic(api_key=self.config.anthropic_api_key) test_response = client.messages.create( model=self.config.claude_model, messages=[{"role": "user", "content": "Test"}], - max_tokens=10 + max_tokens=10, ) - + if test_response: logger.info("LLM connection test passed") - + except Exception as e: logger.error("Connection test failed", error=str(e)) raise ConnectionError(f"Service connection test failed: {e}") - - async def _call_llm_with_cache(self, - messages: List[Dict[str, Any]], - tools: List[Dict[str, Any]], - cache_mode: str = "read") -> Any: + + async def _call_llm_with_cache( + self, + messages: List[Dict[str, Any]], + tools: List[Dict[str, Any]], + cache_mode: str = "read", + ) -> Any: """ Call LLM with cache control and tool support. - + Args: messages: Message history for LLM tools: Available tools for LLM cache_mode: Cache mode ('read' or 'write') - + Returns: LLM response object """ try: # Configure cache control - cache_control = { - "mode": cache_mode, - "prefix": self.config.cache_prefix - } - + cache_control = {"mode": cache_mode, "prefix": self.config.cache_prefix} + # Track cache behavior # Cache hits/misses can be tracked via tool execution metadata if needed - + # Make LLM call with direct Anthropic SDK client = anthropic.Anthropic(api_key=self.config.anthropic_api_key) - + # Anthropic API requires system prompt as separate parameter, not in messages response = client.messages.create( model=self.config.claude_model, @@ -486,46 +522,47 @@ async def _call_llm_with_cache(self, tools=tools if tools else None, tool_choice={"type": "any"} if tools else None, ) - + return response - + except Exception as e: logger.error("LLM call failed", error=str(e), cache_mode=cache_mode) raise ConnectionError(f"LLM call failed: {e}") - + async def _execute_tool_calls(self, tool_calls: List[Any]) -> List[Dict[str, Any]]: """ Execute tool calls and format results as messages. - + Args: tool_calls: List of tool calls from LLM - + Returns: List of tool result messages """ tool_messages = [] - + for call in tool_calls: try: # Record tool execution in metrics_collector - + # Extract tool information (Anthropic format) tool_name = call.name tool_args = call.input - + # tool_args should already be a dict in Anthropic format if isinstance(tool_args, str): import json + tool_args = json.loads(tool_args) # Get tool definition tool_definition = self.tool_registry.get_tool(tool_name) - + # Execute tool with proper async handling if asyncio.iscoroutinefunction(tool_definition.function): result = await tool_definition.function(**tool_args) else: result = tool_definition.function(**tool_args) - + # Format as tool message (Anthropic format) tool_message = { "role": "user", @@ -533,33 +570,35 @@ async def _execute_tool_calls(self, tool_calls: List[Any]) -> List[Dict[str, Any { "type": "tool_result", "tool_use_id": call.id, - "content": str(result) if not isinstance(result, str) else result + "content": ( + str(result) if not isinstance(result, str) else result + ), } - ] + ], } - - logger.info("Tool executed successfully", - tool_name=tool_name, - execution_time=result.get("execution_time_ms", 0)) + + logger.info( + "Tool executed successfully", + tool_name=tool_name, + execution_time=result.get("execution_time_ms", 0), + ) self.metrics_collector.record_tool_execution( tool_name=tool_name, success=True, execution_time=result.get("execution_time_ms", 0), - metadata={"args": tool_args} + metadata={"args": tool_args}, ) - + except Exception as e: - logger.error("Tool execution failed", - tool_name=tool_name, - error=str(e)) + logger.error("Tool execution failed", tool_name=tool_name, error=str(e)) self.metrics_collector.record_tool_execution( tool_name=tool_name, success=False, execution_time=0, error_type=str(e), - metadata={"args": tool_args} + metadata={"args": tool_args}, ) - + # Format error as tool message (Anthropic format) tool_message = { "role": "user", @@ -567,29 +606,31 @@ async def _execute_tool_calls(self, tool_calls: List[Any]) -> List[Dict[str, Any { "type": "tool_result", "tool_use_id": call.id, - "content": str({ - "error": "Tool execution failed", - "details": str(e), - "status": "failed" - }) + "content": str( + { + "error": "Tool execution failed", + "details": str(e), + "status": "failed", + } + ), } - ] + ], } - + tool_messages.append(tool_message) - + return tool_messages - + def get_metrics(self) -> Dict[str, Any]: """ Get system performance metrics. - + Returns: Dictionary containing performance metrics """ # Use the unified metrics collector for system metrics sys_metrics = self.metrics_collector.get_system_metrics() - + # Get cache metrics including circuit breaker metrics cache_metrics = {} if not self._cache_degraded: @@ -597,11 +638,11 @@ def get_metrics(self) -> Dict[str, Any]: if self.resilient_cache: # Get comprehensive metrics from resilient cache resilient_metrics = self.resilient_cache.get_metrics() - + # Extract cache metrics cache_data = resilient_metrics.get("cache", {}) circuit_data = resilient_metrics.get("circuit_breaker", {}) - + cache_metrics = { "cache_hit_rate": cache_data.get("hit_rate", 0), "cache_hits": cache_data.get("cache_hits", 0), @@ -609,13 +650,20 @@ def get_metrics(self) -> Dict[str, Any]: "cache_total_entries": cache_data.get("total_entries", 0), "cache_total_size": cache_data.get("total_size", 0), "cache_token_efficiency": cache_data.get("token_efficiency", 0), - # Circuit breaker metrics "circuit_breaker_state": circuit_data.get("state", "unknown"), - "circuit_breaker_failures": circuit_data.get("failure_count", 0), - "circuit_breaker_successes": circuit_data.get("success_count", 0), - "circuit_breaker_failure_rate": circuit_data.get("failure_rate", 0), - "circuit_breaker_state_changes": circuit_data.get("state_changes", 0) + "circuit_breaker_failures": circuit_data.get( + "failure_count", 0 + ), + "circuit_breaker_successes": circuit_data.get( + "success_count", 0 + ), + "circuit_breaker_failure_rate": circuit_data.get( + "failure_rate", 0 + ), + "circuit_breaker_state_changes": circuit_data.get( + "state_changes", 0 + ), } elif self.cache_system: # Fallback to basic cache metrics @@ -627,7 +675,7 @@ def get_metrics(self) -> Dict[str, Any]: "cache_total_entries": basic_cache_metrics.total_entries, "cache_total_size": basic_cache_metrics.total_size, "cache_token_efficiency": basic_cache_metrics.token_efficiency, - "circuit_breaker_state": "disabled" + "circuit_breaker_state": "disabled", } except Exception as e: logger.warning("Failed to get cache metrics", error=str(e)) @@ -635,7 +683,7 @@ def get_metrics(self) -> Dict[str, Any]: "cache_hit_rate": 0, "cache_hits": 0, "cache_misses": 0, - "circuit_breaker_state": "error" + "circuit_breaker_state": "error", } else: cache_metrics = { @@ -643,21 +691,21 @@ def get_metrics(self) -> Dict[str, Any]: "cache_hits": 0, "cache_misses": 0, "cache_degraded": True, - "circuit_breaker_state": "degraded" + "circuit_breaker_state": "degraded", } - + return { "total_queries": sys_metrics.total_executions, "tool_executions": sys_metrics.total_executions, "error_rate": sys_metrics.error_rate, "initialized": self._initialized, - **cache_metrics + **cache_metrics, } - + async def shutdown(self) -> None: """Gracefully shutdown the FACT system.""" logger.info("Shutting down FACT system") - + # Shutdown resilient cache monitoring if self.resilient_cache: try: @@ -665,7 +713,7 @@ async def shutdown(self) -> None: self.resilient_cache = None except Exception as e: logger.warning("Error stopping cache monitoring", error=str(e)) - + # Shutdown cache circuit breaker if self.cache_circuit_breaker: try: @@ -673,17 +721,17 @@ async def shutdown(self) -> None: self.cache_circuit_breaker = None except Exception as e: logger.warning("Error stopping circuit breaker", error=str(e)) - + # Shutdown cache system if self.cache_system: await self.cache_system.shutdown() self.cache_system = None - + # Close database connections if self.database_manager: # Database manager handles its own cleanup pass - + self._initialized = False self._cache_degraded = False logger.info("FACT system shutdown complete") @@ -696,26 +744,26 @@ async def shutdown(self) -> None: async def get_driver(config: Optional[Config] = None) -> FACTDriver: """ Get or create the global FACT driver instance. - + Args: config: Optional configuration (only used for first creation) - + Returns: Initialized FACTDriver instance """ global _driver_instance - + if _driver_instance is None: _driver_instance = FACTDriver(config) await _driver_instance.initialize() - + return _driver_instance async def shutdown_driver() -> None: """Shutdown the global driver instance.""" global _driver_instance - + if _driver_instance: await _driver_instance.shutdown() _driver_instance = None @@ -724,14 +772,14 @@ async def shutdown_driver() -> None: async def process_user_query(query: str) -> str: """ Process user query using the global driver instance. - + This is a compatibility wrapper for the benchmarking framework. - + Args: query: User query string - + Returns: Response string """ driver = await get_driver() - return await driver.process_query(query) \ No newline at end of file + return await driver.process_query(query) diff --git a/src/core/errors.py b/src/core/errors.py index f08ca62..9dd8fc9 100644 --- a/src/core/errors.py +++ b/src/core/errors.py @@ -8,17 +8,21 @@ from typing import Optional, Dict, Any import structlog - logger = structlog.get_logger(__name__) class FACTError(Exception): """Base exception class for all FACT system errors.""" - - def __init__(self, message: str, error_code: Optional[str] = None, context: Optional[Dict[str, Any]] = None): + + def __init__( + self, + message: str, + error_code: Optional[str] = None, + context: Optional[Dict[str, Any]] = None, + ): """ Initialize FACT error. - + Args: message: Human-readable error message error_code: Optional error code for categorization @@ -32,91 +36,106 @@ def __init__(self, message: str, error_code: Optional[str] = None, context: Opti class ConfigurationError(FACTError): """Raised when system configuration is invalid or missing.""" + pass class ConnectionError(FACTError): """Raised when API connectivity tests fail.""" + pass class AuthenticationError(FACTError): """Raised when API authentication fails.""" + pass class ValidationError(FACTError): """Raised when input validation fails.""" + pass class ToolExecutionError(FACTError): """Raised when tool execution fails.""" + pass class ToolValidationError(FACTError): """Raised when tool validation fails.""" + pass class ToolNotFoundError(FACTError): """Raised when a requested tool is not found.""" + pass class UnauthorizedError(FACTError): """Raised when user lacks authorization for a tool.""" + pass class InvalidArgumentsError(FACTError): """Raised when tool arguments are invalid.""" + pass class DatabaseError(FACTError): """Raised when database operations fail.""" + pass class SecurityError(FACTError): """Raised when security violations are detected.""" + pass class InvalidSQLError(FACTError): """Raised when SQL statements are invalid or dangerous.""" + pass class CacheError(FACTError): """Raised when cache operations fail.""" + pass class ToolRegistrationError(FACTError): """Raised when tool registration fails.""" + pass class FinalRetryError(FACTError): """Raised when maximum retry attempts are exceeded.""" + pass def classify_error(error: Exception) -> str: """ Classify an error into a category for handling strategies. - + Args: error: Exception to classify - + Returns: String category of the error """ error_type = type(error) - + # Map error types to categories error_categories = { ConfigurationError: "configuration", @@ -135,29 +154,31 @@ def classify_error(error: Exception) -> str: ToolRegistrationError: "tool_registration", FinalRetryError: "connectivity", } - + category = error_categories.get(error_type, "unknown") - - logger.debug("Error classified", - error_type=error_type.__name__, - category=category, - message=str(error)) - + + logger.debug( + "Error classified", + error_type=error_type.__name__, + category=category, + message=str(error), + ) + return category def create_user_friendly_message(error: Exception) -> str: """ Create a user-friendly error message from an exception. - + Args: error: Exception to convert - + Returns: User-friendly error message """ category = classify_error(error) - + # Default user-friendly messages by category friendly_messages = { "configuration": "System configuration error. Please check your setup.", @@ -171,18 +192,20 @@ def create_user_friendly_message(error: Exception) -> str: "tool_registration": "Tool registration failed. Please check your tool configuration.", "unknown": "An unexpected error occurred. Please try again later.", } - + # Use specific message if available, otherwise use category default - if hasattr(error, 'message'): + if hasattr(error, "message"): return error.message - + return friendly_messages.get(category, str(error)) -def log_error_with_context(error: Exception, context: Optional[Dict[str, Any]] = None) -> None: +def log_error_with_context( + error: Exception, context: Optional[Dict[str, Any]] = None +) -> None: """ Log error with full context information. - + Args: error: Exception to log context: Additional context information @@ -192,25 +215,25 @@ def log_error_with_context(error: Exception, context: Optional[Dict[str, Any]] = "error_message": str(error), "error_category": classify_error(error), } - + # Add specific error context if available - if hasattr(error, 'context'): + if hasattr(error, "context"): error_context.update(error.context) - + # Add provided context if context: error_context.update(context) - + logger.error("Error occurred", **error_context) def provide_graceful_degradation(failed_component: str) -> str: """ Provide graceful degradation message for failed components. - + Args: failed_component: Name of the component that failed - + Returns: Graceful degradation message """ @@ -221,14 +244,13 @@ def provide_graceful_degradation(failed_component: str) -> str: "anthropic": "Claude API is temporarily unavailable. Please try again later.", "arcade": "Tool execution service is temporarily unavailable. Please try again later.", } - + message = degradation_messages.get( - failed_component, - "System is experiencing issues. Please try again later." + failed_component, "System is experiencing issues. Please try again later." ) - - logger.warning("Graceful degradation activated", - component=failed_component, - message=message) - - return message \ No newline at end of file + + logger.warning( + "Graceful degradation activated", component=failed_component, message=message + ) + + return message diff --git a/src/db/connection.py b/src/db/connection.py index 4d01985..4a1f181 100644 --- a/src/db/connection.py +++ b/src/db/connection.py @@ -22,24 +22,25 @@ SAMPLE_COMPANIES, SAMPLE_FINANCIAL_RECORDS, QueryResult, - validate_schema_integrity + validate_schema_integrity, ) except ImportError: # Fall back to absolute imports (when run as script) import sys from pathlib import Path + # Add src to path if not already there src_path = str(Path(__file__).parent.parent) if src_path not in sys.path: sys.path.insert(0, src_path) - + from core.errors import DatabaseError, SecurityError, InvalidSQLError from db.models import ( DATABASE_SCHEMA, SAMPLE_COMPANIES, SAMPLE_FINANCIAL_RECORDS, QueryResult, - validate_schema_integrity + validate_schema_integrity, ) @@ -49,14 +50,14 @@ class AsyncConnectionPool: """ Async connection pool for SQLite database connections. - + Provides connection reuse and management to reduce connection overhead. """ - + def __init__(self, database_path: str, pool_size: int = 10): """ Initialize connection pool. - + Args: database_path: Path to SQLite database pool_size: Maximum number of connections in pool @@ -66,7 +67,7 @@ def __init__(self, database_path: str, pool_size: int = 10): self.pool: asyncio.Queue = asyncio.Queue(maxsize=pool_size) self.created_connections = 0 self._lock = asyncio.Lock() - + async def get_connection(self) -> aiosqlite.Connection: """Get a connection from the pool or create a new one.""" try: @@ -80,15 +81,17 @@ async def get_connection(self) -> aiosqlite.Connection: if self.created_connections < self.pool_size: conn = await aiosqlite.connect(self.database_path) self.created_connections += 1 - logger.debug("Created new pooled connection", - total_connections=self.created_connections) + logger.debug( + "Created new pooled connection", + total_connections=self.created_connections, + ) return conn else: # Wait for a connection to become available conn = await self.pool.get() logger.debug("Retrieved connection from pool after wait") return conn - + async def return_connection(self, conn: aiosqlite.Connection): """Return a connection to the pool.""" try: @@ -99,9 +102,10 @@ async def return_connection(self, conn: aiosqlite.Connection): await conn.close() async with self._lock: self.created_connections -= 1 - logger.debug("Closed excess connection", - total_connections=self.created_connections) - + logger.debug( + "Closed excess connection", total_connections=self.created_connections + ) + async def close_all(self): """Close all connections in the pool.""" connections_closed = 0 @@ -112,25 +116,27 @@ async def close_all(self): connections_closed += 1 except asyncio.QueueEmpty: break - + async with self._lock: self.created_connections = 0 - - logger.info("Closed all pooled connections", connections_closed=connections_closed) + + logger.info( + "Closed all pooled connections", connections_closed=connections_closed + ) class DatabaseManager: """ Manages SQLite database connections and operations for the FACT system. - + Provides secure database access with read-only query validation, connection pooling, and performance monitoring. """ - + def __init__(self, database_path: str, pool_size: int = 10): """ Initialize database manager with connection pooling. - + Args: database_path: Path to SQLite database file pool_size: Maximum number of connections in pool @@ -141,18 +147,18 @@ def __init__(self, database_path: str, pool_size: int = 10): self._ensure_directory_exists() self._query_plan_cache = {} # Cache for validated queries self._cache_max_size = 1000 - + def _ensure_directory_exists(self) -> None: """Ensure the database directory exists.""" db_dir = os.path.dirname(self.database_path) if db_dir and not os.path.exists(db_dir): os.makedirs(db_dir, exist_ok=True) logger.info("Created database directory", path=db_dir) - + async def initialize_database(self) -> None: """ Initialize database with schema and sample data. - + Raises: DatabaseError: If database initialization fails """ @@ -161,41 +167,50 @@ async def initialize_database(self) -> None: # Execute schema creation await db.executescript(DATABASE_SCHEMA) await db.commit() - + # Check if data already exists cursor = await db.execute("SELECT COUNT(*) FROM companies") company_count = (await cursor.fetchone())[0] await cursor.close() - + # Check if financial_data table has data cursor = await db.execute("SELECT COUNT(*) FROM financial_data") financial_data_count = (await cursor.fetchone())[0] await cursor.close() - + # Check if benchmarks table has data cursor = await db.execute("SELECT COUNT(*) FROM benchmarks") benchmarks_count = (await cursor.fetchone())[0] await cursor.close() - + if company_count == 0: # Batch insert sample companies for better performance - await db.executemany(""" + await db.executemany( + """ INSERT INTO companies (name, symbol, sector, founded_year, employees, market_cap) VALUES (:name, :symbol, :sector, :founded_year, :employees, :market_cap) - """, SAMPLE_COMPANIES) - + """, + SAMPLE_COMPANIES, + ) + # Batch insert sample financial records - await db.executemany(""" + await db.executemany( + """ INSERT INTO financial_records (company_id, quarter, year, revenue, profit, expenses) VALUES (:company_id, :quarter, :year, :revenue, :profit, :expenses) - """, SAMPLE_FINANCIAL_RECORDS) - + """, + SAMPLE_FINANCIAL_RECORDS, + ) + # Batch insert into financial_data table for validation compatibility - await db.executemany(""" + await db.executemany( + """ INSERT INTO financial_data (company_id, quarter, year, revenue, profit, expenses) VALUES (:company_id, :quarter, :year, :revenue, :profit, :expenses) - """, SAMPLE_FINANCIAL_RECORDS) - + """, + SAMPLE_FINANCIAL_RECORDS, + ) + # Insert sample benchmark data sample_benchmarks = [ { @@ -205,7 +220,7 @@ async def initialize_database(self) -> None: "cache_hit_rate": 0.0, "average_response_time_ms": 0.0, "success_rate": 1.0, - "notes": "Initial system setup benchmark" + "notes": "Initial system setup benchmark", }, { "test_name": "cache_warming", @@ -214,33 +229,39 @@ async def initialize_database(self) -> None: "cache_hit_rate": 0.0, "average_response_time_ms": 18.02, "success_rate": 1.0, - "notes": "Cache warming performance test" - } + "notes": "Cache warming performance test", + }, ] - + for benchmark in sample_benchmarks: - await db.execute(""" + await db.execute( + """ INSERT INTO benchmarks (test_name, duration_ms, queries_executed, cache_hit_rate, average_response_time_ms, success_rate, notes) VALUES (:test_name, :duration_ms, :queries_executed, :cache_hit_rate, :average_response_time_ms, :success_rate, :notes) - """, benchmark) - + """, + benchmark, + ) + await db.commit() logger.info("Database initialized with sample data") - + # Handle the case where companies exist but other tables don't elif financial_data_count == 0: logger.info("Adding missing financial_data records") # Insert sample financial records into financial_data table for record in SAMPLE_FINANCIAL_RECORDS: - await db.execute(""" + await db.execute( + """ INSERT INTO financial_data (company_id, quarter, year, revenue, profit, expenses) VALUES (:company_id, :quarter, :year, :revenue, :profit, :expenses) - """, record) + """, + record, + ) await db.commit() logger.info("Financial data populated") - + if benchmarks_count == 0: logger.info("Adding missing benchmark records") # Insert sample benchmark data @@ -252,7 +273,7 @@ async def initialize_database(self) -> None: "cache_hit_rate": 0.0, "average_response_time_ms": 0.0, "success_rate": 1.0, - "notes": "Initial system setup benchmark" + "notes": "Initial system setup benchmark", }, { "test_name": "cache_warming", @@ -261,108 +282,139 @@ async def initialize_database(self) -> None: "cache_hit_rate": 0.0, "average_response_time_ms": 18.02, "success_rate": 1.0, - "notes": "Cache warming performance test" - } + "notes": "Cache warming performance test", + }, ] - + for benchmark in sample_benchmarks: - await db.execute(""" + await db.execute( + """ INSERT INTO benchmarks (test_name, duration_ms, queries_executed, cache_hit_rate, average_response_time_ms, success_rate, notes) VALUES (:test_name, :duration_ms, :queries_executed, :cache_hit_rate, :average_response_time_ms, :success_rate, :notes) - """, benchmark) + """, + benchmark, + ) await db.commit() logger.info("Benchmark data populated") else: - logger.info("Database already contains data, skipping sample data insertion") - + logger.info( + "Database already contains data, skipping sample data insertion" + ) + # Validate schema integrity - cursor = await db.execute("SELECT name FROM sqlite_master WHERE type='table'") + cursor = await db.execute( + "SELECT name FROM sqlite_master WHERE type='table'" + ) tables = await cursor.fetchall() await cursor.close() - + tables_info = [{"name": table[0]} for table in tables] if not validate_schema_integrity(tables_info): raise DatabaseError("Database schema validation failed") - + except Exception as e: logger.error("Database initialization failed", error=str(e)) raise DatabaseError(f"Failed to initialize database: {e}") + def validate_sql_query(self, statement: str) -> None: """ Validate SQL query for security and syntax with caching. - + Args: statement: SQL statement to validate - + Raises: SecurityError: If statement contains dangerous operations InvalidSQLError: If statement has syntax errors """ # Generate cache key for query validation import hashlib + query_hash = hashlib.md5(statement.encode()).hexdigest() - + # Check cache first if query_hash in self._query_plan_cache: logger.debug("Query validation cache hit", statement=statement[:100]) return - + normalized_statement = statement.lower().strip() - + # Security check: allow SELECT statements and safe PRAGMA queries is_select = normalized_statement.startswith("select") is_safe_pragma = normalized_statement.startswith("pragma table_info") - + if not (is_select or is_safe_pragma): - raise SecurityError("Only SELECT statements and PRAGMA table_info queries are allowed") - + raise SecurityError( + "Only SELECT statements and PRAGMA table_info queries are allowed" + ) + # Enhanced dangerous keyword detection (excluding safe pragma) dangerous_keywords = [ - "drop", "delete", "update", "insert", "alter", "create", - "truncate", "replace", "merge", "exec", "execute", - "attach", "detach", "vacuum", "reindex", "analyze" + "drop", + "delete", + "update", + "insert", + "alter", + "create", + "truncate", + "replace", + "merge", + "exec", + "execute", + "attach", + "detach", + "vacuum", + "reindex", + "analyze", ] - + # For PRAGMA queries, only allow table_info - if normalized_statement.startswith("pragma") and not normalized_statement.startswith("pragma table_info"): + if normalized_statement.startswith( + "pragma" + ) and not normalized_statement.startswith("pragma table_info"): raise SecurityError("Only PRAGMA table_info queries are allowed") - + # Check for dangerous keywords with word boundaries import re + for keyword in dangerous_keywords: - pattern = r'\b' + re.escape(keyword) + r'\b' + pattern = r"\b" + re.escape(keyword) + r"\b" if re.search(pattern, normalized_statement, re.IGNORECASE): raise SecurityError(f"Dangerous SQL keyword detected: {keyword}") - + # Check for SQL injection patterns # Check for SQL injection patterns (skip for safe PRAGMA queries) if not normalized_statement.startswith("pragma table_info"): injection_patterns = [ - r'--', # SQL comments - r'/\*.*?\*/', # Multi-line comments - r';\s*\w+', # Multiple statements - r'\bunion\s+select\b', # Union injection - r'\bor\s+1\s*=\s*1\b', # Always true conditions - r'\band\s+1\s*=\s*1\b', # Always true conditions - r'\bor\s+\'.*?\'\s*=\s*\'.*?\'', # Suspicious OR with string comparisons - r'\'.*?\'\s*or\s*\'.*?\'', # Injection with OR between quotes - r'\\x[0-9a-f]{2}', # Hex encoding + r"--", # SQL comments + r"/\*.*?\*/", # Multi-line comments + r";\s*\w+", # Multiple statements + r"\bunion\s+select\b", # Union injection + r"\bor\s+1\s*=\s*1\b", # Always true conditions + r"\band\s+1\s*=\s*1\b", # Always true conditions + r"\bor\s+\'.*?\'\s*=\s*\'.*?\'", # Suspicious OR with string comparisons + r"\'.*?\'\s*or\s*\'.*?\'", # Injection with OR between quotes + r"\\x[0-9a-f]{2}", # Hex encoding ] - + for pattern in injection_patterns: if re.search(pattern, normalized_statement, re.IGNORECASE): - raise SecurityError(f"Potential SQL injection pattern detected: {pattern} in query: {normalized_statement[:100]}") + raise SecurityError( + f"Potential SQL injection pattern detected: {pattern} in query: {normalized_statement[:100]}" + ) # Limit query complexity if len(statement) > 5000: raise SecurityError("Query too long - potential DoS attack") - + # Count nested subqueries to prevent complex injection attacks - subquery_count = normalized_statement.count('select') + subquery_count = normalized_statement.count("select") if subquery_count > 5: - raise SecurityError("Too many nested subqueries - potential injection attack") - + raise SecurityError( + "Too many nested subqueries - potential injection attack" + ) + # Basic syntax validation using actual database connection try: # Parse SQL to check syntax (without executing) @@ -370,59 +422,65 @@ def validate_sql_query(self, statement: str) -> None: conn.execute(f"EXPLAIN QUERY PLAN {statement}") except sqlite3.Error as e: raise InvalidSQLError(f"SQL syntax error: {e}") - + # Cache successful validation if len(self._query_plan_cache) >= self._cache_max_size: # Simple eviction: remove oldest entries oldest_keys = list(self._query_plan_cache.keys())[:100] for key in oldest_keys: del self._query_plan_cache[key] - + self._query_plan_cache[query_hash] = time.time() - logger.debug("SQL query validation passed and cached", statement=statement[:100]) - + logger.debug( + "SQL query validation passed and cached", statement=statement[:100] + ) + def _is_valid_table_name(self, table_name: str) -> bool: """ Validate table name to prevent SQL injection. - + Args: table_name: Table name to validate - + Returns: True if table name is valid """ import re - + # Table name should only contain alphanumeric characters and underscores - if not re.match(r'^[a-zA-Z_][a-zA-Z0-9_]*$', table_name): + if not re.match(r"^[a-zA-Z_][a-zA-Z0-9_]*$", table_name): return False - + # Reasonable length limit if len(table_name) > 64: return False - + # Check against known system tables that should not be accessed forbidden_tables = { - 'sqlite_master', 'sqlite_temp_master', 'sqlite_sequence', - 'sqlite_stat1', 'sqlite_stat2', 'sqlite_stat3', 'sqlite_stat4' + "sqlite_master", + "sqlite_temp_master", + "sqlite_sequence", + "sqlite_stat1", + "sqlite_stat2", + "sqlite_stat3", + "sqlite_stat4", } - + if table_name.lower() in forbidden_tables: return False - + return True - logger.debug("SQL query validation passed", statement=statement[:100]) - + async def execute_query(self, statement: str) -> QueryResult: """ Execute a validated SQL query using connection pool. - + Args: statement: SQL SELECT statement to execute - + Returns: QueryResult containing rows, metadata, and timing - + Raises: DatabaseError: If query execution fails SecurityError: If statement violates security rules @@ -430,71 +488,75 @@ async def execute_query(self, statement: str) -> QueryResult: """ # Validate query first (with caching) self.validate_sql_query(statement) - + start_time = time.time() - + try: # Get connection from pool db = await self.connection_pool.get_connection() - + try: # Enable row factory for dictionary-like access db.row_factory = aiosqlite.Row - + cursor = await db.execute(statement) rows = await cursor.fetchall() - + # Convert rows to dictionaries structured_results = [] columns = [] - + if rows: # Get column names from first row columns = list(rows[0].keys()) - + # Convert each row to dictionary for row in rows: row_dict = {col: row[col] for col in columns} structured_results.append(row_dict) - + await cursor.close() - + end_time = time.time() execution_time_ms = (end_time - start_time) * 1000 - + result = QueryResult( rows=structured_results, row_count=len(structured_results), columns=columns, - execution_time_ms=execution_time_ms + execution_time_ms=execution_time_ms, ) - - logger.info("Query executed successfully", - statement=statement[:100], - row_count=result.row_count, - execution_time_ms=execution_time_ms) - + + logger.info( + "Query executed successfully", + statement=statement[:100], + row_count=result.row_count, + execution_time_ms=execution_time_ms, + ) + return result - + finally: # Always return connection to pool await self.connection_pool.return_connection(db) - + except Exception as e: end_time = time.time() execution_time_ms = (end_time - start_time) * 1000 - - logger.error("Query execution failed", - statement=statement[:100], - error=str(e), - execution_time_ms=execution_time_ms) - + + logger.error( + "Query execution failed", + statement=statement[:100], + error=str(e), + execution_time_ms=execution_time_ms, + ) + raise DatabaseError(f"Query execution failed: {e}") - + async def get_database_info(self) -> Dict[str, Any]: """ Get database metadata and statistics. - + Returns: Dictionary containing database information """ @@ -507,35 +569,41 @@ async def get_database_info(self) -> Dict[str, Any]: """) tables = await cursor.fetchall() await cursor.close() - + table_info = {} for (table_name,) in tables: # Validate table name to prevent injection if not self._is_valid_table_name(table_name): - logger.warning("Invalid table name detected", table_name=table_name) + logger.warning( + "Invalid table name detected", table_name=table_name + ) continue - + # Get row count for each table using parameterized query # Note: Table names cannot be parameterized, so we validate them first - cursor = await db.execute(f"SELECT COUNT(*) FROM \"{table_name}\"") + cursor = await db.execute(f'SELECT COUNT(*) FROM "{table_name}"') count = (await cursor.fetchone())[0] await cursor.close() table_info[table_name] = {"row_count": count} - + # Get database file size - file_size = os.path.getsize(self.database_path) if os.path.exists(self.database_path) else 0 - + file_size = ( + os.path.getsize(self.database_path) + if os.path.exists(self.database_path) + else 0 + ) + return { "database_path": self.database_path, "file_size_bytes": file_size, "tables": table_info, - "total_tables": len(tables) + "total_tables": len(tables), } - + except Exception as e: logger.error("Failed to get database info", error=str(e)) raise DatabaseError(f"Failed to get database info: {e}") - + async def cleanup(self): """ Cleanup database resources including connection pool. @@ -545,12 +613,12 @@ async def cleanup(self): logger.info("Database manager cleanup completed") except Exception as e: logger.error("Database cleanup failed", error=str(e)) - + @asynccontextmanager async def get_connection(self): """ Get an async database connection context manager. - + Yields: aiosqlite.Connection: Database connection """ @@ -568,11 +636,11 @@ async def get_connection(self): def create_database_manager(database_path: str) -> DatabaseManager: """ Create and initialize a database manager instance. - + Args: database_path: Path to SQLite database file - + Returns: Configured DatabaseManager instance """ - return DatabaseManager(database_path) \ No newline at end of file + return DatabaseManager(database_path) diff --git a/src/db/models.py b/src/db/models.py index 859b9a2..41e4a03 100644 --- a/src/db/models.py +++ b/src/db/models.py @@ -10,13 +10,13 @@ from datetime import datetime import structlog - logger = structlog.get_logger(__name__) @dataclass class FinancialRecord: """Represents a financial record in the database.""" + id: int company: str quarter: str @@ -31,6 +31,7 @@ class FinancialRecord: @dataclass class Company: """Represents a company in the database.""" + id: int name: str symbol: str @@ -45,6 +46,7 @@ class Company: @dataclass class QueryResult: """Represents the result of a database query.""" + rows: List[Dict[str, Any]] row_count: int columns: List[str] @@ -131,7 +133,7 @@ class QueryResult: "sector": "Technology", "founded_year": 1995, "employees": 50000, - "market_cap": 250000000000.0 + "market_cap": 250000000000.0, }, { "name": "FinanceFirst LLC", @@ -139,7 +141,7 @@ class QueryResult: "sector": "Financial Services", "founded_year": 1988, "employees": 25000, - "market_cap": 125000000000.0 + "market_cap": 125000000000.0, }, { "name": "HealthTech Solutions", @@ -147,7 +149,7 @@ class QueryResult: "sector": "Healthcare", "founded_year": 2005, "employees": 15000, - "market_cap": 75000000000.0 + "market_cap": 75000000000.0, }, { "name": "Green Energy Corp", @@ -155,7 +157,7 @@ class QueryResult: "sector": "Energy", "founded_year": 2010, "employees": 8000, - "market_cap": 45000000000.0 + "market_cap": 45000000000.0, }, { "name": "RetailMax Group", @@ -163,52 +165,223 @@ class QueryResult: "sector": "Retail", "founded_year": 1975, "employees": 120000, - "market_cap": 95000000000.0 - } + "market_cap": 95000000000.0, + }, ] SAMPLE_FINANCIAL_RECORDS = [ # TechCorp Inc. (company_id: 1) - {"company_id": 1, "quarter": "Q1", "year": 2025, "revenue": 25000000000.0, "profit": 5000000000.0, "expenses": 20000000000.0}, - {"company_id": 1, "quarter": "Q4", "year": 2024, "revenue": 28000000000.0, "profit": 6000000000.0, "expenses": 22000000000.0}, - {"company_id": 1, "quarter": "Q3", "year": 2024, "revenue": 26000000000.0, "profit": 5500000000.0, "expenses": 20500000000.0}, - {"company_id": 1, "quarter": "Q2", "year": 2024, "revenue": 24000000000.0, "profit": 4800000000.0, "expenses": 19200000000.0}, - {"company_id": 1, "quarter": "Q1", "year": 2024, "revenue": 23000000000.0, "profit": 4600000000.0, "expenses": 18400000000.0}, - + { + "company_id": 1, + "quarter": "Q1", + "year": 2025, + "revenue": 25000000000.0, + "profit": 5000000000.0, + "expenses": 20000000000.0, + }, + { + "company_id": 1, + "quarter": "Q4", + "year": 2024, + "revenue": 28000000000.0, + "profit": 6000000000.0, + "expenses": 22000000000.0, + }, + { + "company_id": 1, + "quarter": "Q3", + "year": 2024, + "revenue": 26000000000.0, + "profit": 5500000000.0, + "expenses": 20500000000.0, + }, + { + "company_id": 1, + "quarter": "Q2", + "year": 2024, + "revenue": 24000000000.0, + "profit": 4800000000.0, + "expenses": 19200000000.0, + }, + { + "company_id": 1, + "quarter": "Q1", + "year": 2024, + "revenue": 23000000000.0, + "profit": 4600000000.0, + "expenses": 18400000000.0, + }, # FinanceFirst LLC (company_id: 2) - {"company_id": 2, "quarter": "Q1", "year": 2025, "revenue": 12000000000.0, "profit": 3000000000.0, "expenses": 9000000000.0}, - {"company_id": 2, "quarter": "Q4", "year": 2024, "revenue": 13500000000.0, "profit": 3200000000.0, "expenses": 10300000000.0}, - {"company_id": 2, "quarter": "Q3", "year": 2024, "revenue": 13000000000.0, "profit": 3100000000.0, "expenses": 9900000000.0}, - {"company_id": 2, "quarter": "Q2", "year": 2024, "revenue": 12500000000.0, "profit": 2900000000.0, "expenses": 9600000000.0}, - {"company_id": 2, "quarter": "Q1", "year": 2024, "revenue": 11800000000.0, "profit": 2800000000.0, "expenses": 9000000000.0}, - + { + "company_id": 2, + "quarter": "Q1", + "year": 2025, + "revenue": 12000000000.0, + "profit": 3000000000.0, + "expenses": 9000000000.0, + }, + { + "company_id": 2, + "quarter": "Q4", + "year": 2024, + "revenue": 13500000000.0, + "profit": 3200000000.0, + "expenses": 10300000000.0, + }, + { + "company_id": 2, + "quarter": "Q3", + "year": 2024, + "revenue": 13000000000.0, + "profit": 3100000000.0, + "expenses": 9900000000.0, + }, + { + "company_id": 2, + "quarter": "Q2", + "year": 2024, + "revenue": 12500000000.0, + "profit": 2900000000.0, + "expenses": 9600000000.0, + }, + { + "company_id": 2, + "quarter": "Q1", + "year": 2024, + "revenue": 11800000000.0, + "profit": 2800000000.0, + "expenses": 9000000000.0, + }, # HealthTech Solutions (company_id: 3) - {"company_id": 3, "quarter": "Q1", "year": 2025, "revenue": 8000000000.0, "profit": 1200000000.0, "expenses": 6800000000.0}, - {"company_id": 3, "quarter": "Q4", "year": 2024, "revenue": 8500000000.0, "profit": 1300000000.0, "expenses": 7200000000.0}, - {"company_id": 3, "quarter": "Q3", "year": 2024, "revenue": 8200000000.0, "profit": 1250000000.0, "expenses": 6950000000.0}, - {"company_id": 3, "quarter": "Q2", "year": 2024, "revenue": 7800000000.0, "profit": 1150000000.0, "expenses": 6650000000.0}, - {"company_id": 3, "quarter": "Q1", "year": 2024, "revenue": 7500000000.0, "profit": 1100000000.0, "expenses": 6400000000.0}, - + { + "company_id": 3, + "quarter": "Q1", + "year": 2025, + "revenue": 8000000000.0, + "profit": 1200000000.0, + "expenses": 6800000000.0, + }, + { + "company_id": 3, + "quarter": "Q4", + "year": 2024, + "revenue": 8500000000.0, + "profit": 1300000000.0, + "expenses": 7200000000.0, + }, + { + "company_id": 3, + "quarter": "Q3", + "year": 2024, + "revenue": 8200000000.0, + "profit": 1250000000.0, + "expenses": 6950000000.0, + }, + { + "company_id": 3, + "quarter": "Q2", + "year": 2024, + "revenue": 7800000000.0, + "profit": 1150000000.0, + "expenses": 6650000000.0, + }, + { + "company_id": 3, + "quarter": "Q1", + "year": 2024, + "revenue": 7500000000.0, + "profit": 1100000000.0, + "expenses": 6400000000.0, + }, # Green Energy Corp (company_id: 4) - {"company_id": 4, "quarter": "Q1", "year": 2025, "revenue": 5500000000.0, "profit": 800000000.0, "expenses": 4700000000.0}, - {"company_id": 4, "quarter": "Q4", "year": 2024, "revenue": 6000000000.0, "profit": 900000000.0, "expenses": 5100000000.0}, - {"company_id": 4, "quarter": "Q3", "year": 2024, "revenue": 5800000000.0, "profit": 850000000.0, "expenses": 4950000000.0}, - {"company_id": 4, "quarter": "Q2", "year": 2024, "revenue": 5200000000.0, "profit": 750000000.0, "expenses": 4450000000.0}, - {"company_id": 4, "quarter": "Q1", "year": 2024, "revenue": 4800000000.0, "profit": 680000000.0, "expenses": 4120000000.0}, - + { + "company_id": 4, + "quarter": "Q1", + "year": 2025, + "revenue": 5500000000.0, + "profit": 800000000.0, + "expenses": 4700000000.0, + }, + { + "company_id": 4, + "quarter": "Q4", + "year": 2024, + "revenue": 6000000000.0, + "profit": 900000000.0, + "expenses": 5100000000.0, + }, + { + "company_id": 4, + "quarter": "Q3", + "year": 2024, + "revenue": 5800000000.0, + "profit": 850000000.0, + "expenses": 4950000000.0, + }, + { + "company_id": 4, + "quarter": "Q2", + "year": 2024, + "revenue": 5200000000.0, + "profit": 750000000.0, + "expenses": 4450000000.0, + }, + { + "company_id": 4, + "quarter": "Q1", + "year": 2024, + "revenue": 4800000000.0, + "profit": 680000000.0, + "expenses": 4120000000.0, + }, # RetailMax Group (company_id: 5) - {"company_id": 5, "quarter": "Q1", "year": 2025, "revenue": 18000000000.0, "profit": 1800000000.0, "expenses": 16200000000.0}, - {"company_id": 5, "quarter": "Q4", "year": 2024, "revenue": 22000000000.0, "profit": 2400000000.0, "expenses": 19600000000.0}, # Holiday season - {"company_id": 5, "quarter": "Q3", "year": 2024, "revenue": 19000000000.0, "profit": 2000000000.0, "expenses": 17000000000.0}, - {"company_id": 5, "quarter": "Q2", "year": 2024, "revenue": 17500000000.0, "profit": 1750000000.0, "expenses": 15750000000.0}, - {"company_id": 5, "quarter": "Q1", "year": 2024, "revenue": 16800000000.0, "profit": 1650000000.0, "expenses": 15150000000.0}, + { + "company_id": 5, + "quarter": "Q1", + "year": 2025, + "revenue": 18000000000.0, + "profit": 1800000000.0, + "expenses": 16200000000.0, + }, + { + "company_id": 5, + "quarter": "Q4", + "year": 2024, + "revenue": 22000000000.0, + "profit": 2400000000.0, + "expenses": 19600000000.0, + }, # Holiday season + { + "company_id": 5, + "quarter": "Q3", + "year": 2024, + "revenue": 19000000000.0, + "profit": 2000000000.0, + "expenses": 17000000000.0, + }, + { + "company_id": 5, + "quarter": "Q2", + "year": 2024, + "revenue": 17500000000.0, + "profit": 1750000000.0, + "expenses": 15750000000.0, + }, + { + "company_id": 5, + "quarter": "Q1", + "year": 2024, + "revenue": 16800000000.0, + "profit": 1650000000.0, + "expenses": 15150000000.0, + }, ] def get_sample_queries() -> List[str]: """ Get list of sample queries for testing and demonstration. - + Returns: List of sample SQL queries """ @@ -219,27 +392,27 @@ def get_sample_queries() -> List[str]: "SELECT sector, COUNT(*) as company_count FROM companies GROUP BY sector", "SELECT c.name, f.quarter, f.year, f.revenue FROM companies c JOIN financial_records f ON c.id = f.company_id WHERE c.symbol = 'TECH' ORDER BY f.year DESC, f.quarter DESC", "SELECT AVG(revenue) as avg_revenue, AVG(profit) as avg_profit FROM financial_records WHERE year = 2024", - "SELECT c.name, c.market_cap, f.revenue FROM companies c JOIN financial_records f ON c.id = f.company_id WHERE f.year = 2025 AND f.quarter = 'Q1' ORDER BY c.market_cap DESC" + "SELECT c.name, c.market_cap, f.revenue FROM companies c JOIN financial_records f ON c.id = f.company_id WHERE f.year = 2025 AND f.quarter = 'Q1' ORDER BY c.market_cap DESC", ] def validate_schema_integrity(tables_info: List[Dict[str, Any]]) -> bool: """ Validate that the database schema matches expected structure. - + Args: tables_info: Information about database tables - + Returns: True if schema is valid, False otherwise """ expected_tables = {"companies", "financial_records", "financial_data", "benchmarks"} actual_tables = {table["name"] for table in tables_info} - + if not expected_tables.issubset(actual_tables): missing_tables = expected_tables - actual_tables logger.error("Missing database tables", missing=list(missing_tables)) return False - + logger.info("Database schema validation passed") - return True \ No newline at end of file + return True diff --git a/src/monitoring/__init__.py b/src/monitoring/__init__.py index 6c4fae1..759f49f 100644 --- a/src/monitoring/__init__.py +++ b/src/monitoring/__init__.py @@ -7,14 +7,24 @@ # Use try/except to handle both relative and absolute imports try: - from .metrics import MetricsCollector, get_metrics_collector, SystemMetrics, ToolExecutionMetric + from .metrics import ( + MetricsCollector, + get_metrics_collector, + SystemMetrics, + ToolExecutionMetric, + ) except ImportError: # Fallback to absolute imports when called from scripts - from monitoring.metrics import MetricsCollector, get_metrics_collector, SystemMetrics, ToolExecutionMetric + from monitoring.metrics import ( + MetricsCollector, + get_metrics_collector, + SystemMetrics, + ToolExecutionMetric, + ) __all__ = [ - 'MetricsCollector', - 'get_metrics_collector', - 'SystemMetrics', - 'ToolExecutionMetric' -] \ No newline at end of file + "MetricsCollector", + "get_metrics_collector", + "SystemMetrics", + "ToolExecutionMetric", +] diff --git a/src/monitoring/metrics.py b/src/monitoring/metrics.py index 83c01d8..1f3fe52 100644 --- a/src/monitoring/metrics.py +++ b/src/monitoring/metrics.py @@ -18,6 +18,7 @@ @dataclass class ToolExecutionMetric: """Represents a single tool execution metric.""" + tool_name: str success: bool execution_time_ms: float @@ -30,11 +31,12 @@ class ToolExecutionMetric: @dataclass class SystemMetrics: """Aggregated system metrics.""" + total_executions: int = 0 successful_executions: int = 0 failed_executions: int = 0 average_execution_time: float = 0.0 - min_execution_time: float = float('inf') + min_execution_time: float = float("inf") max_execution_time: float = 0.0 error_rate: float = 0.0 executions_per_minute: float = 0.0 @@ -45,47 +47,51 @@ class SystemMetrics: class MetricsCollector: """ Collects and aggregates performance metrics for the FACT system. - + Provides real-time metrics collection, aggregation, and reporting for tool execution performance and system health monitoring. """ - + def __init__(self, max_history: int = 10000): """ Initialize metrics collector. - + Args: max_history: Maximum number of metrics to keep in memory """ self.max_history = max_history self.metrics_history: deque = deque(maxlen=max_history) - self.tool_stats: Dict[str, Dict[str, Any]] = defaultdict(lambda: { - 'count': 0, - 'success_count': 0, - 'total_time': 0.0, - 'min_time': float('inf'), - 'max_time': 0.0, - 'recent_executions': deque(maxlen=100) - }) - + self.tool_stats: Dict[str, Dict[str, Any]] = defaultdict( + lambda: { + "count": 0, + "success_count": 0, + "total_time": 0.0, + "min_time": float("inf"), + "max_time": 0.0, + "recent_executions": deque(maxlen=100), + } + ) + # Thread safety self._lock = threading.RLock() - + # Start time for rate calculations self._start_time = time.time() - + logger.info("MetricsCollector initialized", max_history=max_history) - - def record_tool_execution(self, - tool_name: str, - success: bool, - execution_time: float, - user_id: Optional[str] = None, - error_type: Optional[str] = None, - metadata: Optional[Dict[str, Any]] = None) -> None: + + def record_tool_execution( + self, + tool_name: str, + success: bool, + execution_time: float, + user_id: Optional[str] = None, + error_type: Optional[str] = None, + metadata: Optional[Dict[str, Any]] = None, + ) -> None: """ Record a tool execution metric. - + Args: tool_name: Name of the executed tool success: Whether execution was successful @@ -96,7 +102,7 @@ def record_tool_execution(self, """ with self._lock: timestamp = time.time() - + # Create metric record metric = ToolExecutionMetric( tool_name=tool_name, @@ -105,81 +111,92 @@ def record_tool_execution(self, timestamp=timestamp, user_id=user_id, error_type=error_type, - metadata=metadata or {} + metadata=metadata or {}, ) - + # Add to history self.metrics_history.append(metric) - + # Update tool-specific stats tool_stat = self.tool_stats[tool_name] - tool_stat['count'] += 1 - tool_stat['total_time'] += execution_time - tool_stat['recent_executions'].append({ - 'timestamp': timestamp, - 'success': success, - 'execution_time': execution_time - }) - + tool_stat["count"] += 1 + tool_stat["total_time"] += execution_time + tool_stat["recent_executions"].append( + { + "timestamp": timestamp, + "success": success, + "execution_time": execution_time, + } + ) + if success: - tool_stat['success_count'] += 1 - + tool_stat["success_count"] += 1 + # Update min/max times - tool_stat['min_time'] = min(tool_stat['min_time'], execution_time) - tool_stat['max_time'] = max(tool_stat['max_time'], execution_time) - - logger.debug("Tool execution metric recorded", - tool_name=tool_name, - success=success, - execution_time_ms=execution_time, - user_id=user_id) - + tool_stat["min_time"] = min(tool_stat["min_time"], execution_time) + tool_stat["max_time"] = max(tool_stat["max_time"], execution_time) + + logger.debug( + "Tool execution metric recorded", + tool_name=tool_name, + success=success, + execution_time_ms=execution_time, + user_id=user_id, + ) + def get_system_metrics(self, time_window_minutes: int = 60) -> SystemMetrics: """ Get aggregated system metrics. - + Args: time_window_minutes: Time window for calculations in minutes - + Returns: SystemMetrics object with aggregated data """ with self._lock: current_time = time.time() cutoff_time = current_time - (time_window_minutes * 60) - + # Filter metrics within time window recent_metrics = [ - metric for metric in self.metrics_history + metric + for metric in self.metrics_history if metric.timestamp >= cutoff_time ] - + if not recent_metrics: return SystemMetrics() - + # Calculate basic stats total_executions = len(recent_metrics) successful_executions = sum(1 for m in recent_metrics if m.success) failed_executions = total_executions - successful_executions - + execution_times = [m.execution_time_ms for m in recent_metrics] average_execution_time = sum(execution_times) / len(execution_times) min_execution_time = min(execution_times) max_execution_time = max(execution_times) - - error_rate = (failed_executions / total_executions) * 100 if total_executions > 0 else 0 + + error_rate = ( + (failed_executions / total_executions) * 100 + if total_executions > 0 + else 0 + ) executions_per_minute = total_executions / time_window_minutes - + # Get top tools by usage tool_counts = defaultdict(int) for metric in recent_metrics: tool_counts[metric.tool_name] += 1 - + top_tools = [ {"tool_name": tool, "count": count} - for tool, count in sorted(tool_counts.items(), key=lambda x: x[1], reverse=True)[:10] + for tool, count in sorted( + tool_counts.items(), key=lambda x: x[1], reverse=True + )[:10] ] - + # Get recent errors recent_errors = [ { @@ -187,11 +204,14 @@ def get_system_metrics(self, time_window_minutes: int = 60) -> SystemMetrics: "error_type": m.error_type, "timestamp": m.timestamp, "execution_time_ms": m.execution_time_ms, - "user_id": m.user_id + "user_id": m.user_id, } - for m in recent_metrics if not m.success - ][-10:] # Last 10 errors - + for m in recent_metrics + if not m.success + ][ + -10: + ] # Last 10 errors + return SystemMetrics( total_executions=total_executions, successful_executions=successful_executions, @@ -202,16 +222,16 @@ def get_system_metrics(self, time_window_minutes: int = 60) -> SystemMetrics: error_rate=error_rate, executions_per_minute=executions_per_minute, top_tools=top_tools, - recent_errors=recent_errors + recent_errors=recent_errors, ) - + def get_tool_metrics(self, tool_name: str) -> Dict[str, Any]: """ Get detailed metrics for a specific tool. - + Args: tool_name: Name of the tool - + Returns: Tool-specific metrics """ @@ -223,96 +243,114 @@ def get_tool_metrics(self, tool_name: str) -> Dict[str, Any]: "success_rate": 0.0, "average_execution_time": 0.0, "min_execution_time": 0.0, - "max_execution_time": 0.0 + "max_execution_time": 0.0, } - + stats = self.tool_stats[tool_name] - + # Calculate derived metrics - success_rate = (stats['success_count'] / stats['count']) * 100 if stats['count'] > 0 else 0 - average_time = stats['total_time'] / stats['count'] if stats['count'] > 0 else 0 - + success_rate = ( + (stats["success_count"] / stats["count"]) * 100 + if stats["count"] > 0 + else 0 + ) + average_time = ( + stats["total_time"] / stats["count"] if stats["count"] > 0 else 0 + ) + # Recent performance (last 24 hours) current_time = time.time() cutoff_time = current_time - (24 * 60 * 60) # 24 hours ago - + recent_executions = [ - exec for exec in stats['recent_executions'] - if exec['timestamp'] >= cutoff_time + exec + for exec in stats["recent_executions"] + if exec["timestamp"] >= cutoff_time ] - + recent_success_rate = 0.0 recent_avg_time = 0.0 - + if recent_executions: - recent_successes = sum(1 for exec in recent_executions if exec['success']) + recent_successes = sum( + 1 for exec in recent_executions if exec["success"] + ) recent_success_rate = (recent_successes / len(recent_executions)) * 100 - recent_avg_time = sum(exec['execution_time'] for exec in recent_executions) / len(recent_executions) - + recent_avg_time = sum( + exec["execution_time"] for exec in recent_executions + ) / len(recent_executions) + return { "tool_name": tool_name, - "total_executions": stats['count'], - "successful_executions": stats['success_count'], - "failed_executions": stats['count'] - stats['success_count'], + "total_executions": stats["count"], + "successful_executions": stats["success_count"], + "failed_executions": stats["count"] - stats["success_count"], "success_rate": success_rate, "average_execution_time": average_time, - "min_execution_time": stats['min_time'] if stats['min_time'] != float('inf') else 0, - "max_execution_time": stats['max_time'], + "min_execution_time": ( + stats["min_time"] if stats["min_time"] != float("inf") else 0 + ), + "max_execution_time": stats["max_time"], "recent_24h": { "executions": len(recent_executions), "success_rate": recent_success_rate, - "average_execution_time": recent_avg_time - } + "average_execution_time": recent_avg_time, + }, } - - def get_user_metrics(self, user_id: str, time_window_minutes: int = 60) -> Dict[str, Any]: + + def get_user_metrics( + self, user_id: str, time_window_minutes: int = 60 + ) -> Dict[str, Any]: """ Get metrics for a specific user. - + Args: user_id: User identifier time_window_minutes: Time window in minutes - + Returns: User-specific metrics """ with self._lock: current_time = time.time() cutoff_time = current_time - (time_window_minutes * 60) - + # Filter metrics for this user within time window user_metrics = [ - metric for metric in self.metrics_history + metric + for metric in self.metrics_history if metric.user_id == user_id and metric.timestamp >= cutoff_time ] - + if not user_metrics: return { "user_id": user_id, "total_executions": 0, "success_rate": 0.0, "tools_used": [], - "average_execution_time": 0.0 + "average_execution_time": 0.0, } - + # Calculate user stats total_executions = len(user_metrics) successful_executions = sum(1 for m in user_metrics if m.success) success_rate = (successful_executions / total_executions) * 100 - + execution_times = [m.execution_time_ms for m in user_metrics] average_execution_time = sum(execution_times) / len(execution_times) - + # Tools used by this user tool_usage = defaultdict(int) for metric in user_metrics: tool_usage[metric.tool_name] += 1 - + tools_used = [ {"tool_name": tool, "count": count} - for tool, count in sorted(tool_usage.items(), key=lambda x: x[1], reverse=True) + for tool, count in sorted( + tool_usage.items(), key=lambda x: x[1], reverse=True + ) ] - + return { "user_id": user_id, "time_window_minutes": time_window_minutes, @@ -321,17 +359,19 @@ def get_user_metrics(self, user_id: str, time_window_minutes: int = 60) -> Dict[ "failed_executions": total_executions - successful_executions, "success_rate": success_rate, "average_execution_time": average_execution_time, - "tools_used": tools_used + "tools_used": tools_used, } - - def get_performance_trends(self, time_window_hours: int = 24, bucket_minutes: int = 60) -> Dict[str, Any]: + + def get_performance_trends( + self, time_window_hours: int = 24, bucket_minutes: int = 60 + ) -> Dict[str, Any]: """ Get performance trends over time. - + Args: time_window_hours: Time window in hours bucket_minutes: Time bucket size in minutes - + Returns: Performance trend data """ @@ -339,41 +379,42 @@ def get_performance_trends(self, time_window_hours: int = 24, bucket_minutes: in current_time = time.time() cutoff_time = current_time - (time_window_hours * 60 * 60) bucket_size = bucket_minutes * 60 - + # Filter metrics within time window recent_metrics = [ - metric for metric in self.metrics_history + metric + for metric in self.metrics_history if metric.timestamp >= cutoff_time ] - + if not recent_metrics: return {"buckets": [], "summary": {}} - + # Create time buckets start_time = cutoff_time buckets = [] - + while start_time < current_time: bucket_end = start_time + bucket_size bucket_metrics = [ - m for m in recent_metrics - if start_time <= m.timestamp < bucket_end + m for m in recent_metrics if start_time <= m.timestamp < bucket_end ] - + if bucket_metrics: total = len(bucket_metrics) successful = sum(1 for m in bucket_metrics if m.success) execution_times = [m.execution_time_ms for m in bucket_metrics] - + bucket_data = { "timestamp": start_time, "total_executions": total, "successful_executions": successful, "failed_executions": total - successful, "success_rate": (successful / total) * 100, - "average_execution_time": sum(execution_times) / len(execution_times), + "average_execution_time": sum(execution_times) + / len(execution_times), "min_execution_time": min(execution_times), - "max_execution_time": max(execution_times) + "max_execution_time": max(execution_times), } else: bucket_data = { @@ -384,49 +425,47 @@ def get_performance_trends(self, time_window_hours: int = 24, bucket_minutes: in "success_rate": 0.0, "average_execution_time": 0.0, "min_execution_time": 0.0, - "max_execution_time": 0.0 + "max_execution_time": 0.0, } - + buckets.append(bucket_data) start_time = bucket_end - + # Calculate summary if recent_metrics: total_executions = len(recent_metrics) successful_executions = sum(1 for m in recent_metrics if m.success) all_execution_times = [m.execution_time_ms for m in recent_metrics] - + summary = { "time_window_hours": time_window_hours, "bucket_minutes": bucket_minutes, "total_executions": total_executions, "success_rate": (successful_executions / total_executions) * 100, - "average_execution_time": sum(all_execution_times) / len(all_execution_times), + "average_execution_time": sum(all_execution_times) + / len(all_execution_times), "min_execution_time": min(all_execution_times), - "max_execution_time": max(all_execution_times) + "max_execution_time": max(all_execution_times), } else: summary = {} - - return { - "buckets": buckets, - "summary": summary - } - + + return {"buckets": buckets, "summary": summary} + def export_metrics(self, format: str = "json") -> str: """ Export metrics in specified format. - + Args: format: Export format ("json" or "csv") - + Returns: Exported metrics as string """ with self._lock: if format.lower() == "json": import json - + export_data = { "system_metrics": self.get_system_metrics().__dict__, "tool_metrics": { @@ -435,47 +474,55 @@ def export_metrics(self, format: str = "json") -> str: }, "performance_trends": self.get_performance_trends(), "export_timestamp": time.time(), - "total_metrics_count": len(self.metrics_history) + "total_metrics_count": len(self.metrics_history), } - + return json.dumps(export_data, indent=2) - + elif format.lower() == "csv": import csv import io - + output = io.StringIO() writer = csv.writer(output) - + # Write header - writer.writerow([ - "timestamp", "tool_name", "success", "execution_time_ms", - "user_id", "error_type" - ]) - + writer.writerow( + [ + "timestamp", + "tool_name", + "success", + "execution_time_ms", + "user_id", + "error_type", + ] + ) + # Write metrics for metric in self.metrics_history: - writer.writerow([ - metric.timestamp, - metric.tool_name, - metric.success, - metric.execution_time_ms, - metric.user_id or "", - metric.error_type or "" - ]) - + writer.writerow( + [ + metric.timestamp, + metric.tool_name, + metric.success, + metric.execution_time_ms, + metric.user_id or "", + metric.error_type or "", + ] + ) + return output.getvalue() - + else: raise ValueError(f"Unsupported export format: {format}") - + def clear_metrics(self, older_than_hours: Optional[int] = None) -> int: """ Clear metrics from memory. - + Args: older_than_hours: Only clear metrics older than this many hours - + Returns: Number of metrics cleared """ @@ -490,17 +537,23 @@ def clear_metrics(self, older_than_hours: Optional[int] = None) -> int: else: # Clear only old metrics cutoff_time = time.time() - (older_than_hours * 60 * 60) - + original_count = len(self.metrics_history) self.metrics_history = deque( - (metric for metric in self.metrics_history if metric.timestamp >= cutoff_time), - maxlen=self.max_history + ( + metric + for metric in self.metrics_history + if metric.timestamp >= cutoff_time + ), + maxlen=self.max_history, ) - + cleared_count = original_count - len(self.metrics_history) - logger.info("Old metrics cleared", - cleared_count=cleared_count, - older_than_hours=older_than_hours) + logger.info( + "Old metrics cleared", + cleared_count=cleared_count, + older_than_hours=older_than_hours, + ) return cleared_count @@ -510,4 +563,4 @@ def clear_metrics(self, older_than_hours: Optional[int] = None) -> int: def get_metrics_collector() -> MetricsCollector: """Get the global metrics collector instance.""" - return _metrics_collector \ No newline at end of file + return _metrics_collector diff --git a/src/monitoring/performance_optimizer.py b/src/monitoring/performance_optimizer.py index a98da57..3ac7f96 100644 --- a/src/monitoring/performance_optimizer.py +++ b/src/monitoring/performance_optimizer.py @@ -22,10 +22,11 @@ except ImportError: import sys from pathlib import Path + src_path = str(Path(__file__).parent.parent) if src_path not in sys.path: sys.path.insert(0, src_path) - + from cache.manager import CacheManager, get_cache_manager from cache.warming import CacheWarmer, get_cache_warmer from cache.metrics import MetricsCollector, get_metrics_collector @@ -36,6 +37,7 @@ class OptimizationStrategy(Enum): """Performance optimization strategies.""" + AGGRESSIVE_WARMING = "aggressive_warming" CONSERVATIVE_MEMORY = "conservative_memory" BALANCED = "balanced" @@ -46,6 +48,7 @@ class OptimizationStrategy(Enum): @dataclass class PerformanceTarget: """Performance targets and thresholds.""" + cache_hit_latency_ms: float = 48.0 cache_miss_latency_ms: float = 140.0 cache_hit_rate_percent: float = 60.0 @@ -57,6 +60,7 @@ class PerformanceTarget: @dataclass class OptimizationAction: """Represents an optimization action to be taken.""" + action_type: str parameters: Dict[str, Any] priority: int @@ -67,17 +71,19 @@ class OptimizationAction: class PerformanceOptimizer: """ Real-time performance optimizer for the FACT system. - + Monitors system performance and automatically applies optimizations to meet performance targets. """ - - def __init__(self, - cache_manager: Optional[CacheManager] = None, - targets: Optional[PerformanceTarget] = None): + + def __init__( + self, + cache_manager: Optional[CacheManager] = None, + targets: Optional[PerformanceTarget] = None, + ): """ Initialize performance optimizer. - + Args: cache_manager: Cache manager instance targets: Performance targets @@ -86,48 +92,50 @@ def __init__(self, self.cache_warmer = get_cache_warmer(self.cache_manager) self.metrics_collector = get_metrics_collector() self.benchmark_framework = BenchmarkFramework() - + self.targets = targets or PerformanceTarget() self.strategy = OptimizationStrategy.BALANCED - + # Optimization state self.optimization_enabled = True self.optimization_interval = 300 # 5 minutes self.last_optimization = 0.0 self.optimization_history: List[OptimizationAction] = [] - + # Performance tracking self.performance_windows = { - 'short': [], # Last 10 measurements - 'medium': [], # Last 50 measurements - 'long': [] # Last 200 measurements + "short": [], # Last 10 measurements + "medium": [], # Last 50 measurements + "long": [], # Last 200 measurements } - + # Optimization thresholds self.thresholds = { - 'hit_rate_critical': 40.0, - 'latency_critical': 200.0, - 'memory_critical': 95.0, - 'cost_efficiency_critical': 50.0 + "hit_rate_critical": 40.0, + "latency_critical": 200.0, + "memory_critical": 95.0, + "cost_efficiency_critical": 50.0, } - + # Thread safety self._lock = threading.RLock() self._optimization_task: Optional[asyncio.Task] = None - - logger.info("Performance optimizer initialized", - strategy=self.strategy.value, - targets=self.targets) - + + logger.info( + "Performance optimizer initialized", + strategy=self.strategy.value, + targets=self.targets, + ) + async def start_optimization_loop(self): """Start the continuous optimization loop.""" if self._optimization_task and not self._optimization_task.done(): logger.warning("Optimization loop already running") return - + self._optimization_task = asyncio.create_task(self._optimization_loop()) logger.info("Performance optimization loop started") - + async def stop_optimization_loop(self): """Stop the continuous optimization loop.""" if self._optimization_task: @@ -137,360 +145,421 @@ async def stop_optimization_loop(self): except asyncio.CancelledError: pass self._optimization_task = None - + logger.info("Performance optimization loop stopped") - + async def _optimization_loop(self): """Main optimization loop.""" while True: try: await asyncio.sleep(self.optimization_interval) - + if self.optimization_enabled: await self.optimize_performance() - + except asyncio.CancelledError: logger.info("Optimization loop cancelled") break except Exception as e: logger.error("Optimization loop error", error=str(e)) await asyncio.sleep(30) # Brief pause before retry - + async def optimize_performance(self) -> List[OptimizationAction]: """ Analyze current performance and apply optimizations. - + Returns: List of optimization actions taken """ try: with self._lock: logger.info("Starting performance optimization cycle") - + # Collect current metrics current_metrics = self._collect_performance_metrics() - + # Update performance windows self._update_performance_windows(current_metrics) - + # Analyze performance issues issues = self._analyze_performance_issues(current_metrics) - + # Generate optimization actions actions = self._generate_optimization_actions(issues, current_metrics) - + # Execute high-priority actions executed_actions = [] for action in sorted(actions, key=lambda x: x.priority, reverse=True): if await self._execute_optimization_action(action): executed_actions.append(action) self.optimization_history.append(action) - + # Log optimization results self._log_optimization_results(executed_actions, current_metrics) - + self.last_optimization = time.time() - + return executed_actions - + except Exception as e: logger.error("Performance optimization failed", error=str(e)) return [] - + def _collect_performance_metrics(self) -> Dict[str, Any]: """Collect comprehensive performance metrics.""" try: # Cache metrics cache_metrics = self.cache_manager.get_metrics() cache_perf_stats = self.cache_manager.get_performance_stats() - + # Cache health - health_metrics = self.metrics_collector.get_cache_health_score(self.cache_manager) - + health_metrics = self.metrics_collector.get_cache_health_score( + self.cache_manager + ) + # Optimization metrics - opt_metrics = self.metrics_collector.track_optimization_metrics(self.cache_manager) - + opt_metrics = self.metrics_collector.track_optimization_metrics( + self.cache_manager + ) + # System metrics current_time = time.time() - memory_utilization = (cache_metrics.total_size / self.cache_manager.max_size_bytes * 100) - + memory_utilization = ( + cache_metrics.total_size / self.cache_manager.max_size_bytes * 100 + ) + return { - 'timestamp': current_time, - 'cache_metrics': cache_metrics, - 'performance_stats': cache_perf_stats, - 'health_metrics': health_metrics, - 'optimization_metrics': opt_metrics, - 'memory_utilization': memory_utilization, - 'hit_rate_percent': cache_metrics.hit_rate, - 'avg_hit_latency_ms': cache_perf_stats.get('avg_hit_latency_ms', 0), - 'avg_miss_latency_ms': cache_perf_stats.get('avg_miss_latency_ms', 0) + "timestamp": current_time, + "cache_metrics": cache_metrics, + "performance_stats": cache_perf_stats, + "health_metrics": health_metrics, + "optimization_metrics": opt_metrics, + "memory_utilization": memory_utilization, + "hit_rate_percent": cache_metrics.hit_rate, + "avg_hit_latency_ms": cache_perf_stats.get("avg_hit_latency_ms", 0), + "avg_miss_latency_ms": cache_perf_stats.get("avg_miss_latency_ms", 0), } - + except Exception as e: logger.error("Failed to collect performance metrics", error=str(e)) - return {'timestamp': time.time(), 'error': str(e)} - + return {"timestamp": time.time(), "error": str(e)} + def _update_performance_windows(self, metrics: Dict[str, Any]): """Update rolling performance windows.""" try: # Add to all windows for window_name, window in self.performance_windows.items(): window.append(metrics) - + # Trim windows to size - if len(self.performance_windows['short']) > 10: - self.performance_windows['short'].pop(0) - if len(self.performance_windows['medium']) > 50: - self.performance_windows['medium'].pop(0) - if len(self.performance_windows['long']) > 200: - self.performance_windows['long'].pop(0) - + if len(self.performance_windows["short"]) > 10: + self.performance_windows["short"].pop(0) + if len(self.performance_windows["medium"]) > 50: + self.performance_windows["medium"].pop(0) + if len(self.performance_windows["long"]) > 200: + self.performance_windows["long"].pop(0) + except Exception as e: logger.debug("Failed to update performance windows", error=str(e)) - + def _analyze_performance_issues(self, current_metrics: Dict[str, Any]) -> List[str]: """Analyze current performance and identify issues.""" issues = [] - + try: # Cache hit rate analysis - hit_rate = current_metrics.get('hit_rate_percent', 0) + hit_rate = current_metrics.get("hit_rate_percent", 0) if hit_rate < self.targets.cache_hit_rate_percent: - severity = "critical" if hit_rate < self.thresholds['hit_rate_critical'] else "moderate" + severity = ( + "critical" + if hit_rate < self.thresholds["hit_rate_critical"] + else "moderate" + ) issues.append(f"low_hit_rate_{severity}") - + # Latency analysis - hit_latency = current_metrics.get('avg_hit_latency_ms', 0) - miss_latency = current_metrics.get('avg_miss_latency_ms', 0) - + hit_latency = current_metrics.get("avg_hit_latency_ms", 0) + miss_latency = current_metrics.get("avg_miss_latency_ms", 0) + if hit_latency > self.targets.cache_hit_latency_ms: issues.append("high_hit_latency") if miss_latency > self.targets.cache_miss_latency_ms: issues.append("high_miss_latency") - + # Memory utilization - memory_util = current_metrics.get('memory_utilization', 0) + memory_util = current_metrics.get("memory_utilization", 0) if memory_util > self.targets.memory_utilization_max: - severity = "critical" if memory_util > self.thresholds['memory_critical'] else "moderate" + severity = ( + "critical" + if memory_util > self.thresholds["memory_critical"] + else "moderate" + ) issues.append(f"high_memory_usage_{severity}") - + # Cost efficiency - opt_metrics = current_metrics.get('optimization_metrics', {}) - warming_efficiency = opt_metrics.get('cache_warming_efficiency', 0) + opt_metrics = current_metrics.get("optimization_metrics", {}) + warming_efficiency = opt_metrics.get("cache_warming_efficiency", 0) if warming_efficiency < 50: issues.append("low_warming_efficiency") - + # Trend analysis - if len(self.performance_windows['medium']) >= 10: - recent_hit_rates = [m.get('hit_rate_percent', 0) for m in self.performance_windows['medium'][-10:]] + if len(self.performance_windows["medium"]) >= 10: + recent_hit_rates = [ + m.get("hit_rate_percent", 0) + for m in self.performance_windows["medium"][-10:] + ] if len(recent_hit_rates) >= 5: - trend = (recent_hit_rates[-1] - recent_hit_rates[0]) / len(recent_hit_rates) + trend = (recent_hit_rates[-1] - recent_hit_rates[0]) / len( + recent_hit_rates + ) if trend < -2: # Declining hit rate issues.append("declining_hit_rate") - + except Exception as e: logger.error("Performance analysis failed", error=str(e)) issues.append("analysis_error") - + return issues - - def _generate_optimization_actions(self, issues: List[str], metrics: Dict[str, Any]) -> List[OptimizationAction]: + + def _generate_optimization_actions( + self, issues: List[str], metrics: Dict[str, Any] + ) -> List[OptimizationAction]: """Generate optimization actions based on identified issues.""" actions = [] - + try: # Cache warming optimizations if "low_hit_rate_critical" in issues or "declining_hit_rate" in issues: - actions.append(OptimizationAction( - action_type="aggressive_cache_warming", - parameters={"max_queries": 50, "concurrent": True}, - priority=9, - estimated_impact=15.0 - )) + actions.append( + OptimizationAction( + action_type="aggressive_cache_warming", + parameters={"max_queries": 50, "concurrent": True}, + priority=9, + estimated_impact=15.0, + ) + ) elif "low_hit_rate_moderate" in issues: - actions.append(OptimizationAction( - action_type="moderate_cache_warming", - parameters={"max_queries": 30, "concurrent": True}, - priority=7, - estimated_impact=10.0 - )) - + actions.append( + OptimizationAction( + action_type="moderate_cache_warming", + parameters={"max_queries": 30, "concurrent": True}, + priority=7, + estimated_impact=10.0, + ) + ) + # Memory optimization if "high_memory_usage_critical" in issues: - actions.append(OptimizationAction( - action_type="emergency_memory_cleanup", - parameters={"target_utilization": 70.0}, - priority=10, - estimated_impact=20.0 - )) + actions.append( + OptimizationAction( + action_type="emergency_memory_cleanup", + parameters={"target_utilization": 70.0}, + priority=10, + estimated_impact=20.0, + ) + ) elif "high_memory_usage_moderate" in issues: - actions.append(OptimizationAction( - action_type="preemptive_cleanup", - parameters={"target_utilization": 80.0}, - priority=6, - estimated_impact=8.0 - )) - + actions.append( + OptimizationAction( + action_type="preemptive_cleanup", + parameters={"target_utilization": 80.0}, + priority=6, + estimated_impact=8.0, + ) + ) + # Latency optimizations if "high_hit_latency" in issues: - actions.append(OptimizationAction( - action_type="optimize_cache_access", - parameters={"enable_fast_lookup": True}, - priority=8, - estimated_impact=12.0 - )) - + actions.append( + OptimizationAction( + action_type="optimize_cache_access", + parameters={"enable_fast_lookup": True}, + priority=8, + estimated_impact=12.0, + ) + ) + # Warming efficiency improvements if "low_warming_efficiency" in issues: - actions.append(OptimizationAction( - action_type="optimize_warming_strategy", - parameters={"strategy": "intelligent_prioritization"}, - priority=5, - estimated_impact=7.0 - )) - + actions.append( + OptimizationAction( + action_type="optimize_warming_strategy", + parameters={"strategy": "intelligent_prioritization"}, + priority=5, + estimated_impact=7.0, + ) + ) + # Strategy adjustments - current_hit_rate = metrics.get('hit_rate_percent', 0) - if current_hit_rate < 40 and self.strategy != OptimizationStrategy.AGGRESSIVE_WARMING: - actions.append(OptimizationAction( - action_type="change_strategy", - parameters={"new_strategy": OptimizationStrategy.AGGRESSIVE_WARMING.value}, - priority=4, - estimated_impact=5.0 - )) - + current_hit_rate = metrics.get("hit_rate_percent", 0) + if ( + current_hit_rate < 40 + and self.strategy != OptimizationStrategy.AGGRESSIVE_WARMING + ): + actions.append( + OptimizationAction( + action_type="change_strategy", + parameters={ + "new_strategy": OptimizationStrategy.AGGRESSIVE_WARMING.value + }, + priority=4, + estimated_impact=5.0, + ) + ) + except Exception as e: logger.error("Failed to generate optimization actions", error=str(e)) - + return actions - + async def _execute_optimization_action(self, action: OptimizationAction) -> bool: """Execute a specific optimization action.""" try: - logger.info("Executing optimization action", - action_type=action.action_type, - priority=action.priority, - parameters=action.parameters) - + logger.info( + "Executing optimization action", + action_type=action.action_type, + priority=action.priority, + parameters=action.parameters, + ) + if action.action_type == "aggressive_cache_warming": await self.cache_warmer.warm_cache_intelligently( max_queries=action.parameters.get("max_queries", 50) ) return True - + elif action.action_type == "moderate_cache_warming": await self.cache_warmer.warm_cache_intelligently( max_queries=action.parameters.get("max_queries", 30) ) return True - + elif action.action_type == "emergency_memory_cleanup": target_util = action.parameters.get("target_utilization", 70.0) current_size = self.cache_manager._calculate_current_size() target_size = int(self.cache_manager.max_size_bytes * target_util / 100) - + if current_size > target_size: space_to_free = current_size - target_size self.cache_manager._intelligent_eviction(space_to_free) return True - + elif action.action_type == "preemptive_cleanup": self.cache_manager._maybe_preemptive_cleanup() return True - + elif action.action_type == "optimize_cache_access": # Enable performance optimizations self.cache_manager.optimization_enabled = True - self.cache_manager.fast_lookup_enabled = action.parameters.get("enable_fast_lookup", True) + self.cache_manager.fast_lookup_enabled = action.parameters.get( + "enable_fast_lookup", True + ) return True - + elif action.action_type == "optimize_warming_strategy": # Update warming strategy - strategy = action.parameters.get("strategy", "intelligent_prioritization") - self.cache_warmer.optimization_config['adaptive_priorities'] = True - self.cache_warmer.optimization_config['concurrent_warming'] = True + strategy = action.parameters.get( + "strategy", "intelligent_prioritization" + ) + self.cache_warmer.optimization_config["adaptive_priorities"] = True + self.cache_warmer.optimization_config["concurrent_warming"] = True return True - + elif action.action_type == "change_strategy": new_strategy = action.parameters.get("new_strategy") if new_strategy: self.strategy = OptimizationStrategy(new_strategy) logger.info("Strategy changed", new_strategy=new_strategy) return True - + else: - logger.warning("Unknown optimization action", action_type=action.action_type) + logger.warning( + "Unknown optimization action", action_type=action.action_type + ) return False - + except Exception as e: - logger.error("Failed to execute optimization action", - action_type=action.action_type, - error=str(e)) + logger.error( + "Failed to execute optimization action", + action_type=action.action_type, + error=str(e), + ) return False - - def _log_optimization_results(self, actions: List[OptimizationAction], metrics: Dict[str, Any]): + + def _log_optimization_results( + self, actions: List[OptimizationAction], metrics: Dict[str, Any] + ): """Log optimization results and impact.""" try: if actions: action_types = [a.action_type for a in actions] total_impact = sum(a.estimated_impact for a in actions) - - logger.info("Optimization cycle completed", - actions_taken=len(actions), - action_types=action_types, - estimated_total_impact=total_impact, - current_hit_rate=metrics.get('hit_rate_percent', 0), - memory_utilization=metrics.get('memory_utilization', 0)) + + logger.info( + "Optimization cycle completed", + actions_taken=len(actions), + action_types=action_types, + estimated_total_impact=total_impact, + current_hit_rate=metrics.get("hit_rate_percent", 0), + memory_utilization=metrics.get("memory_utilization", 0), + ) else: - logger.info("Optimization cycle completed - no actions needed", - current_hit_rate=metrics.get('hit_rate_percent', 0), - memory_utilization=metrics.get('memory_utilization', 0)) - + logger.info( + "Optimization cycle completed - no actions needed", + current_hit_rate=metrics.get("hit_rate_percent", 0), + memory_utilization=metrics.get("memory_utilization", 0), + ) + except Exception as e: logger.debug("Failed to log optimization results", error=str(e)) - + def get_optimization_status(self) -> Dict[str, Any]: """Get current optimization status and statistics.""" try: - recent_actions = self.optimization_history[-10:] if self.optimization_history else [] - + recent_actions = ( + self.optimization_history[-10:] if self.optimization_history else [] + ) + return { - 'optimization_enabled': self.optimization_enabled, - 'current_strategy': self.strategy.value, - 'last_optimization': self.last_optimization, - 'optimization_interval': self.optimization_interval, - 'total_optimizations': len(self.optimization_history), - 'recent_actions': [ + "optimization_enabled": self.optimization_enabled, + "current_strategy": self.strategy.value, + "last_optimization": self.last_optimization, + "optimization_interval": self.optimization_interval, + "total_optimizations": len(self.optimization_history), + "recent_actions": [ { - 'action_type': a.action_type, - 'execution_time': a.execution_time, - 'estimated_impact': a.estimated_impact - } for a in recent_actions + "action_type": a.action_type, + "execution_time": a.execution_time, + "estimated_impact": a.estimated_impact, + } + for a in recent_actions ], - 'performance_targets': { - 'cache_hit_latency_ms': self.targets.cache_hit_latency_ms, - 'cache_miss_latency_ms': self.targets.cache_miss_latency_ms, - 'cache_hit_rate_percent': self.targets.cache_hit_rate_percent, - 'cost_reduction_percent': self.targets.cost_reduction_percent - } + "performance_targets": { + "cache_hit_latency_ms": self.targets.cache_hit_latency_ms, + "cache_miss_latency_ms": self.targets.cache_miss_latency_ms, + "cache_hit_rate_percent": self.targets.cache_hit_rate_percent, + "cost_reduction_percent": self.targets.cost_reduction_percent, + }, } - + except Exception as e: logger.error("Failed to get optimization status", error=str(e)) - return {'error': str(e)} + return {"error": str(e)} # Global optimizer instance _performance_optimizer: Optional[PerformanceOptimizer] = None -def get_performance_optimizer(cache_manager: Optional[CacheManager] = None) -> PerformanceOptimizer: +def get_performance_optimizer( + cache_manager: Optional[CacheManager] = None, +) -> PerformanceOptimizer: """Get or create global performance optimizer instance.""" global _performance_optimizer - + if _performance_optimizer is None: _performance_optimizer = PerformanceOptimizer(cache_manager) - + return _performance_optimizer @@ -505,4 +574,4 @@ async def stop_performance_optimization(): """Stop performance optimization.""" optimizer = get_performance_optimizer() await optimizer.stop_optimization_loop() - logger.info("Performance optimization stopped") \ No newline at end of file + logger.info("Performance optimization stopped") diff --git a/src/security/__init__.py b/src/security/__init__.py index 1427c7b..b67e033 100644 --- a/src/security/__init__.py +++ b/src/security/__init__.py @@ -12,8 +12,4 @@ # Fallback to absolute imports when called from scripts from security.auth import AuthorizationManager, Authorization, AuthFlow -__all__ = [ - 'AuthorizationManager', - 'Authorization', - 'AuthFlow' -] \ No newline at end of file +__all__ = ["AuthorizationManager", "Authorization", "AuthFlow"] diff --git a/src/security/auth.py b/src/security/auth.py index 08cf06c..0b53356 100644 --- a/src/security/auth.py +++ b/src/security/auth.py @@ -13,25 +13,18 @@ try: # Try relative imports first (when used as package) - from ..core.errors import ( - AuthenticationError, - UnauthorizedError, - ValidationError - ) + from ..core.errors import AuthenticationError, UnauthorizedError, ValidationError except ImportError: # Fall back to absolute imports (when run as script) import sys from pathlib import Path + # Add src to path if not already there src_path = str(Path(__file__).parent.parent) if src_path not in sys.path: sys.path.insert(0, src_path) - - from core.errors import ( - AuthenticationError, - UnauthorizedError, - ValidationError - ) + + from core.errors import AuthenticationError, UnauthorizedError, ValidationError logger = structlog.get_logger(__name__) @@ -40,6 +33,7 @@ @dataclass class Authorization: """Represents an authorization grant for a user and tool.""" + user_id: str tool_name: str scopes: List[str] @@ -47,21 +41,21 @@ class Authorization: refresh_token: Optional[str] = None expires_at: Optional[float] = None created_at: float = None - + def __post_init__(self): if self.created_at is None: self.created_at = time.time() - + def is_expired(self) -> bool: """Check if authorization is expired.""" if self.expires_at is None: return False return time.time() > self.expires_at - + def has_scope(self, scope: str) -> bool: """Check if authorization includes a specific scope.""" return scope in self.scopes - + def has_all_scopes(self, required_scopes: List[str]) -> bool: """Check if authorization includes all required scopes.""" return all(scope in self.scopes for scope in required_scopes) @@ -70,6 +64,7 @@ def has_all_scopes(self, required_scopes: List[str]) -> bool: @dataclass class AuthFlow: """Represents an ongoing OAuth authorization flow.""" + flow_id: str user_id: str tool_name: str @@ -78,13 +73,13 @@ class AuthFlow: status: str = "pending" created_at: float = None expires_at: float = None - + def __post_init__(self): if self.created_at is None: self.created_at = time.time() if self.expires_at is None: self.expires_at = self.created_at + 600 # 10 minutes - + def is_expired(self) -> bool: """Check if authorization flow is expired.""" return time.time() > self.expires_at @@ -93,18 +88,20 @@ def is_expired(self) -> bool: class AuthorizationManager: """ Manages user authorization for tool access. - + Provides OAuth integration, scope validation, and session management for secure tool execution. """ - - def __init__(self, - oauth_client_id: Optional[str] = None, - oauth_client_secret: Optional[str] = None, - oauth_redirect_uri: Optional[str] = None): + + def __init__( + self, + oauth_client_id: Optional[str] = None, + oauth_client_secret: Optional[str] = None, + oauth_redirect_uri: Optional[str] = None, + ): """ Initialize authorization manager. - + Args: oauth_client_id: OAuth client ID oauth_client_secret: OAuth client secret @@ -113,11 +110,11 @@ def __init__(self, self.oauth_client_id = oauth_client_id self.oauth_client_secret = oauth_client_secret self.oauth_redirect_uri = oauth_redirect_uri - + # In-memory storage for demo (use persistent storage in production) self.active_authorizations: Dict[str, Authorization] = {} self.pending_flows: Dict[str, AuthFlow] = {} - + # Scope definitions self.available_scopes = { "read": "Read access to data and tools", @@ -125,109 +122,130 @@ def __init__(self, "admin": "Administrative access to system functions", "sql": "Execute SQL queries", "file": "Access file system operations", - "web": "Make external web requests" + "web": "Make external web requests", } - - logger.info("AuthorizationManager initialized", - oauth_enabled=bool(oauth_client_id)) - - async def validate_authorization(self, - user_id: str, - tool_name: str, - required_scopes: Optional[List[str]] = None) -> Authorization: + + logger.info( + "AuthorizationManager initialized", oauth_enabled=bool(oauth_client_id) + ) + + async def validate_authorization( + self, user_id: str, tool_name: str, required_scopes: Optional[List[str]] = None + ) -> Authorization: """ Validate user authorization for tool access. - + Args: user_id: User identifier tool_name: Tool name requiring authorization required_scopes: Optional list of required scopes - + Returns: Valid Authorization object - + Raises: UnauthorizedError: If authorization is invalid or insufficient """ try: auth_key = f"{user_id}:{tool_name}" - + # Check if authorization exists if auth_key not in self.active_authorizations: - logger.warning("No authorization found", - user_id=user_id, - tool_name=tool_name) - raise UnauthorizedError(f"No authorization found for user {user_id} and tool {tool_name}") - + logger.warning( + "No authorization found", user_id=user_id, tool_name=tool_name + ) + raise UnauthorizedError( + f"No authorization found for user {user_id} and tool {tool_name}" + ) + authorization = self.active_authorizations[auth_key] - + # Check if authorization is expired if authorization.is_expired(): - logger.warning("Authorization expired", - user_id=user_id, - tool_name=tool_name, - expires_at=authorization.expires_at) - + logger.warning( + "Authorization expired", + user_id=user_id, + tool_name=tool_name, + expires_at=authorization.expires_at, + ) + # Try to refresh token if available if authorization.refresh_token: try: - refreshed_auth = await self._refresh_authorization(authorization) + refreshed_auth = await self._refresh_authorization( + authorization + ) self.active_authorizations[auth_key] = refreshed_auth authorization = refreshed_auth except Exception as e: - logger.error("Failed to refresh authorization", - user_id=user_id, - tool_name=tool_name, - error=str(e)) + logger.error( + "Failed to refresh authorization", + user_id=user_id, + tool_name=tool_name, + error=str(e), + ) del self.active_authorizations[auth_key] - raise UnauthorizedError("Authorization expired and refresh failed") + raise UnauthorizedError( + "Authorization expired and refresh failed" + ) else: del self.active_authorizations[auth_key] raise UnauthorizedError("Authorization expired") - + # Check required scopes if required_scopes: if not authorization.has_all_scopes(required_scopes): - missing_scopes = [scope for scope in required_scopes if not authorization.has_scope(scope)] - logger.warning("Insufficient scopes", - user_id=user_id, - tool_name=tool_name, - required_scopes=required_scopes, - user_scopes=authorization.scopes, - missing_scopes=missing_scopes) - raise UnauthorizedError(f"Insufficient permissions. Missing scopes: {missing_scopes}") - - logger.debug("Authorization validated successfully", + missing_scopes = [ + scope + for scope in required_scopes + if not authorization.has_scope(scope) + ] + logger.warning( + "Insufficient scopes", user_id=user_id, tool_name=tool_name, - scopes=authorization.scopes) - + required_scopes=required_scopes, + user_scopes=authorization.scopes, + missing_scopes=missing_scopes, + ) + raise UnauthorizedError( + f"Insufficient permissions. Missing scopes: {missing_scopes}" + ) + + logger.debug( + "Authorization validated successfully", + user_id=user_id, + tool_name=tool_name, + scopes=authorization.scopes, + ) + return authorization - + except UnauthorizedError: raise except Exception as e: - logger.error("Authorization validation failed", - user_id=user_id, - tool_name=tool_name, - error=str(e)) + logger.error( + "Authorization validation failed", + user_id=user_id, + tool_name=tool_name, + error=str(e), + ) raise UnauthorizedError(f"Authorization validation failed: {str(e)}") - - def initiate_authorization(self, - user_id: str, - tool_name: str, - scopes: List[str]) -> AuthFlow: + + def initiate_authorization( + self, user_id: str, tool_name: str, scopes: List[str] + ) -> AuthFlow: """ Initiate OAuth authorization flow. - + Args: user_id: User identifier tool_name: Tool name requiring authorization scopes: List of requested scopes - + Returns: AuthFlow instance with authorization URL - + Raises: ValidationError: If parameters are invalid """ @@ -235,64 +253,71 @@ def initiate_authorization(self, # Validate inputs if not user_id or not tool_name: raise ValidationError("User ID and tool name are required") - + if not scopes: raise ValidationError("At least one scope is required") - + # Validate scopes - invalid_scopes = [scope for scope in scopes if scope not in self.available_scopes] + invalid_scopes = [ + scope for scope in scopes if scope not in self.available_scopes + ] if invalid_scopes: raise ValidationError(f"Invalid scopes: {invalid_scopes}") - + # Generate flow ID import uuid + flow_id = str(uuid.uuid4()) - + # Create authorization URL auth_url = self._create_auth_url(flow_id, scopes) - + # Create auth flow auth_flow = AuthFlow( flow_id=flow_id, user_id=user_id, tool_name=tool_name, scopes=scopes, - auth_url=auth_url + auth_url=auth_url, ) - + # Store pending flow self.pending_flows[flow_id] = auth_flow - - logger.info("Authorization flow initiated", - flow_id=flow_id, - user_id=user_id, - tool_name=tool_name, - scopes=scopes) - + + logger.info( + "Authorization flow initiated", + flow_id=flow_id, + user_id=user_id, + tool_name=tool_name, + scopes=scopes, + ) + return auth_flow - + except ValidationError: raise except Exception as e: - logger.error("Failed to initiate authorization", - user_id=user_id, - tool_name=tool_name, - error=str(e)) + logger.error( + "Failed to initiate authorization", + user_id=user_id, + tool_name=tool_name, + error=str(e), + ) raise ValidationError(f"Failed to initiate authorization: {str(e)}") - - async def complete_authorization(self, - flow_id: str, - authorization_code: Optional[str] = None) -> Authorization: + + async def complete_authorization( + self, flow_id: str, authorization_code: Optional[str] = None + ) -> Authorization: """ Complete OAuth authorization flow. - + Args: flow_id: Authorization flow identifier authorization_code: OAuth authorization code - + Returns: Authorization object - + Raises: UnauthorizedError: If flow is invalid or expired """ @@ -300,14 +325,14 @@ async def complete_authorization(self, # Get pending flow if flow_id not in self.pending_flows: raise UnauthorizedError("Invalid or expired authorization flow") - + auth_flow = self.pending_flows[flow_id] - + # Check if flow is expired if auth_flow.is_expired(): del self.pending_flows[flow_id] raise UnauthorizedError("Authorization flow expired") - + # For demo purposes, create authorization without actual OAuth # In production, exchange authorization_code for access token authorization = Authorization( @@ -315,123 +340,127 @@ async def complete_authorization(self, tool_name=auth_flow.tool_name, scopes=auth_flow.scopes, access_token=f"demo_token_{flow_id}", - expires_at=time.time() + 3600 # 1 hour + expires_at=time.time() + 3600, # 1 hour ) - + # Store authorization auth_key = f"{auth_flow.user_id}:{auth_flow.tool_name}" self.active_authorizations[auth_key] = authorization - + # Clean up pending flow del self.pending_flows[flow_id] - - logger.info("Authorization completed successfully", - flow_id=flow_id, - user_id=auth_flow.user_id, - tool_name=auth_flow.tool_name) - + + logger.info( + "Authorization completed successfully", + flow_id=flow_id, + user_id=auth_flow.user_id, + tool_name=auth_flow.tool_name, + ) + return authorization - + except UnauthorizedError: raise except Exception as e: - logger.error("Failed to complete authorization", - flow_id=flow_id, - error=str(e)) + logger.error( + "Failed to complete authorization", flow_id=flow_id, error=str(e) + ) raise UnauthorizedError(f"Failed to complete authorization: {str(e)}") - + def revoke_authorization(self, user_id: str, tool_name: str) -> bool: """ Revoke user authorization for a tool. - + Args: user_id: User identifier tool_name: Tool name - + Returns: True if authorization was revoked """ auth_key = f"{user_id}:{tool_name}" - + if auth_key in self.active_authorizations: del self.active_authorizations[auth_key] - logger.info("Authorization revoked", - user_id=user_id, - tool_name=tool_name) + logger.info("Authorization revoked", user_id=user_id, tool_name=tool_name) return True - + return False - + def get_user_authorizations(self, user_id: str) -> List[Authorization]: """ Get all authorizations for a user. - + Args: user_id: User identifier - + Returns: List of user authorizations """ user_auths = [] - + for auth_key, authorization in self.active_authorizations.items(): if authorization.user_id == user_id: user_auths.append(authorization) - + return user_auths - + def cleanup_expired_authorizations(self) -> int: """ Clean up expired authorizations and flows. - + Returns: Number of items cleaned up """ cleaned_count = 0 current_time = time.time() - + # Clean up expired authorizations expired_auths = [] for auth_key, authorization in self.active_authorizations.items(): if authorization.is_expired(): expired_auths.append(auth_key) - + for auth_key in expired_auths: del self.active_authorizations[auth_key] cleaned_count += 1 - + # Clean up expired flows expired_flows = [] for flow_id, auth_flow in self.pending_flows.items(): if auth_flow.is_expired(): expired_flows.append(flow_id) - + for flow_id in expired_flows: del self.pending_flows[flow_id] cleaned_count += 1 - + if cleaned_count > 0: - logger.info("Cleaned up expired items", - count=cleaned_count, - authorizations=len(expired_auths), - flows=len(expired_flows)) - + logger.info( + "Cleaned up expired items", + count=cleaned_count, + authorizations=len(expired_auths), + flows=len(expired_flows), + ) + return cleaned_count - + def get_available_scopes(self) -> Dict[str, str]: """Get available authorization scopes.""" return self.available_scopes.copy() - - async def _refresh_authorization(self, authorization: Authorization) -> Authorization: + + async def _refresh_authorization( + self, authorization: Authorization + ) -> Authorization: """ Refresh an expired authorization using refresh token. - + Args: authorization: Authorization to refresh - + Returns: Refreshed authorization - + Raises: UnauthorizedError: If refresh fails """ @@ -445,28 +474,32 @@ async def _refresh_authorization(self, authorization: Authorization) -> Authoriz access_token=f"refreshed_{authorization.access_token}", refresh_token=authorization.refresh_token, expires_at=time.time() + 3600, # 1 hour - created_at=authorization.created_at + created_at=authorization.created_at, + ) + + logger.info( + "Authorization refreshed", + user_id=authorization.user_id, + tool_name=authorization.tool_name, ) - - logger.info("Authorization refreshed", - user_id=authorization.user_id, - tool_name=authorization.tool_name) - + return refreshed_auth - + except Exception as e: - logger.error("Failed to refresh authorization", - user_id=authorization.user_id, - tool_name=authorization.tool_name, - error=str(e)) + logger.error( + "Failed to refresh authorization", + user_id=authorization.user_id, + tool_name=authorization.tool_name, + error=str(e), + ) raise UnauthorizedError(f"Failed to refresh authorization: {str(e)}") - + def _create_auth_url(self, flow_id: str, scopes: List[str]) -> str: """Create OAuth authorization URL.""" if not self.oauth_client_id or not self.oauth_redirect_uri: # Return demo URL if OAuth not configured return f"https://demo-oauth.example.com/authorize?flow_id={flow_id}&scopes={','.join(scopes)}" - + # In production, create actual OAuth URL scope_string = " ".join(scopes) auth_url = ( @@ -477,41 +510,44 @@ def _create_auth_url(self, flow_id: str, scopes: List[str]) -> str: f"&state={flow_id}" f"&response_type=code" ) - + return auth_url # Utility functions for authorization -def create_authorization_manager(oauth_client_id: Optional[str] = None, - oauth_client_secret: Optional[str] = None, - oauth_redirect_uri: Optional[str] = None) -> AuthorizationManager: + +def create_authorization_manager( + oauth_client_id: Optional[str] = None, + oauth_client_secret: Optional[str] = None, + oauth_redirect_uri: Optional[str] = None, +) -> AuthorizationManager: """ Create and configure an authorization manager. - + Args: oauth_client_id: OAuth client ID oauth_client_secret: OAuth client secret oauth_redirect_uri: OAuth redirect URI - + Returns: Configured AuthorizationManager instance """ return AuthorizationManager( oauth_client_id=oauth_client_id, oauth_client_secret=oauth_client_secret, - oauth_redirect_uri=oauth_redirect_uri + oauth_redirect_uri=oauth_redirect_uri, ) def extract_required_scopes(tool_definition: Dict[str, Any]) -> List[str]: """ Extract required scopes from tool definition. - + Args: tool_definition: Tool definition dictionary - + Returns: List of required scopes """ - return tool_definition.get("required_scopes", []) \ No newline at end of file + return tool_definition.get("required_scopes", []) diff --git a/src/security/cache_encryption.py b/src/security/cache_encryption.py index c7079db..810491e 100644 --- a/src/security/cache_encryption.py +++ b/src/security/cache_encryption.py @@ -20,52 +20,53 @@ from core.errors import CacheError, SecurityError - logger = structlog.get_logger(__name__) class CacheEncryption: """ Secure encryption for cache data with integrity protection. - + Provides encryption, decryption, and integrity validation for cached content to prevent data tampering and unauthorized access. """ - + def __init__(self, encryption_key: Optional[bytes] = None): """ Initialize cache encryption. - + Args: encryption_key: Optional encryption key (generated if not provided) """ self.encryption_key = encryption_key or self._generate_encryption_key() self.cipher_suite = Fernet(self.encryption_key) - + # HMAC key for integrity verification (derived from encryption key) self.hmac_key = self._derive_hmac_key(self.encryption_key) - + # Metadata for cache entries self.version = "1.0" self.algorithm = "Fernet" - - logger.info("Cache encryption initialized", - algorithm=self.algorithm, - version=self.version) - - def encrypt_cache_entry(self, - content: str, - metadata: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: + + logger.info( + "Cache encryption initialized", + algorithm=self.algorithm, + version=self.version, + ) + + def encrypt_cache_entry( + self, content: str, metadata: Optional[Dict[str, Any]] = None + ) -> Dict[str, Any]: """ Encrypt cache content with metadata and integrity protection. - + Args: content: Content to encrypt metadata: Optional metadata to include - + Returns: Encrypted cache entry with metadata - + Raises: CacheError: If encryption fails """ @@ -73,181 +74,200 @@ def encrypt_cache_entry(self, # Validate input if not isinstance(content, str): raise CacheError("Content must be a string") - + if len(content) > 10 * 1024 * 1024: # 10MB limit raise CacheError("Content too large for encryption") - + # Prepare cache entry data cache_data = { "content": content, "metadata": metadata or {}, "encrypted_at": time.time(), - "version": self.version + "version": self.version, } - + # Serialize data - data_bytes = json.dumps(cache_data).encode('utf-8') - + data_bytes = json.dumps(cache_data).encode("utf-8") + # Encrypt data encrypted_data = self.cipher_suite.encrypt(data_bytes) - + # Create integrity hash integrity_hash = self._create_integrity_hash(encrypted_data) - + # Prepare final encrypted entry encrypted_entry = { - "encrypted_content": base64.urlsafe_b64encode(encrypted_data).decode('utf-8'), + "encrypted_content": base64.urlsafe_b64encode(encrypted_data).decode( + "utf-8" + ), "integrity_hash": integrity_hash, "algorithm": self.algorithm, "version": self.version, - "created_at": time.time() + "created_at": time.time(), } - - logger.debug("Cache entry encrypted successfully", - content_length=len(content), - metadata_keys=list(metadata.keys()) if metadata else []) - + + logger.debug( + "Cache entry encrypted successfully", + content_length=len(content), + metadata_keys=list(metadata.keys()) if metadata else [], + ) + return encrypted_entry - + except Exception as e: logger.error("Cache encryption failed", error=str(e)) raise CacheError(f"Failed to encrypt cache entry: {str(e)}") - + def decrypt_cache_entry(self, encrypted_entry: Dict[str, Any]) -> Dict[str, Any]: """ Decrypt and validate cache entry. - + Args: encrypted_entry: Encrypted cache entry - + Returns: Decrypted cache data with content and metadata - + Raises: CacheError: If decryption or validation fails """ try: # Validate encrypted entry structure - required_fields = ["encrypted_content", "integrity_hash", "algorithm", "version"] + required_fields = [ + "encrypted_content", + "integrity_hash", + "algorithm", + "version", + ] for field in required_fields: if field not in encrypted_entry: raise CacheError(f"Missing required field: {field}") - + # Check version compatibility if encrypted_entry["version"] != self.version: - raise CacheError(f"Incompatible cache entry version: {encrypted_entry['version']}") - + raise CacheError( + f"Incompatible cache entry version: {encrypted_entry['version']}" + ) + # Check algorithm if encrypted_entry["algorithm"] != self.algorithm: - raise CacheError(f"Unsupported encryption algorithm: {encrypted_entry['algorithm']}") - + raise CacheError( + f"Unsupported encryption algorithm: {encrypted_entry['algorithm']}" + ) + # Decode encrypted content try: encrypted_data = base64.urlsafe_b64decode( - encrypted_entry["encrypted_content"].encode('utf-8') + encrypted_entry["encrypted_content"].encode("utf-8") ) except Exception: raise CacheError("Invalid encrypted content encoding") - + # Verify integrity expected_hash = encrypted_entry["integrity_hash"] actual_hash = self._create_integrity_hash(encrypted_data) - + if not hmac.compare_digest(expected_hash, actual_hash): - raise CacheError("Cache entry integrity check failed - possible tampering") - + raise CacheError( + "Cache entry integrity check failed - possible tampering" + ) + # Decrypt data try: decrypted_data = self.cipher_suite.decrypt(encrypted_data) except Exception: - raise CacheError("Failed to decrypt cache entry - invalid key or corrupted data") - + raise CacheError( + "Failed to decrypt cache entry - invalid key or corrupted data" + ) + # Parse decrypted content try: - cache_data = json.loads(decrypted_data.decode('utf-8')) + cache_data = json.loads(decrypted_data.decode("utf-8")) except Exception: raise CacheError("Failed to parse decrypted cache data") - + # Validate decrypted data structure if not isinstance(cache_data, dict) or "content" not in cache_data: raise CacheError("Invalid decrypted cache data structure") - - logger.debug("Cache entry decrypted successfully", - content_length=len(cache_data.get("content", "")), - has_metadata=bool(cache_data.get("metadata"))) - + + logger.debug( + "Cache entry decrypted successfully", + content_length=len(cache_data.get("content", "")), + has_metadata=bool(cache_data.get("metadata")), + ) + return cache_data - + except CacheError: raise except Exception as e: logger.error("Cache decryption failed", error=str(e)) raise CacheError(f"Failed to decrypt cache entry: {str(e)}") - - def encrypt_sensitive_fields(self, - data: Dict[str, Any], - sensitive_fields: set) -> Dict[str, Any]: + + def encrypt_sensitive_fields( + self, data: Dict[str, Any], sensitive_fields: set + ) -> Dict[str, Any]: """ Encrypt specific sensitive fields in a dictionary. - + Args: data: Dictionary containing data sensitive_fields: Set of field names to encrypt - + Returns: Dictionary with sensitive fields encrypted """ if not isinstance(data, dict): raise CacheError("Data must be a dictionary") - + encrypted_data = data.copy() - + for field_name in sensitive_fields: if field_name in data: field_value = data[field_name] - + if isinstance(field_value, str): # Encrypt string field encrypted_field = self._encrypt_field(field_value) encrypted_data[f"_encrypted_{field_name}"] = encrypted_field del encrypted_data[field_name] - + elif isinstance(field_value, dict): # Recursively encrypt nested dictionaries encrypted_field = self._encrypt_field(json.dumps(field_value)) encrypted_data[f"_encrypted_{field_name}"] = encrypted_field del encrypted_data[field_name] - + return encrypted_data - - def decrypt_sensitive_fields(self, - data: Dict[str, Any], - sensitive_fields: set) -> Dict[str, Any]: + + def decrypt_sensitive_fields( + self, data: Dict[str, Any], sensitive_fields: set + ) -> Dict[str, Any]: """ Decrypt sensitive fields in a dictionary. - + Args: data: Dictionary with encrypted fields sensitive_fields: Set of original field names - + Returns: Dictionary with sensitive fields decrypted """ if not isinstance(data, dict): raise CacheError("Data must be a dictionary") - + decrypted_data = data.copy() - + for field_name in sensitive_fields: encrypted_field_name = f"_encrypted_{field_name}" - + if encrypted_field_name in data: encrypted_field = data[encrypted_field_name] - + try: # Decrypt field decrypted_value = self._decrypt_field(encrypted_field) - + # Try to parse as JSON for complex types try: parsed_value = json.loads(decrypted_value) @@ -255,68 +275,77 @@ def decrypt_sensitive_fields(self, except json.JSONDecodeError: # Keep as string if not valid JSON decrypted_data[field_name] = decrypted_value - + # Remove encrypted field del decrypted_data[encrypted_field_name] - + except Exception as e: - logger.warning("Failed to decrypt sensitive field", - field_name=field_name, - error=str(e)) + logger.warning( + "Failed to decrypt sensitive field", + field_name=field_name, + error=str(e), + ) # Keep encrypted field if decryption fails - + return decrypted_data - + def is_cache_entry_valid(self, encrypted_entry: Dict[str, Any]) -> bool: """ Check if a cache entry is valid without decrypting it. - + Args: encrypted_entry: Encrypted cache entry to validate - + Returns: True if entry appears valid """ try: # Check required fields - required_fields = ["encrypted_content", "integrity_hash", "algorithm", "version"] + required_fields = [ + "encrypted_content", + "integrity_hash", + "algorithm", + "version", + ] for field in required_fields: if field not in encrypted_entry: return False - + # Check version and algorithm - if (encrypted_entry["version"] != self.version or - encrypted_entry["algorithm"] != self.algorithm): + if ( + encrypted_entry["version"] != self.version + or encrypted_entry["algorithm"] != self.algorithm + ): return False - + # Validate base64 encoding try: base64.urlsafe_b64decode(encrypted_entry["encrypted_content"]) except Exception: return False - + return True - + except Exception: return False - + def _generate_encryption_key(self) -> bytes: """Generate a secure encryption key.""" # In production, this should be loaded from secure key management password = os.urandom(32) salt = os.urandom(16) - + kdf = PBKDF2HMAC( algorithm=hashes.SHA256(), length=32, salt=salt, iterations=100000, - backend=default_backend() + backend=default_backend(), ) - + key = base64.urlsafe_b64encode(kdf.derive(password)) return key - + def _derive_hmac_key(self, encryption_key: bytes) -> bytes: """Derive HMAC key from encryption key.""" # Use PBKDF2 to derive a separate HMAC key @@ -325,40 +354,36 @@ def _derive_hmac_key(self, encryption_key: bytes) -> bytes: length=32, salt=b"hmac_derivation_salt", iterations=10000, - backend=default_backend() + backend=default_backend(), ) - + return kdf.derive(encryption_key) - + def _create_integrity_hash(self, data: bytes) -> str: """Create HMAC hash for integrity verification.""" - signature = hmac.new( - self.hmac_key, - data, - hashlib.sha256 - ).hexdigest() - + signature = hmac.new(self.hmac_key, data, hashlib.sha256).hexdigest() + return signature - + def _encrypt_field(self, value: str) -> str: """Encrypt a single field value.""" - encrypted_data = self.cipher_suite.encrypt(value.encode('utf-8')) - return base64.urlsafe_b64encode(encrypted_data).decode('utf-8') - + encrypted_data = self.cipher_suite.encrypt(value.encode("utf-8")) + return base64.urlsafe_b64encode(encrypted_data).decode("utf-8") + def _decrypt_field(self, encrypted_value: str) -> str: """Decrypt a single field value.""" - encrypted_data = base64.urlsafe_b64decode(encrypted_value.encode('utf-8')) + encrypted_data = base64.urlsafe_b64decode(encrypted_value.encode("utf-8")) decrypted_data = self.cipher_suite.decrypt(encrypted_data) - return decrypted_data.decode('utf-8') + return decrypted_data.decode("utf-8") def create_cache_encryption(encryption_key: Optional[bytes] = None) -> CacheEncryption: """ Create a configured cache encryption instance. - + Args: encryption_key: Optional encryption key - + Returns: CacheEncryption instance """ @@ -372,22 +397,23 @@ def create_cache_encryption(encryption_key: Optional[bytes] = None) -> CacheEncr def get_cache_encryption() -> CacheEncryption: """Get global cache encryption instance.""" global _cache_encryption_instance - + if _cache_encryption_instance is None: _cache_encryption_instance = create_cache_encryption() - + return _cache_encryption_instance -def encrypt_cache_data(content: str, - metadata: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: +def encrypt_cache_data( + content: str, metadata: Optional[Dict[str, Any]] = None +) -> Dict[str, Any]: """ Convenience function to encrypt cache data. - + Args: content: Content to encrypt metadata: Optional metadata - + Returns: Encrypted cache entry """ @@ -398,12 +424,12 @@ def encrypt_cache_data(content: str, def decrypt_cache_data(encrypted_entry: Dict[str, Any]) -> Dict[str, Any]: """ Convenience function to decrypt cache data. - + Args: encrypted_entry: Encrypted cache entry - + Returns: Decrypted cache data """ cache_encryption = get_cache_encryption() - return cache_encryption.decrypt_cache_entry(encrypted_entry) \ No newline at end of file + return cache_encryption.decrypt_cache_entry(encrypted_entry) diff --git a/src/security/config.py b/src/security/config.py index b31613b..0bba6f4 100644 --- a/src/security/config.py +++ b/src/security/config.py @@ -14,29 +14,28 @@ from core.errors import ConfigurationError, SecurityError - logger = structlog.get_logger(__name__) @dataclass class AuthConfig: """Authentication and authorization configuration.""" - + # Token settings token_lifetime_seconds: int = 3600 # 1 hour refresh_token_lifetime_seconds: int = 86400 # 24 hours token_entropy_bytes: int = 32 # 256 bits max_token_per_user: int = 10 - + # OAuth settings oauth_client_id: Optional[str] = None oauth_client_secret: Optional[str] = None oauth_redirect_uri: Optional[str] = None - + # Session settings session_timeout_seconds: int = 1800 # 30 minutes max_sessions_per_user: int = 5 - + # Security settings require_https: bool = True enforce_csrf_protection: bool = True @@ -46,24 +45,24 @@ class AuthConfig: @dataclass class ValidationConfig: """Input validation and sanitization configuration.""" - + # String limits max_string_length: int = 10000 max_url_length: int = 2000 max_email_length: int = 254 max_filename_length: int = 255 max_path_length: int = 4096 - + # List and object limits max_list_items: int = 1000 max_object_properties: int = 100 max_nesting_depth: int = 10 - + # Content validation allow_html: bool = False allow_javascript: bool = False allow_file_uploads: bool = False - + # Encoding settings default_encoding: str = "utf-8" sanitize_unicode: bool = True @@ -72,47 +71,54 @@ class ValidationConfig: @dataclass class EncryptionConfig: """Encryption configuration for data protection.""" - + # Algorithm settings encryption_algorithm: str = "Fernet" key_derivation_iterations: int = 100000 hmac_algorithm: str = "sha256" - + # Key management encryption_key: Optional[bytes] = None key_rotation_interval_days: int = 90 - + # Cache encryption encrypt_cache_content: bool = True cache_encryption_key: Optional[bytes] = None - + # Database encryption encrypt_sensitive_fields: bool = True - sensitive_field_names: List[str] = field(default_factory=lambda: [ - 'password', 'token', 'secret', 'key', 'auth', 'credential' - ]) + sensitive_field_names: List[str] = field( + default_factory=lambda: [ + "password", + "token", + "secret", + "key", + "auth", + "credential", + ] + ) @dataclass class RateLimitConfig: """Rate limiting configuration.""" - + # Global rate limits global_requests_per_minute: int = 1000 global_requests_per_hour: int = 10000 - + # Per-user rate limits user_requests_per_minute: int = 60 user_requests_per_hour: int = 1000 - + # Tool-specific rate limits tool_requests_per_minute: int = 30 tool_requests_per_hour: int = 500 - + # Authentication rate limits auth_attempts_per_minute: int = 5 auth_attempts_per_hour: int = 20 - + # Rate limit enforcement enable_rate_limiting: bool = True rate_limit_storage: str = "memory" # "memory" or "redis" @@ -122,23 +128,23 @@ class RateLimitConfig: @dataclass class SecurityMonitoringConfig: """Security monitoring and alerting configuration.""" - + # Logging settings log_security_events: bool = True log_level: str = "INFO" log_sensitive_data: bool = False - + # Alert thresholds failed_auth_threshold: int = 10 rate_limit_violation_threshold: int = 5 suspicious_activity_threshold: int = 3 - + # Monitoring features enable_intrusion_detection: bool = True enable_anomaly_detection: bool = False monitor_file_access: bool = True monitor_network_requests: bool = True - + # Data retention security_log_retention_days: int = 90 audit_log_retention_days: int = 365 @@ -147,22 +153,26 @@ class SecurityMonitoringConfig: @dataclass class NetworkSecurityConfig: """Network security configuration.""" - + # HTTPS settings enforce_https: bool = True hsts_max_age: int = 31536000 # 1 year - + # CORS settings cors_enabled: bool = True - cors_allowed_origins: List[str] = field(default_factory=lambda: ["https://*.example.com"]) - cors_allowed_methods: List[str] = field(default_factory=lambda: ["GET", "POST", "PUT", "DELETE"]) + cors_allowed_origins: List[str] = field( + default_factory=lambda: ["https://*.example.com"] + ) + cors_allowed_methods: List[str] = field( + default_factory=lambda: ["GET", "POST", "PUT", "DELETE"] + ) cors_max_age: int = 86400 # 24 hours - + # Request filtering block_private_ips: bool = True block_localhost: bool = True allowed_schemes: List[str] = field(default_factory=lambda: ["https"]) - + # Timeout settings request_timeout: int = 30 connection_timeout: int = 10 @@ -172,57 +182,62 @@ class NetworkSecurityConfig: @dataclass class SecurityConfig: """Main security configuration container.""" - + auth: AuthConfig = field(default_factory=AuthConfig) validation: ValidationConfig = field(default_factory=ValidationConfig) encryption: EncryptionConfig = field(default_factory=EncryptionConfig) rate_limiting: RateLimitConfig = field(default_factory=RateLimitConfig) - monitoring: SecurityMonitoringConfig = field(default_factory=SecurityMonitoringConfig) + monitoring: SecurityMonitoringConfig = field( + default_factory=SecurityMonitoringConfig + ) network: NetworkSecurityConfig = field(default_factory=NetworkSecurityConfig) - + # Global security settings debug_mode: bool = False strict_mode: bool = True security_headers_enabled: bool = True - + def __post_init__(self): """Validate configuration after initialization.""" self._validate_config() self._setup_encryption_keys() - + def _validate_config(self): """Validate security configuration.""" # Validate auth config if self.auth.token_lifetime_seconds < 300: # 5 minutes minimum raise ConfigurationError("Token lifetime must be at least 5 minutes") - + if self.auth.token_lifetime_seconds > 86400: # 24 hours maximum raise ConfigurationError("Token lifetime cannot exceed 24 hours") - + # Validate validation config if self.validation.max_string_length > 100 * 1024 * 1024: # 100MB raise ConfigurationError("Maximum string length too large") - + # Validate rate limiting - if self.rate_limiting.user_requests_per_minute > self.rate_limiting.global_requests_per_minute: + if ( + self.rate_limiting.user_requests_per_minute + > self.rate_limiting.global_requests_per_minute + ): raise ConfigurationError("User rate limit cannot exceed global rate limit") - + # Validate encryption if self.encryption.key_derivation_iterations < 10000: raise ConfigurationError("Key derivation iterations too low") - + logger.info("Security configuration validated successfully") - + def _setup_encryption_keys(self): """Setup encryption keys if not provided.""" if self.encryption.encryption_key is None: self.encryption.encryption_key = Fernet.generate_key() logger.info("Generated new encryption key") - + if self.encryption.cache_encryption_key is None: self.encryption.cache_encryption_key = Fernet.generate_key() logger.info("Generated new cache encryption key") - + def to_dict(self) -> Dict[str, Any]: """Convert configuration to dictionary (excluding sensitive data).""" config_dict = { @@ -232,51 +247,51 @@ def to_dict(self) -> Dict[str, Any]: "require_https": self.auth.require_https, "enforce_csrf_protection": self.auth.enforce_csrf_protection, "min_password_length": self.auth.min_password_length, - "oauth_configured": bool(self.auth.oauth_client_id) + "oauth_configured": bool(self.auth.oauth_client_id), }, "validation": { "max_string_length": self.validation.max_string_length, "max_list_items": self.validation.max_list_items, "max_nesting_depth": self.validation.max_nesting_depth, "allow_html": self.validation.allow_html, - "allow_javascript": self.validation.allow_javascript + "allow_javascript": self.validation.allow_javascript, }, "encryption": { "encryption_algorithm": self.encryption.encryption_algorithm, "encrypt_cache_content": self.encryption.encrypt_cache_content, "encrypt_sensitive_fields": self.encryption.encrypt_sensitive_fields, - "key_rotation_interval_days": self.encryption.key_rotation_interval_days + "key_rotation_interval_days": self.encryption.key_rotation_interval_days, }, "rate_limiting": { "enable_rate_limiting": self.rate_limiting.enable_rate_limiting, "user_requests_per_minute": self.rate_limiting.user_requests_per_minute, - "tool_requests_per_minute": self.rate_limiting.tool_requests_per_minute + "tool_requests_per_minute": self.rate_limiting.tool_requests_per_minute, }, "monitoring": { "log_security_events": self.monitoring.log_security_events, "enable_intrusion_detection": self.monitoring.enable_intrusion_detection, - "security_log_retention_days": self.monitoring.security_log_retention_days + "security_log_retention_days": self.monitoring.security_log_retention_days, }, "network": { "enforce_https": self.network.enforce_https, "cors_enabled": self.network.cors_enabled, "block_private_ips": self.network.block_private_ips, - "request_timeout": self.network.request_timeout + "request_timeout": self.network.request_timeout, }, "global": { "debug_mode": self.debug_mode, "strict_mode": self.strict_mode, - "security_headers_enabled": self.security_headers_enabled - } + "security_headers_enabled": self.security_headers_enabled, + }, } - + return config_dict def load_security_config_from_env() -> SecurityConfig: """Load security configuration from environment variables.""" config = SecurityConfig() - + # Auth configuration config.auth.token_lifetime_seconds = int( os.getenv("AUTH_TOKEN_LIFETIME", config.auth.token_lifetime_seconds) @@ -284,45 +299,54 @@ def load_security_config_from_env() -> SecurityConfig: config.auth.oauth_client_id = os.getenv("OAUTH_CLIENT_ID") config.auth.oauth_client_secret = os.getenv("OAUTH_CLIENT_SECRET") config.auth.oauth_redirect_uri = os.getenv("OAUTH_REDIRECT_URI") - + # Validation configuration config.validation.max_string_length = int( os.getenv("VALIDATION_MAX_STRING_LENGTH", config.validation.max_string_length) ) - config.validation.allow_html = os.getenv("VALIDATION_ALLOW_HTML", "false").lower() == "true" - + config.validation.allow_html = ( + os.getenv("VALIDATION_ALLOW_HTML", "false").lower() == "true" + ) + # Encryption configuration encryption_key_b64 = os.getenv("ENCRYPTION_KEY") if encryption_key_b64: try: import base64 - config.encryption.encryption_key = base64.urlsafe_b64decode(encryption_key_b64) + + config.encryption.encryption_key = base64.urlsafe_b64decode( + encryption_key_b64 + ) except Exception as e: - logger.warning("Failed to decode encryption key from environment", error=str(e)) - + logger.warning( + "Failed to decode encryption key from environment", error=str(e) + ) + # Rate limiting configuration config.rate_limiting.enable_rate_limiting = ( os.getenv("RATE_LIMITING_ENABLED", "true").lower() == "true" ) config.rate_limiting.user_requests_per_minute = int( - os.getenv("RATE_LIMIT_USER_PER_MINUTE", config.rate_limiting.user_requests_per_minute) + os.getenv( + "RATE_LIMIT_USER_PER_MINUTE", config.rate_limiting.user_requests_per_minute + ) ) - + # Monitoring configuration config.monitoring.log_security_events = ( os.getenv("LOG_SECURITY_EVENTS", "true").lower() == "true" ) - config.monitoring.log_level = os.getenv("SECURITY_LOG_LEVEL", config.monitoring.log_level) - - # Network configuration - config.network.enforce_https = ( - os.getenv("ENFORCE_HTTPS", "true").lower() == "true" + config.monitoring.log_level = os.getenv( + "SECURITY_LOG_LEVEL", config.monitoring.log_level ) - + + # Network configuration + config.network.enforce_https = os.getenv("ENFORCE_HTTPS", "true").lower() == "true" + # Global settings config.debug_mode = os.getenv("DEBUG_MODE", "false").lower() == "true" config.strict_mode = os.getenv("STRICT_MODE", "true").lower() == "true" - + logger.info("Security configuration loaded from environment") return config @@ -330,23 +354,23 @@ def load_security_config_from_env() -> SecurityConfig: def create_security_config(**overrides) -> SecurityConfig: """ Create security configuration with optional overrides. - + Args: **overrides: Configuration overrides - + Returns: SecurityConfig instance """ # Start with environment-based config config = load_security_config_from_env() - + # Apply overrides for key, value in overrides.items(): if hasattr(config, key): setattr(config, key, value) else: logger.warning("Unknown security config override", key=key) - + return config @@ -357,10 +381,10 @@ def create_security_config(**overrides) -> SecurityConfig: def get_security_config() -> SecurityConfig: """Get global security configuration instance.""" global _security_config_instance - + if _security_config_instance is None: _security_config_instance = load_security_config_from_env() - + return _security_config_instance @@ -374,7 +398,7 @@ def update_security_config(config: SecurityConfig) -> None: def get_security_headers() -> Dict[str, str]: """Get recommended security headers.""" config = get_security_config() - + headers = { "X-Content-Type-Options": "nosniff", "X-Frame-Options": "DENY", @@ -390,54 +414,56 @@ def get_security_headers() -> Dict[str, str]: "object-src 'none'; " "media-src 'self'; " "frame-src 'none';" - ) + ), } - + if config.network.enforce_https: - headers["Strict-Transport-Security"] = f"max-age={config.network.hsts_max_age}; includeSubDomains" - + headers["Strict-Transport-Security"] = ( + f"max-age={config.network.hsts_max_age}; includeSubDomains" + ) + return headers def validate_security_config(config: SecurityConfig) -> List[str]: """ Validate security configuration and return warnings. - + Args: config: Security configuration to validate - + Returns: List of validation warnings """ warnings = [] - + # Check for insecure settings if config.debug_mode: warnings.append("Debug mode is enabled - disable in production") - + if not config.auth.require_https: warnings.append("HTTPS is not required - enable for production") - + if not config.rate_limiting.enable_rate_limiting: warnings.append("Rate limiting is disabled - enable for production") - + if not config.encryption.encrypt_cache_content: warnings.append("Cache encryption is disabled") - + if config.validation.allow_html: warnings.append("HTML input is allowed - potential XSS risk") - + if config.validation.allow_javascript: warnings.append("JavaScript input is allowed - high XSS risk") - + # Check for weak settings if config.auth.token_lifetime_seconds > 7200: # 2 hours warnings.append("Token lifetime is longer than recommended (>2 hours)") - + if config.rate_limiting.user_requests_per_minute > 120: warnings.append("User rate limit is higher than recommended (>120/min)") - + if config.validation.max_string_length > 1024 * 1024: # 1MB warnings.append("Maximum string length is very large (>1MB)") - - return warnings \ No newline at end of file + + return warnings diff --git a/src/security/error_handler.py b/src/security/error_handler.py index ab7b701..dc047f1 100644 --- a/src/security/error_handler.py +++ b/src/security/error_handler.py @@ -12,22 +12,21 @@ import structlog from core.errors import ( - SecurityError, - AuthenticationError, + SecurityError, + AuthenticationError, ValidationError, DatabaseError, ToolExecutionError, - CacheError + CacheError, ) - logger = structlog.get_logger(__name__) @dataclass class SecureErrorResponse: """Represents a sanitized error response.""" - + error_code: str message: str status_code: int @@ -40,16 +39,16 @@ class SecureErrorHandler: Secure error handler that sanitizes error responses to prevent information disclosure while maintaining debugging capabilities. """ - + def __init__(self, debug_mode: bool = False): """ Initialize secure error handler. - + Args: debug_mode: Whether to include debug information in responses """ self.debug_mode = debug_mode - + # Error code mappings for different exception types self.error_mappings = { SecurityError: ("SECURITY_ERROR", 403), @@ -66,7 +65,7 @@ def __init__(self, debug_mode: bool = False): ConnectionError: ("CONNECTION_ERROR", 503), TimeoutError: ("TIMEOUT_ERROR", 504), } - + # Safe error messages that don't reveal system internals self.safe_messages = { "SECURITY_ERROR": "Access denied due to security policy", @@ -82,105 +81,107 @@ def __init__(self, debug_mode: bool = False): "MISSING_KEY": "Required parameter missing", "CONNECTION_ERROR": "Service temporarily unavailable", "TIMEOUT_ERROR": "Request timed out", - "INTERNAL_ERROR": "An internal error occurred" + "INTERNAL_ERROR": "An internal error occurred", } - + # Patterns to sanitize from error messages self.sensitive_patterns = [ - r'password[=:]\s*\S+', - r'api[_-]?key[=:]\s*\S+', - r'token[=:]\s*\S+', - r'secret[=:]\s*\S+', - r'auth[=:]\s*\S+', - r'/[a-zA-Z]:/.*', # Windows paths - r'/home/[^/\s]+', # Unix home paths - r'/etc/[^/\s]+', # System config paths - r'localhost:\d+', # Local endpoints - r'127\.0\.0\.1:\d+', # Local IP endpoints - r'(\d{1,3}\.){3}\d{1,3}:\d+', # IP:port combinations + r"password[=:]\s*\S+", + r"api[_-]?key[=:]\s*\S+", + r"token[=:]\s*\S+", + r"secret[=:]\s*\S+", + r"auth[=:]\s*\S+", + r"/[a-zA-Z]:/.*", # Windows paths + r"/home/[^/\s]+", # Unix home paths + r"/etc/[^/\s]+", # System config paths + r"localhost:\d+", # Local endpoints + r"127\.0\.0\.1:\d+", # Local IP endpoints + r"(\d{1,3}\.){3}\d{1,3}:\d+", # IP:port combinations ] - - def handle_exception(self, - exception: Exception, - context: Optional[Dict[str, Any]] = None) -> SecureErrorResponse: + + def handle_exception( + self, exception: Exception, context: Optional[Dict[str, Any]] = None + ) -> SecureErrorResponse: """ Handle an exception and return a secure error response. - + Args: exception: The exception to handle context: Optional context information - + Returns: SecureErrorResponse with sanitized error information """ # Generate unique error ID for tracking import uuid + error_id = str(uuid.uuid4()) - + # Get error code and status code error_code, status_code = self._get_error_info(exception) - + # Get safe message safe_message = self._get_safe_message(exception, error_code) - + # Log full error details for debugging self._log_error_details(exception, error_id, context) - + # Prepare response details details = None if self.debug_mode: details = self._get_debug_details(exception, context) - + return SecureErrorResponse( error_code=error_code, message=safe_message, status_code=status_code, details=details, - error_id=error_id + error_id=error_id, ) - + def sanitize_error_message(self, message: str) -> str: """ Sanitize an error message to remove sensitive information. - + Args: message: Original error message - + Returns: Sanitized error message """ import re - + sanitized = message - + # Remove sensitive patterns for pattern in self.sensitive_patterns: - sanitized = re.sub(pattern, '[REDACTED]', sanitized, flags=re.IGNORECASE) - + sanitized = re.sub(pattern, "[REDACTED]", sanitized, flags=re.IGNORECASE) + # Remove file paths - sanitized = re.sub(r'/[\w/.]+\.py', '[FILE_PATH]', sanitized) - sanitized = re.sub(r'[A-Z]:\\[\w\\/.]+\.py', '[FILE_PATH]', sanitized) - + sanitized = re.sub(r"/[\w/.]+\.py", "[FILE_PATH]", sanitized) + sanitized = re.sub(r"[A-Z]:\\[\w\\/.]+\.py", "[FILE_PATH]", sanitized) + # Remove line numbers from stack traces - sanitized = re.sub(r'line \d+', 'line [NUM]', sanitized) - + sanitized = re.sub(r"line \d+", "line [NUM]", sanitized) + # Remove memory addresses - sanitized = re.sub(r'0x[0-9a-fA-F]+', '[MEMORY_ADDR]', sanitized) - + sanitized = re.sub(r"0x[0-9a-fA-F]+", "[MEMORY_ADDR]", sanitized) + # Truncate very long messages if len(sanitized) > 500: - sanitized = sanitized[:497] + '...' - + sanitized = sanitized[:497] + "..." + return sanitized - - def format_error_response(self, - error_response: SecureErrorResponse) -> Dict[str, Any]: + + def format_error_response( + self, error_response: SecureErrorResponse + ) -> Dict[str, Any]: """ Format error response for API consumption. - + Args: error_response: SecureErrorResponse to format - + Returns: Formatted error response dictionary """ @@ -188,66 +189,66 @@ def format_error_response(self, "error": True, "error_code": error_response.error_code, "message": error_response.message, - "error_id": error_response.error_id + "error_id": error_response.error_id, } - + if error_response.details and self.debug_mode: response["details"] = error_response.details - + return response - + def _get_error_info(self, exception: Exception) -> Tuple[str, int]: """Get error code and HTTP status code for exception.""" exception_type = type(exception) - + # Check for exact type match if exception_type in self.error_mappings: error_code, status_code = self.error_mappings[exception_type] return error_code, status_code - + # Check for inheritance for exc_type, (error_code, status_code) in self.error_mappings.items(): if isinstance(exception, exc_type): return error_code, status_code - + # Default for unknown exceptions return "INTERNAL_ERROR", 500 - + def _get_safe_message(self, exception: Exception, error_code: str) -> str: """Get a safe error message that doesn't reveal system internals.""" # First try to get the safe message from our mappings safe_message = self.safe_messages.get(error_code, "An error occurred") - + # For known security-related exceptions, we can include some details if isinstance(exception, (ValidationError, SecurityError)): # These exceptions are designed to be user-facing original_message = str(exception) sanitized_message = self.sanitize_error_message(original_message) - + # Only use the original message if it's reasonably safe - if (len(sanitized_message) < 200 and - not any(keyword in sanitized_message.lower() - for keyword in ['password', 'token', 'key', 'secret', 'internal'])): + if len(sanitized_message) < 200 and not any( + keyword in sanitized_message.lower() + for keyword in ["password", "token", "key", "secret", "internal"] + ): return sanitized_message - + return safe_message - - def _log_error_details(self, - exception: Exception, - error_id: str, - context: Optional[Dict[str, Any]]) -> None: + + def _log_error_details( + self, exception: Exception, error_id: str, context: Optional[Dict[str, Any]] + ) -> None: """Log full error details for debugging purposes.""" error_details = { "error_id": error_id, "exception_type": type(exception).__name__, "exception_message": str(exception), - "context": context or {} + "context": context or {}, } - + # Add stack trace in debug mode if self.debug_mode: error_details["stack_trace"] = traceback.format_exc() - + # Log based on severity if isinstance(exception, (SecurityError, AuthenticationError)): logger.warning("Security-related error occurred", **error_details) @@ -255,19 +256,19 @@ def _log_error_details(self, logger.info("Input validation error occurred", **error_details) else: logger.error("System error occurred", **error_details) - - def _get_debug_details(self, - exception: Exception, - context: Optional[Dict[str, Any]]) -> Dict[str, Any]: + + def _get_debug_details( + self, exception: Exception, context: Optional[Dict[str, Any]] + ) -> Dict[str, Any]: """Get debug details for development environment.""" if not self.debug_mode: return {} - + details = { "exception_type": type(exception).__name__, - "exception_message": self.sanitize_error_message(str(exception)) + "exception_message": self.sanitize_error_message(str(exception)), } - + if context: # Sanitize context information sanitized_context = {} @@ -278,9 +279,9 @@ def _get_debug_details(self, sanitized_context[key] = value else: sanitized_context[key] = str(type(value)) - + details["context"] = sanitized_context - + return details @@ -291,32 +292,34 @@ def _get_debug_details(self, def get_error_handler(debug_mode: bool = False) -> SecureErrorHandler: """ Get or create global error handler instance. - + Args: debug_mode: Whether to enable debug mode - + Returns: SecureErrorHandler instance """ global _error_handler_instance - + if _error_handler_instance is None: _error_handler_instance = SecureErrorHandler(debug_mode) - + return _error_handler_instance -def handle_error(exception: Exception, - context: Optional[Dict[str, Any]] = None, - debug_mode: bool = False) -> Dict[str, Any]: +def handle_error( + exception: Exception, + context: Optional[Dict[str, Any]] = None, + debug_mode: bool = False, +) -> Dict[str, Any]: """ Convenience function to handle an error and return a formatted response. - + Args: exception: Exception to handle context: Optional context information debug_mode: Whether to include debug information - + Returns: Formatted error response dictionary """ @@ -328,12 +331,12 @@ def handle_error(exception: Exception, def sanitize_message(message: str) -> str: """ Convenience function to sanitize an error message. - + Args: message: Message to sanitize - + Returns: Sanitized message """ error_handler = get_error_handler() - return error_handler.sanitize_error_message(message) \ No newline at end of file + return error_handler.sanitize_error_message(message) diff --git a/src/security/input_sanitizer.py b/src/security/input_sanitizer.py index 934b3a5..78c6fa6 100644 --- a/src/security/input_sanitizer.py +++ b/src/security/input_sanitizer.py @@ -18,11 +18,12 @@ # Fall back to absolute imports (when run as script) import sys from pathlib import Path + # Add src to path if not already there src_path = str(Path(__file__).parent.parent) if src_path not in sys.path: sys.path.insert(0, src_path) - + from core.errors import SecurityError, ValidationError @@ -32,264 +33,274 @@ class InputSanitizer: """ Comprehensive input sanitization for security protection. - + Provides sanitization against XSS, SQL injection, command injection, and other common attack vectors. """ - + def __init__(self): """Initialize input sanitizer with security patterns.""" - + # XSS prevention patterns self.xss_patterns = [ - r')<[^<]*)*<\/script>', - r'javascript:', - r'data:text/html', - r'vbscript:', - r'onload\s*=', - r'onerror\s*=', - r'onclick\s*=', - r'onmouseover\s*=', - r')<[^<]*)*<\/script>", + r"javascript:", + r"data:text/html", + r"vbscript:", + r"onload\s*=", + r"onerror\s*=", + r"onclick\s*=", + r"onmouseover\s*=", + r"\s*/(?:etc|bin|usr|var|sys|proc)\b', # Redirect to system paths - r'<\s*/(?:etc|bin|usr|var|sys|proc)\b', # Input from system paths - r'\\\\x[0-9a-f]{2}', # Hex escape sequences + r";\s*(rm|del|cat|ls|dir|mkdir|rmdir|cp|mv|chmod|chown)\b", # Specific dangerous commands after semicolon + r"\|\s*(rm|del|cat|ls|dir|mkdir|rmdir|cp|mv|chmod|chown)\b", # Pipe to dangerous commands + r"`[^`]*`", # Backticks for command execution + r"\$\([^)]*\)", # Command substitution + r"&&\s*(rm|del|cat|ls|dir|mkdir|rmdir|cp|mv|chmod|chown)\b", # AND command execution with dangerous commands + r"\|\|\s*(rm|del|cat|ls|dir|mkdir|rmdir|cp|mv|chmod|chown)\b", # OR command execution with dangerous commands + r">\s*/(?:etc|bin|usr|var|sys|proc)\b", # Redirect to system paths + r"<\s*/(?:etc|bin|usr|var|sys|proc)\b", # Input from system paths + r"\\\\x[0-9a-f]{2}", # Hex escape sequences ] - + # Path traversal patterns self.path_traversal_patterns = [ - r'\.\./+', - r'\.\.\\+', - r'%2e%2e%2f', - r'%2e%2e%5c', - r'..%2f', - r'..%5c', + r"\.\./+", + r"\.\.\\+", + r"%2e%2e%2f", + r"%2e%2e%5c", + r"..%2f", + r"..%5c", ] - + # Maximum safe lengths self.max_lengths = { - 'string': 10000, - 'url': 2000, - 'email': 254, - 'username': 100, - 'password': 256, - 'filename': 255, - 'path': 4096 + "string": 10000, + "url": 2000, + "email": 254, + "username": 100, + "password": 256, + "filename": 255, + "path": 4096, } - - def sanitize_string(self, - value: str, - max_length: Optional[int] = None, - allow_html: bool = False, - field_type: str = 'string') -> str: + + def sanitize_string( + self, + value: str, + max_length: Optional[int] = None, + allow_html: bool = False, + field_type: str = "string", + ) -> str: """ Sanitize string input for security. - + Args: value: Input string to sanitize max_length: Maximum allowed length allow_html: Whether to allow HTML content field_type: Type of field for specific validation - + Returns: Sanitized string - + Raises: SecurityError: If dangerous content is detected ValidationError: If validation fails """ if not isinstance(value, str): raise ValidationError("Input must be a string") - + # Check length limits - max_len = max_length or self.max_lengths.get(field_type, self.max_lengths['string']) + max_len = max_length or self.max_lengths.get( + field_type, self.max_lengths["string"] + ) if len(value) > max_len: raise ValidationError(f"Input too long: {len(value)} > {max_len}") - + # Check for null bytes - if '\x00' in value: + if "\x00" in value: raise SecurityError("Null bytes not allowed in input") - + # Check for control characters (except common whitespace) - control_chars = re.findall(r'[\x00-\x08\x0B\x0C\x0E-\x1F\x7F]', value) + control_chars = re.findall(r"[\x00-\x08\x0B\x0C\x0E-\x1F\x7F]", value) if control_chars: raise SecurityError("Control characters not allowed in input") - + # XSS protection if not allow_html: self._check_xss_patterns(value) # HTML encode for safety value = html.escape(value, quote=True) - + # Check for injection patterns self._check_injection_patterns(value) - + # Field-specific sanitization - if field_type == 'url': + if field_type == "url": value = self._sanitize_url(value) - elif field_type == 'email': + elif field_type == "email": value = self._sanitize_email(value) - elif field_type == 'filename': + elif field_type == "filename": value = self._sanitize_filename(value) - elif field_type == 'path': + elif field_type == "path": value = self._sanitize_path(value) - - logger.debug("String sanitized successfully", - field_type=field_type, - original_length=len(value), - allow_html=allow_html) - + + logger.debug( + "String sanitized successfully", + field_type=field_type, + original_length=len(value), + allow_html=allow_html, + ) + return value - - def sanitize_dict(self, - data: Dict[str, Any], - field_rules: Optional[Dict[str, Dict[str, Any]]] = None) -> Dict[str, Any]: + + def sanitize_dict( + self, + data: Dict[str, Any], + field_rules: Optional[Dict[str, Dict[str, Any]]] = None, + ) -> Dict[str, Any]: """ Sanitize dictionary data recursively. - + Args: data: Dictionary to sanitize field_rules: Optional field-specific sanitization rules - + Returns: Sanitized dictionary """ if not isinstance(data, dict): raise ValidationError("Input must be a dictionary") - + sanitized = {} field_rules = field_rules or {} - + for key, value in data.items(): # Sanitize key sanitized_key = self._sanitize_dict_key(key) - + # Get field rules rules = field_rules.get(key, {}) - + # Sanitize value based on type if isinstance(value, str): sanitized_value = self.sanitize_string( value, - max_length=rules.get('max_length'), - allow_html=rules.get('allow_html', False), - field_type=rules.get('type', 'string') + max_length=rules.get("max_length"), + allow_html=rules.get("allow_html", False), + field_type=rules.get("type", "string"), ) elif isinstance(value, dict): - nested_rules = rules.get('nested_rules', {}) + nested_rules = rules.get("nested_rules", {}) sanitized_value = self.sanitize_dict(value, nested_rules) elif isinstance(value, list): - sanitized_value = self.sanitize_list(value, rules.get('item_rules', {})) + sanitized_value = self.sanitize_list(value, rules.get("item_rules", {})) else: # For other types (int, float, bool), validate but don't modify sanitized_value = self._validate_primitive(value, rules) - + sanitized[sanitized_key] = sanitized_value - + return sanitized - - def sanitize_list(self, - data: List[Any], - item_rules: Optional[Dict[str, Any]] = None) -> List[Any]: + + def sanitize_list( + self, data: List[Any], item_rules: Optional[Dict[str, Any]] = None + ) -> List[Any]: """ Sanitize list data. - + Args: data: List to sanitize item_rules: Rules for list items - + Returns: Sanitized list """ if not isinstance(data, list): raise ValidationError("Input must be a list") - + # Check list length - max_items = item_rules.get('max_items', 1000) if item_rules else 1000 + max_items = item_rules.get("max_items", 1000) if item_rules else 1000 if len(data) > max_items: raise ValidationError(f"List too long: {len(data)} > {max_items}") - + sanitized = [] item_rules = item_rules or {} - + for item in data: if isinstance(item, str): sanitized_item = self.sanitize_string( item, - max_length=item_rules.get('max_length'), - allow_html=item_rules.get('allow_html', False), - field_type=item_rules.get('type', 'string') + max_length=item_rules.get("max_length"), + allow_html=item_rules.get("allow_html", False), + field_type=item_rules.get("type", "string"), ) elif isinstance(item, dict): - nested_rules = item_rules.get('nested_rules', {}) + nested_rules = item_rules.get("nested_rules", {}) sanitized_item = self.sanitize_dict(item, nested_rules) elif isinstance(item, list): - sanitized_item = self.sanitize_list(item, item_rules.get('item_rules', {})) + sanitized_item = self.sanitize_list( + item, item_rules.get("item_rules", {}) + ) else: sanitized_item = self._validate_primitive(item, item_rules) - + sanitized.append(sanitized_item) - + return sanitized - + def _check_xss_patterns(self, value: str) -> None: """Check for XSS attack patterns.""" value_lower = value.lower() - + for pattern in self.xss_patterns: if re.search(pattern, value_lower, re.IGNORECASE | re.DOTALL): raise SecurityError(f"Potential XSS attack detected") - + def _check_injection_patterns(self, value: str) -> None: """Check for various injection attack patterns.""" value_lower = value.lower() - + # SQL injection check for pattern in self.sql_injection_patterns: if re.search(pattern, value_lower, re.IGNORECASE): raise SecurityError("Potential SQL injection detected") - + # Command injection check for pattern in self.command_injection_patterns: if re.search(pattern, value, re.IGNORECASE): raise SecurityError("Potential command injection detected") - + # Path traversal check for pattern in self.path_traversal_patterns: if re.search(pattern, value, re.IGNORECASE): raise SecurityError("Potential path traversal attack detected") - + def _sanitize_url(self, url: str) -> str: """Sanitize URL input.""" # Parse URL to validate structure @@ -297,104 +308,123 @@ def _sanitize_url(self, url: str) -> str: parsed = urllib.parse.urlparse(url) except Exception: raise ValidationError("Invalid URL format") - + # Check scheme - if parsed.scheme not in ['http', 'https']: + if parsed.scheme not in ["http", "https"]: raise SecurityError("Only HTTP and HTTPS URLs allowed") - + # Check for dangerous characters - dangerous_chars = ['<', '>', '"', '\'', '&', '\x00'] + dangerous_chars = ["<", ">", '"', "'", "&", "\x00"] for char in dangerous_chars: if char in url: raise SecurityError(f"Dangerous character in URL: {char}") - + return url - + def _sanitize_email(self, email: str) -> str: """Sanitize email input.""" # Basic email format validation - email_pattern = r'^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$' + email_pattern = r"^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$" if not re.match(email_pattern, email): raise ValidationError("Invalid email format") - + return email.lower().strip() - + def _sanitize_filename(self, filename: str) -> str: """Sanitize filename input.""" # Remove path separators and dangerous characters - dangerous_chars = ['/', '\\', ':', '*', '?', '"', '<', '>', '|', '\x00'] + dangerous_chars = ["/", "\\", ":", "*", "?", '"', "<", ">", "|", "\x00"] for char in dangerous_chars: if char in filename: raise SecurityError(f"Dangerous character in filename: {char}") - + # Check for reserved names (Windows) reserved_names = { - 'CON', 'PRN', 'AUX', 'NUL', 'COM1', 'COM2', 'COM3', 'COM4', - 'COM5', 'COM6', 'COM7', 'COM8', 'COM9', 'LPT1', 'LPT2', - 'LPT3', 'LPT4', 'LPT5', 'LPT6', 'LPT7', 'LPT8', 'LPT9' + "CON", + "PRN", + "AUX", + "NUL", + "COM1", + "COM2", + "COM3", + "COM4", + "COM5", + "COM6", + "COM7", + "COM8", + "COM9", + "LPT1", + "LPT2", + "LPT3", + "LPT4", + "LPT5", + "LPT6", + "LPT7", + "LPT8", + "LPT9", } - + name_upper = filename.upper() - if name_upper in reserved_names or name_upper.split('.')[0] in reserved_names: + if name_upper in reserved_names or name_upper.split(".")[0] in reserved_names: raise SecurityError("Reserved filename not allowed") - + return filename.strip() - + def _sanitize_path(self, path: str) -> str: """Sanitize file path input.""" # Check for path traversal - if '..' in path: + if ".." in path: raise SecurityError("Path traversal not allowed") - + # Check for absolute paths - if path.startswith('/') or (len(path) > 1 and path[1] == ':'): + if path.startswith("/") or (len(path) > 1 and path[1] == ":"): raise SecurityError("Absolute paths not allowed") - + return path - + def _sanitize_dict_key(self, key: str) -> str: """Sanitize dictionary key.""" if not isinstance(key, str): raise ValidationError("Dictionary keys must be strings") - + if len(key) > 100: raise ValidationError("Dictionary key too long") - + # Allow only alphanumeric characters, underscores, and hyphens - if not re.match(r'^[a-zA-Z0-9_-]+$', key): + if not re.match(r"^[a-zA-Z0-9_-]+$", key): raise SecurityError("Invalid characters in dictionary key") - + return key - + def _validate_primitive(self, value: Any, rules: Dict[str, Any]) -> Any: """Validate primitive values (int, float, bool).""" if isinstance(value, (int, float)): - min_val = rules.get('min_value') - max_val = rules.get('max_value') - + min_val = rules.get("min_value") + max_val = rules.get("max_value") + if min_val is not None and value < min_val: raise ValidationError(f"Value too small: {value} < {min_val}") - + if max_val is not None and value > max_val: raise ValidationError(f"Value too large: {value} > {max_val}") - + elif isinstance(value, bool): # Boolean values are safe as-is pass - + else: # For other types, convert to string and validate str_value = str(value) if len(str_value) > 1000: raise ValidationError("Value representation too long") - + return value def create_input_sanitizer() -> InputSanitizer: """ Create a configured input sanitizer instance. - + Returns: InputSanitizer instance """ @@ -408,31 +438,31 @@ def create_input_sanitizer() -> InputSanitizer: def get_sanitizer() -> InputSanitizer: """Get global sanitizer instance.""" global _sanitizer_instance - + if _sanitizer_instance is None: _sanitizer_instance = create_input_sanitizer() - + return _sanitizer_instance def sanitize_input(value: Any, **kwargs) -> Any: """ Convenience function to sanitize input. - + Args: value: Value to sanitize **kwargs: Sanitization options - + Returns: Sanitized value """ sanitizer = get_sanitizer() - + if isinstance(value, str): return sanitizer.sanitize_string(value, **kwargs) elif isinstance(value, dict): - return sanitizer.sanitize_dict(value, kwargs.get('field_rules')) + return sanitizer.sanitize_dict(value, kwargs.get("field_rules")) elif isinstance(value, list): - return sanitizer.sanitize_list(value, kwargs.get('item_rules')) + return sanitizer.sanitize_list(value, kwargs.get("item_rules")) else: - return sanitizer._validate_primitive(value, kwargs) \ No newline at end of file + return sanitizer._validate_primitive(value, kwargs) diff --git a/src/security/token_manager.py b/src/security/token_manager.py index 0020868..6a8307b 100644 --- a/src/security/token_manager.py +++ b/src/security/token_manager.py @@ -22,14 +22,13 @@ from core.errors import SecurityError, AuthenticationError - logger = structlog.get_logger(__name__) @dataclass class SecureToken: """Represents a securely encrypted token with metadata.""" - + token_id: str user_id: str tool_name: str @@ -41,15 +40,15 @@ class SecureToken: is_revoked: bool = False last_used: Optional[float] = None use_count: int = 0 - + def is_expired(self) -> bool: """Check if token is expired.""" return time.time() > self.expires_at - + def is_valid(self) -> bool: """Check if token is valid and usable.""" return not self.is_revoked and not self.is_expired() - + def record_use(self) -> None: """Record token usage.""" self.last_used = time.time() @@ -59,65 +58,69 @@ def record_use(self) -> None: class SecureTokenManager: """ Secure token manager with encryption, proper storage, and validation. - + Provides cryptographically secure token generation, storage, and management with protection against token manipulation and replay attacks. """ - + def __init__(self, encryption_key: Optional[bytes] = None): """ Initialize secure token manager. - + Args: encryption_key: Optional encryption key (generates if not provided) """ self.encryption_key = encryption_key or self._generate_encryption_key() self.cipher_suite = Fernet(self.encryption_key) - + # Secure token storage (in production, use encrypted database) self.tokens: Dict[str, SecureToken] = {} - + # Token validation settings self.max_token_lifetime = 3600 # 1 hour self.max_refresh_lifetime = 86400 # 24 hours self.token_entropy_bytes = 32 # 256 bits - + logger.info("SecureTokenManager initialized with encryption") - - def generate_access_token(self, - user_id: str, - tool_name: str, - scopes: List[str], - lifetime_seconds: int = None) -> str: + + def generate_access_token( + self, + user_id: str, + tool_name: str, + scopes: List[str], + lifetime_seconds: int = None, + ) -> str: """ Generate a cryptographically secure access token. - + Args: user_id: User identifier tool_name: Tool name for authorization scopes: List of authorized scopes lifetime_seconds: Token lifetime (default: 1 hour) - + Returns: Secure access token string - + Raises: SecurityError: If token generation fails """ try: # Validate inputs self._validate_token_inputs(user_id, tool_name, scopes) - + # Set token lifetime if lifetime_seconds is None: lifetime_seconds = self.max_token_lifetime elif lifetime_seconds > self.max_token_lifetime: - raise SecurityError(f"Token lifetime cannot exceed {self.max_token_lifetime} seconds") - + raise SecurityError( + f"Token lifetime cannot exceed {self.max_token_lifetime} seconds" + ) + # Generate secure token components token_id = self._generate_token_id() raw_token = self._generate_raw_token() - + # Create token data current_time = time.time() token_data = { @@ -127,15 +130,15 @@ def generate_access_token(self, "scopes": scopes, "created_at": current_time, "expires_at": current_time + lifetime_seconds, - "raw_token": raw_token + "raw_token": raw_token, } - + # Encrypt token data encrypted_data = self._encrypt_token_data(token_data) - + # Create token hash for validation token_hash = self._create_token_hash(raw_token, user_id) - + # Store secure token secure_token = SecureToken( token_id=token_id, @@ -145,107 +148,113 @@ def generate_access_token(self, created_at=current_time, expires_at=current_time + lifetime_seconds, token_hash=token_hash, - encrypted_data=encrypted_data + encrypted_data=encrypted_data, ) - + self.tokens[token_id] = secure_token - + # Create final token (format: tokenId.encryptedData.hash) final_token = f"{token_id}.{encrypted_data}.{token_hash}" - - logger.info("Access token generated", - token_id=token_id, - user_id=user_id, - tool_name=tool_name, - scopes=scopes, - expires_at=secure_token.expires_at) - + + logger.info( + "Access token generated", + token_id=token_id, + user_id=user_id, + tool_name=tool_name, + scopes=scopes, + expires_at=secure_token.expires_at, + ) + return final_token - + except Exception as e: - logger.error("Token generation failed", - user_id=user_id, - tool_name=tool_name, - error=str(e)) + logger.error( + "Token generation failed", + user_id=user_id, + tool_name=tool_name, + error=str(e), + ) raise SecurityError(f"Failed to generate access token: {str(e)}") - + def validate_token(self, token: str) -> Dict[str, Any]: """ Validate and decrypt an access token. - + Args: token: Access token to validate - + Returns: Token data if valid - + Raises: AuthenticationError: If token is invalid """ try: # Parse token components - token_parts = token.split('.') + token_parts = token.split(".") if len(token_parts) != 3: raise AuthenticationError("Invalid token format") - + token_id, encrypted_data, provided_hash = token_parts - + # Check if token exists if token_id not in self.tokens: raise AuthenticationError("Token not found") - + secure_token = self.tokens[token_id] - + # Check token validity if not secure_token.is_valid(): if secure_token.is_revoked: raise AuthenticationError("Token has been revoked") else: raise AuthenticationError("Token has expired") - + # Validate token hash if not self._verify_token_hash(provided_hash, secure_token.token_hash): raise AuthenticationError("Token integrity check failed") - + # Decrypt token data token_data = self._decrypt_token_data(encrypted_data) - + # Additional validation if token_data["token_id"] != token_id: raise AuthenticationError("Token ID mismatch") - + if token_data["expires_at"] < time.time(): raise AuthenticationError("Token expired") - + # Record token usage secure_token.record_use() - - logger.debug("Token validated successfully", - token_id=token_id, - user_id=token_data["user_id"], - use_count=secure_token.use_count) - + + logger.debug( + "Token validated successfully", + token_id=token_id, + user_id=token_data["user_id"], + use_count=secure_token.use_count, + ) + return { "user_id": token_data["user_id"], "tool_name": token_data["tool_name"], "scopes": token_data["scopes"], "created_at": token_data["created_at"], - "expires_at": token_data["expires_at"] + "expires_at": token_data["expires_at"], } - + except AuthenticationError: raise except Exception as e: logger.error("Token validation failed", error=str(e)) raise AuthenticationError(f"Token validation failed: {str(e)}") - + def revoke_token(self, token_id: str) -> bool: """ Revoke a specific token. - + Args: token_id: Token ID to revoke - + Returns: True if token was revoked """ @@ -254,152 +263,157 @@ def revoke_token(self, token_id: str) -> bool: logger.info("Token revoked", token_id=token_id) return True return False - + def revoke_user_tokens(self, user_id: str, tool_name: Optional[str] = None) -> int: """ Revoke all tokens for a user or user-tool combination. - + Args: user_id: User identifier tool_name: Optional tool name filter - + Returns: Number of tokens revoked """ revoked_count = 0 - + for token in self.tokens.values(): - if (token.user_id == user_id and - (tool_name is None or token.tool_name == tool_name) and - not token.is_revoked): + if ( + token.user_id == user_id + and (tool_name is None or token.tool_name == tool_name) + and not token.is_revoked + ): token.is_revoked = True revoked_count += 1 - - logger.info("User tokens revoked", - user_id=user_id, - tool_name=tool_name, - count=revoked_count) - + + logger.info( + "User tokens revoked", + user_id=user_id, + tool_name=tool_name, + count=revoked_count, + ) + return revoked_count - + def cleanup_expired_tokens(self) -> int: """ Remove expired tokens from storage. - + Returns: Number of tokens cleaned up """ current_time = time.time() expired_tokens = [ - token_id for token_id, token in self.tokens.items() + token_id + for token_id, token in self.tokens.items() if token.expires_at < current_time ] - + for token_id in expired_tokens: del self.tokens[token_id] - + if expired_tokens: logger.info("Expired tokens cleaned up", count=len(expired_tokens)) - + return len(expired_tokens) - + def get_token_stats(self) -> Dict[str, Any]: """Get token statistics for monitoring.""" current_time = time.time() - + active_tokens = sum(1 for token in self.tokens.values() if token.is_valid()) expired_tokens = sum(1 for token in self.tokens.values() if token.is_expired()) revoked_tokens = sum(1 for token in self.tokens.values() if token.is_revoked) - + return { "total_tokens": len(self.tokens), "active_tokens": active_tokens, "expired_tokens": expired_tokens, "revoked_tokens": revoked_tokens, - "cleanup_needed": expired_tokens > 0 + "cleanup_needed": expired_tokens > 0, } - + def _generate_encryption_key(self) -> bytes: """Generate a secure encryption key.""" # In production, this should be loaded from secure key management password = os.urandom(32) salt = os.urandom(16) - + kdf = PBKDF2HMAC( algorithm=hashes.SHA256(), length=32, salt=salt, iterations=100000, - backend=default_backend() + backend=default_backend(), ) - + key = base64.urlsafe_b64encode(kdf.derive(password)) return key - + def _generate_token_id(self) -> str: """Generate a unique token identifier.""" return secrets.token_urlsafe(16) - + def _generate_raw_token(self) -> str: """Generate a cryptographically secure raw token.""" return secrets.token_urlsafe(self.token_entropy_bytes) - + def _create_token_hash(self, raw_token: str, user_id: str) -> str: """Create a secure hash of the token for validation.""" # Use HMAC for secure hashing - message = f"{raw_token}:{user_id}".encode('utf-8') - signature = hmac.new( - self.encryption_key, - message, - hashlib.sha256 - ).hexdigest() + message = f"{raw_token}:{user_id}".encode("utf-8") + signature = hmac.new(self.encryption_key, message, hashlib.sha256).hexdigest() return signature - + def _verify_token_hash(self, provided_hash: str, stored_hash: str) -> bool: """Verify token hash using constant-time comparison.""" return hmac.compare_digest(provided_hash, stored_hash) - + def _encrypt_token_data(self, token_data: Dict[str, Any]) -> str: """Encrypt token data using Fernet encryption.""" - data_bytes = json.dumps(token_data).encode('utf-8') + data_bytes = json.dumps(token_data).encode("utf-8") encrypted_data = self.cipher_suite.encrypt(data_bytes) - return base64.urlsafe_b64encode(encrypted_data).decode('utf-8') - + return base64.urlsafe_b64encode(encrypted_data).decode("utf-8") + def _decrypt_token_data(self, encrypted_data: str) -> Dict[str, Any]: """Decrypt token data using Fernet encryption.""" - encrypted_bytes = base64.urlsafe_b64decode(encrypted_data.encode('utf-8')) + encrypted_bytes = base64.urlsafe_b64decode(encrypted_data.encode("utf-8")) decrypted_data = self.cipher_suite.decrypt(encrypted_bytes) - return json.loads(decrypted_data.decode('utf-8')) - - def _validate_token_inputs(self, user_id: str, tool_name: str, scopes: List[str]) -> None: + return json.loads(decrypted_data.decode("utf-8")) + + def _validate_token_inputs( + self, user_id: str, tool_name: str, scopes: List[str] + ) -> None: """Validate token generation inputs.""" if not user_id or not isinstance(user_id, str): raise SecurityError("Invalid user_id") - + if not tool_name or not isinstance(tool_name, str): raise SecurityError("Invalid tool_name") - + if not scopes or not isinstance(scopes, list): raise SecurityError("Invalid scopes") - + if len(user_id) > 100: raise SecurityError("user_id too long") - + if len(tool_name) > 100: raise SecurityError("tool_name too long") - + if len(scopes) > 20: raise SecurityError("Too many scopes") -def create_secure_token_manager(encryption_key: Optional[bytes] = None) -> SecureTokenManager: +def create_secure_token_manager( + encryption_key: Optional[bytes] = None, +) -> SecureTokenManager: """ Create a configured secure token manager. - + Args: encryption_key: Optional encryption key - + Returns: SecureTokenManager instance """ - return SecureTokenManager(encryption_key) \ No newline at end of file + return SecureTokenManager(encryption_key) diff --git a/src/tools/__init__.py b/src/tools/__init__.py index 062fa36..f8488c1 100644 --- a/src/tools/__init__.py +++ b/src/tools/__init__.py @@ -8,24 +8,36 @@ # Use try/except to handle both relative and absolute imports try: from .decorators import Tool, get_tool_registry, ToolDefinition, ToolRegistry - from .executor import ToolExecutor, ToolCall, ToolResult, create_tool_call, format_tool_result_for_llm + from .executor import ( + ToolExecutor, + ToolCall, + ToolResult, + create_tool_call, + format_tool_result_for_llm, + ) from .validation import ParameterValidator, SecurityValidator except ImportError: # Fallback to absolute imports when called from scripts from tools.decorators import Tool, get_tool_registry, ToolDefinition, ToolRegistry - from tools.executor import ToolExecutor, ToolCall, ToolResult, create_tool_call, format_tool_result_for_llm + from tools.executor import ( + ToolExecutor, + ToolCall, + ToolResult, + create_tool_call, + format_tool_result_for_llm, + ) from tools.validation import ParameterValidator, SecurityValidator __all__ = [ - 'Tool', - 'get_tool_registry', - 'ToolDefinition', - 'ToolRegistry', - 'ToolExecutor', - 'ToolCall', - 'ToolResult', - 'create_tool_call', - 'format_tool_result_for_llm', - 'ParameterValidator', - 'SecurityValidator' -] \ No newline at end of file + "Tool", + "get_tool_registry", + "ToolDefinition", + "ToolRegistry", + "ToolExecutor", + "ToolCall", + "ToolResult", + "create_tool_call", + "format_tool_result_for_llm", + "ParameterValidator", + "SecurityValidator", +] diff --git a/src/tools/connectors/file.py b/src/tools/connectors/file.py index a8a6e5a..916323f 100644 --- a/src/tools/connectors/file.py +++ b/src/tools/connectors/file.py @@ -15,18 +15,17 @@ from tools.decorators import Tool from ...core.errors import ToolExecutionError, SecurityError, ValidationError - logger = structlog.get_logger(__name__) # Configuration for allowed directories ALLOWED_BASE_PATHS = [ "/workspace/data", - "/workspace/docs", + "/workspace/docs", "/workspace/output", "./data", "./docs", - "./output" + "./output", ] MAX_FILE_SIZE = 10 * 1024 * 1024 # 10MB @@ -41,39 +40,39 @@ "type": "string", "description": "Path to file to read (relative to allowed directories)", "pattern": r"^[a-zA-Z0-9/._-]+$", - "maxLength": 500 + "maxLength": 500, }, "encoding": { "type": "string", "description": "File encoding", "enum": ["utf-8", "ascii", "latin-1", "utf-16"], - "default": "utf-8" + "default": "utf-8", }, "max_lines": { "type": "integer", "description": "Maximum number of lines to read (0 = all)", "default": 0, "minimum": 0, - "maximum": 10000 - } + "maximum": 10000, + }, }, requires_auth=False, - timeout_seconds=30 + timeout_seconds=30, ) -def read_file_tool(file_path: str, - encoding: str = "utf-8", - max_lines: int = 0) -> Dict[str, Any]: +def read_file_tool( + file_path: str, encoding: str = "utf-8", max_lines: int = 0 +) -> Dict[str, Any]: """ Read contents of a file with security validation. - + Args: file_path: Path to file to read encoding: File encoding max_lines: Maximum lines to read (0 = all) - + Returns: File content and metadata - + Raises: SecurityError: If file path is not allowed ToolExecutionError: If file operation fails @@ -81,22 +80,22 @@ def read_file_tool(file_path: str, try: # Security validation validated_path = _validate_file_path_security(file_path, operation="read") - + # Check if file exists if not validated_path.exists(): return { "success": False, "error": f"File does not exist: {file_path}", - "file_path": file_path + "file_path": file_path, } - + if not validated_path.is_file(): return { "success": False, "error": f"Path is not a file: {file_path}", - "file_path": file_path + "file_path": file_path, } - + # Check file size file_size = validated_path.stat().st_size if file_size > MAX_FILE_SIZE: @@ -104,71 +103,70 @@ def read_file_tool(file_path: str, "success": False, "error": f"File too large: {file_size} bytes (max: {MAX_FILE_SIZE})", "file_path": file_path, - "file_size": file_size + "file_size": file_size, } - + # Read file content start_time = time.time() - - with open(validated_path, 'r', encoding=encoding) as file: + + with open(validated_path, "r", encoding=encoding) as file: if max_lines > 0: lines = [] for i, line in enumerate(file): if i >= max_lines: break - lines.append(line.rstrip('\n\r')) - content = '\n'.join(lines) + lines.append(line.rstrip("\n\r")) + content = "\n".join(lines) truncated = True else: content = file.read() truncated = False - + read_time = (time.time() - start_time) * 1000 - + # Get file metadata stat = validated_path.stat() - + result = { "success": True, "file_path": file_path, "content": content, "size_bytes": len(content.encode(encoding)), "original_size_bytes": file_size, - "line_count": content.count('\n') + 1 if content else 0, + "line_count": content.count("\n") + 1 if content else 0, "encoding": encoding, "truncated": truncated, "max_lines_applied": max_lines if max_lines > 0 else None, "last_modified": stat.st_mtime, - "read_time_ms": read_time + "read_time_ms": read_time, } - - logger.info("File read successfully", - file_path=file_path, - size_bytes=result["size_bytes"], - read_time_ms=read_time) - + + logger.info( + "File read successfully", + file_path=file_path, + size_bytes=result["size_bytes"], + read_time_ms=read_time, + ) + return result - + except UnicodeDecodeError as e: - logger.error("File encoding error", - file_path=file_path, - encoding=encoding, - error=str(e)) + logger.error( + "File encoding error", file_path=file_path, encoding=encoding, error=str(e) + ) return { "success": False, "error": f"File encoding error: {str(e)}", "file_path": file_path, - "encoding": encoding + "encoding": encoding, } - + except Exception as e: - logger.error("File read failed", - file_path=file_path, - error=str(e)) + logger.error("File read failed", file_path=file_path, error=str(e)) return { "success": False, "error": f"Failed to read file: {str(e)}", - "file_path": file_path + "file_path": file_path, } @@ -180,144 +178,155 @@ def read_file_tool(file_path: str, "type": "string", "description": "Path to directory to list", "pattern": r"^[a-zA-Z0-9/._-]*$", - "maxLength": 500 + "maxLength": 500, }, "include_hidden": { "type": "boolean", "description": "Include hidden files and directories", - "default": False + "default": False, }, "recursive": { - "type": "boolean", + "type": "boolean", "description": "List subdirectories recursively", - "default": False + "default": False, }, "file_types": { "type": "array", "description": "Filter by file extensions (e.g., ['.txt', '.json'])", "items": {"type": "string"}, - "default": [] - } + "default": [], + }, }, requires_auth=False, - timeout_seconds=30 + timeout_seconds=30, ) -def list_directory_tool(directory_path: str = "", - include_hidden: bool = False, - recursive: bool = False, - file_types: List[str] = None) -> Dict[str, Any]: +def list_directory_tool( + directory_path: str = "", + include_hidden: bool = False, + recursive: bool = False, + file_types: List[str] = None, +) -> Dict[str, Any]: """ List contents of a directory with security validation. - + Args: directory_path: Path to directory to list include_hidden: Include hidden files recursive: List recursively file_types: Filter by file extensions - + Returns: Directory listing with metadata """ try: # Security validation validated_path = _validate_directory_path_security(directory_path) - + if not validated_path.exists(): return { "success": False, "error": f"Directory does not exist: {directory_path}", - "directory_path": directory_path + "directory_path": directory_path, } - + if not validated_path.is_dir(): return { "success": False, "error": f"Path is not a directory: {directory_path}", - "directory_path": directory_path + "directory_path": directory_path, } - + # List directory contents start_time = time.time() entries = [] - + if file_types is None: file_types = [] - + # Normalize file extensions - file_types = [ext.lower() if ext.startswith('.') else f'.{ext.lower()}' for ext in file_types] - + file_types = [ + ext.lower() if ext.startswith(".") else f".{ext.lower()}" + for ext in file_types + ] + try: if recursive: for root, dirs, files in os.walk(validated_path): # Skip hidden directories if not requested if not include_hidden: - dirs[:] = [d for d in dirs if not d.startswith('.')] - + dirs[:] = [d for d in dirs if not d.startswith(".")] + root_path = Path(root) - + # Process directories for dir_name in dirs: - if not include_hidden and dir_name.startswith('.'): + if not include_hidden and dir_name.startswith("."): continue - + dir_path = root_path / dir_name relative_path = dir_path.relative_to(validated_path) - - entry = _create_directory_entry(dir_path, str(relative_path), "directory") + + entry = _create_directory_entry( + dir_path, str(relative_path), "directory" + ) if entry: entries.append(entry) - + # Process files for file_name in files: - if not include_hidden and file_name.startswith('.'): + if not include_hidden and file_name.startswith("."): continue - + file_path = root_path / file_name relative_path = file_path.relative_to(validated_path) - + # Filter by file types if specified if file_types: file_ext = file_path.suffix.lower() if file_ext not in file_types: continue - - entry = _create_directory_entry(file_path, str(relative_path), "file") + + entry = _create_directory_entry( + file_path, str(relative_path), "file" + ) if entry: entries.append(entry) else: # Non-recursive listing for item in validated_path.iterdir(): - if not include_hidden and item.name.startswith('.'): + if not include_hidden and item.name.startswith("."): continue - + # Filter by file types if specified if file_types and item.is_file(): file_ext = item.suffix.lower() if file_ext not in file_types: continue - + entry_type = "directory" if item.is_dir() else "file" entry = _create_directory_entry(item, item.name, entry_type) if entry: entries.append(entry) - + except PermissionError as e: return { "success": False, "error": f"Permission denied: {str(e)}", - "directory_path": directory_path + "directory_path": directory_path, } - + list_time = (time.time() - start_time) * 1000 - + # Sort entries by type (directories first) then by name entries.sort(key=lambda x: (x["type"] != "directory", x["name"].lower())) - + # Calculate statistics file_count = sum(1 for entry in entries if entry["type"] == "file") dir_count = sum(1 for entry in entries if entry["type"] == "directory") - total_size = sum(entry.get("size_bytes", 0) for entry in entries if entry["type"] == "file") - + total_size = sum( + entry.get("size_bytes", 0) for entry in entries if entry["type"] == "file" + ) + result = { "success": True, "directory_path": directory_path, @@ -329,24 +338,26 @@ def list_directory_tool(directory_path: str = "", "include_hidden": include_hidden, "recursive": recursive, "file_types_filter": file_types, - "list_time_ms": list_time + "list_time_ms": list_time, } - - logger.info("Directory listed successfully", - directory_path=directory_path, - entry_count=len(entries), - list_time_ms=list_time) - + + logger.info( + "Directory listed successfully", + directory_path=directory_path, + entry_count=len(entries), + list_time_ms=list_time, + ) + return result - + except Exception as e: - logger.error("Directory listing failed", - directory_path=directory_path, - error=str(e)) + logger.error( + "Directory listing failed", directory_path=directory_path, error=str(e) + ) return { "success": False, "error": f"Failed to list directory: {str(e)}", - "directory_path": directory_path + "directory_path": directory_path, } @@ -358,36 +369,36 @@ def list_directory_tool(directory_path: str = "", "type": "string", "description": "Path to file or directory", "pattern": r"^[a-zA-Z0-9/._-]+$", - "maxLength": 500 + "maxLength": 500, } }, requires_auth=False, - timeout_seconds=10 + timeout_seconds=10, ) def get_file_info_tool(path: str) -> Dict[str, Any]: """ Get detailed metadata about a file or directory. - + Args: path: Path to file or directory - + Returns: File/directory metadata """ try: # Security validation validated_path = _validate_file_path_security(path, operation="info") - + if not validated_path.exists(): return { "success": False, "error": f"Path does not exist: {path}", - "path": path + "path": path, } - + # Get file/directory stats stat = validated_path.stat() - + result = { "success": True, "path": path, @@ -400,24 +411,46 @@ def get_file_info_tool(path: str) -> Dict[str, Any]: "permissions": oct(stat.st_mode)[-3:], "is_readable": os.access(validated_path, os.R_OK), "is_writable": os.access(validated_path, os.W_OK), - "is_executable": os.access(validated_path, os.X_OK) + "is_executable": os.access(validated_path, os.X_OK), } - + # Add file-specific metadata if validated_path.is_file(): result["extension"] = validated_path.suffix result["stem"] = validated_path.stem - + # Try to detect file type - if validated_path.suffix.lower() in ['.txt', '.md', '.py', '.js', '.json', '.xml', '.html', '.css']: + if validated_path.suffix.lower() in [ + ".txt", + ".md", + ".py", + ".js", + ".json", + ".xml", + ".html", + ".css", + ]: result["file_category"] = "text" - elif validated_path.suffix.lower() in ['.jpg', '.jpeg', '.png', '.gif', '.bmp', '.svg']: + elif validated_path.suffix.lower() in [ + ".jpg", + ".jpeg", + ".png", + ".gif", + ".bmp", + ".svg", + ]: result["file_category"] = "image" - elif validated_path.suffix.lower() in ['.pdf', '.doc', '.docx', '.xls', '.xlsx']: + elif validated_path.suffix.lower() in [ + ".pdf", + ".doc", + ".docx", + ".xls", + ".xlsx", + ]: result["file_category"] = "document" else: result["file_category"] = "other" - + # Add directory-specific metadata elif validated_path.is_dir(): try: @@ -429,36 +462,36 @@ def get_file_info_tool(path: str) -> Dict[str, Any]: except PermissionError: result["item_count"] = None result["access_denied"] = True - - logger.debug("File info retrieved successfully", - path=path, - type=result["type"], - size_bytes=result["size_bytes"]) - + + logger.debug( + "File info retrieved successfully", + path=path, + type=result["type"], + size_bytes=result["size_bytes"], + ) + return result - + except Exception as e: - logger.error("Failed to get file info", - path=path, - error=str(e)) + logger.error("Failed to get file info", path=path, error=str(e)) return { "success": False, "error": f"Failed to get file info: {str(e)}", - "path": path + "path": path, } def _validate_file_path_security(file_path: str, operation: str = "read") -> Path: """ Validate file path for security constraints. - + Args: file_path: File path to validate operation: Operation being performed - + Returns: Validated Path object - + Raises: SecurityError: If path is not allowed """ @@ -466,14 +499,14 @@ def _validate_file_path_security(file_path: str, operation: str = "read") -> Pat # Normalize path if not file_path: file_path = "." - + path = Path(file_path).resolve() - + # Check for directory traversal attempts path_str = str(path) if ".." in file_path or "~" in file_path: raise SecurityError("Directory traversal not allowed") - + # Check against allowed base paths is_allowed = False for base_path in ALLOWED_BASE_PATHS: @@ -485,7 +518,7 @@ def _validate_file_path_security(file_path: str, operation: str = "read") -> Pat except (ValueError, OSError): # Handle cases where base path doesn't exist or is invalid continue - + if not is_allowed: # For relative paths, check if they're within current working directory's allowed subdirectories cwd = Path.cwd() @@ -497,25 +530,31 @@ def _validate_file_path_security(file_path: str, operation: str = "read") -> Pat break except ValueError: continue - + if not is_allowed: - raise SecurityError(f"File access not allowed outside permitted directories: {file_path}") - + raise SecurityError( + f"File access not allowed outside permitted directories: {file_path}" + ) + # Check for dangerous file extensions dangerous_extensions = [".exe", ".bat", ".sh", ".ps1", ".cmd", ".scr", ".msi"] if path.suffix.lower() in dangerous_extensions: - raise SecurityError(f"Access to executable files not allowed: {path.suffix}") - + raise SecurityError( + f"Access to executable files not allowed: {path.suffix}" + ) + # Check path depth to prevent excessive nesting try: parts = path.parts if len(parts) > MAX_DIRECTORY_DEPTH: - raise SecurityError(f"Path too deep: {len(parts)} levels (max: {MAX_DIRECTORY_DEPTH})") + raise SecurityError( + f"Path too deep: {len(parts)} levels (max: {MAX_DIRECTORY_DEPTH})" + ) except Exception: pass - + return path - + except SecurityError: raise except Exception as e: @@ -527,42 +566,42 @@ def _validate_directory_path_security(directory_path: str) -> Path: return _validate_file_path_security(directory_path, operation="list") -def _create_directory_entry(path: Path, relative_path: str, entry_type: str) -> Optional[Dict[str, Any]]: +def _create_directory_entry( + path: Path, relative_path: str, entry_type: str +) -> Optional[Dict[str, Any]]: """ Create directory entry with metadata. - + Args: path: Full path to the item relative_path: Relative path for display entry_type: Type of entry ("file" or "directory") - + Returns: Directory entry dictionary or None if error """ try: stat = path.stat() - + entry = { "name": path.name, "path": relative_path, "type": entry_type, "modified": stat.st_mtime, - "permissions": oct(stat.st_mode)[-3:] + "permissions": oct(stat.st_mode)[-3:], } - + if entry_type == "file": entry["size_bytes"] = stat.st_size entry["extension"] = path.suffix - + return entry - + except (OSError, PermissionError) as e: - logger.warning("Failed to get entry metadata", - path=str(path), - error=str(e)) + logger.warning("Failed to get entry metadata", path=str(path), error=str(e)) return { "name": path.name, "path": relative_path, "type": entry_type, - "error": "Access denied" - } \ No newline at end of file + "error": "Access denied", + } diff --git a/src/tools/connectors/http.py b/src/tools/connectors/http.py index a03f26f..f712e36 100644 --- a/src/tools/connectors/http.py +++ b/src/tools/connectors/http.py @@ -15,6 +15,7 @@ # Import aiohttp for async HTTP requests try: import aiohttp + HTTP_AVAILABLE = True except ImportError: HTTP_AVAILABLE = False @@ -23,7 +24,6 @@ from tools.decorators import Tool from ...core.errors import ToolExecutionError, SecurityError, ValidationError - logger = structlog.get_logger(__name__) @@ -35,50 +35,52 @@ "type": "string", "description": "URL to request (must be HTTPS for external APIs)", "pattern": r"^https?://.*", - "maxLength": 2000 + "maxLength": 2000, }, "method": { "type": "string", "description": "HTTP method", "enum": ["GET", "POST", "PUT", "DELETE"], - "default": "GET" + "default": "GET", }, "headers": { "type": "object", "description": "Request headers", "default": {}, - "additionalProperties": {"type": "string"} + "additionalProperties": {"type": "string"}, }, "data": { "type": "object", "description": "Request body data for POST/PUT requests", - "default": {} + "default": {}, }, "timeout": { "type": "number", "description": "Request timeout in seconds", "default": 30, "minimum": 1, - "maximum": 60 + "maximum": 60, }, "follow_redirects": { "type": "boolean", "description": "Whether to follow HTTP redirects", - "default": True - } + "default": True, + }, }, requires_auth=False, - timeout_seconds=60 + timeout_seconds=60, ) -async def http_request_tool(url: str, - method: str = "GET", - headers: Dict[str, str] = None, - data: Dict[str, Any] = None, - timeout: int = 30, - follow_redirects: bool = True) -> Dict[str, Any]: +async def http_request_tool( + url: str, + method: str = "GET", + headers: Dict[str, str] = None, + data: Dict[str, Any] = None, + timeout: int = 30, + follow_redirects: bool = True, +) -> Dict[str, Any]: """ Make HTTP requests to external APIs. - + Args: url: URL to request method: HTTP method @@ -86,75 +88,79 @@ async def http_request_tool(url: str, data: Request body data timeout: Request timeout follow_redirects: Whether to follow redirects - + Returns: HTTP response data - + Raises: ToolExecutionError: If request fails SecurityError: If URL is not allowed """ if not HTTP_AVAILABLE: - raise ToolExecutionError("aiohttp library not available. Install with: pip install aiohttp") - + raise ToolExecutionError( + "aiohttp library not available. Install with: pip install aiohttp" + ) + start_time = time.time() request_id = f"req_{int(start_time * 1000)}" - + # Security validation _validate_url_security(url) - + # Prepare headers if headers is None: headers = {} - + # Add default headers default_headers = { "User-Agent": "FACT-System/1.0", "Accept": "application/json, text/plain, */*", - "Content-Type": "application/json" + "Content-Type": "application/json", } - + # Merge headers (user headers override defaults) request_headers = {**default_headers, **headers} - + try: - logger.info("Making HTTP request", - request_id=request_id, - url=url, - method=method, - timeout=timeout) - + logger.info( + "Making HTTP request", + request_id=request_id, + url=url, + method=method, + timeout=timeout, + ) + # Create HTTP session connector = aiohttp.TCPConnector( - limit=10, - limit_per_host=5, - ttl_dns_cache=300, - use_dns_cache=True + limit=10, limit_per_host=5, ttl_dns_cache=300, use_dns_cache=True ) - + async with aiohttp.ClientSession( connector=connector, timeout=aiohttp.ClientTimeout(total=timeout), - headers={"User-Agent": request_headers["User-Agent"]} + headers={"User-Agent": request_headers["User-Agent"]}, ) as session: - + # Prepare request kwargs request_kwargs = { "headers": request_headers, - "allow_redirects": follow_redirects + "allow_redirects": follow_redirects, } - + # Add data for POST/PUT requests if method in ["POST", "PUT"] and data: request_kwargs["json"] = data - + # Make request async with session.request(method, url, **request_kwargs) as response: response_time = (time.time() - start_time) * 1000 - + # Read response content try: - if response.content_type and "application/json" in response.content_type: + if ( + response.content_type + and "application/json" in response.content_type + ): response_data = await response.json() else: response_text = await response.text() @@ -166,7 +172,7 @@ async def http_request_tool(url: str, except Exception as e: logger.warning("Failed to parse response content", error=str(e)) response_data = await response.text() - + # Prepare result result = { "request_id": request_id, @@ -177,36 +183,42 @@ async def http_request_tool(url: str, "response_time_ms": response_time, "content_length": len(str(response_data)), "content_type": response.content_type, - "data": response_data + "data": response_data, } - + # Check for HTTP errors if response.status >= 400: result["success"] = False result["error"] = f"HTTP {response.status} error" - logger.warning("HTTP request returned error status", - request_id=request_id, - status_code=response.status, - url=url) + logger.warning( + "HTTP request returned error status", + request_id=request_id, + status_code=response.status, + url=url, + ) else: result["success"] = True - - logger.info("HTTP request completed", - request_id=request_id, - status_code=response.status, - response_time_ms=response_time, - content_length=result["content_length"]) - + + logger.info( + "HTTP request completed", + request_id=request_id, + status_code=response.status, + response_time_ms=response_time, + content_length=result["content_length"], + ) + return result - + except asyncio.TimeoutError: response_time = (time.time() - start_time) * 1000 - logger.error("HTTP request timed out", - request_id=request_id, - url=url, - timeout=timeout, - response_time_ms=response_time) - + logger.error( + "HTTP request timed out", + request_id=request_id, + url=url, + timeout=timeout, + response_time_ms=response_time, + ) + return { "request_id": request_id, "url": url, @@ -214,17 +226,19 @@ async def http_request_tool(url: str, "success": False, "error": f"Request timed out after {timeout} seconds", "response_time_ms": response_time, - "status": "timeout" + "status": "timeout", } - + except Exception as e: response_time = (time.time() - start_time) * 1000 - logger.error("HTTP request failed", - request_id=request_id, - url=url, - error=str(e), - response_time_ms=response_time) - + logger.error( + "HTTP request failed", + request_id=request_id, + url=url, + error=str(e), + response_time_ms=response_time, + ) + return { "request_id": request_id, "url": url, @@ -232,7 +246,7 @@ async def http_request_tool(url: str, "success": False, "error": f"Request failed: {str(e)}", "response_time_ms": response_time, - "status": "error" + "status": "error", } @@ -243,46 +257,46 @@ async def http_request_tool(url: str, "url": { "type": "string", "description": "URL to check", - "pattern": r"^https?://.*" + "pattern": r"^https?://.*", }, "timeout": { "type": "number", "description": "Timeout in seconds", "default": 10, "minimum": 1, - "maximum": 30 - } + "maximum": 30, + }, }, requires_auth=False, - timeout_seconds=30 + timeout_seconds=30, ) async def url_health_check_tool(url: str, timeout: int = 10) -> Dict[str, Any]: """ Check the health and availability of a URL endpoint. - + Args: url: URL to check timeout: Request timeout in seconds - + Returns: Health check results """ start_time = time.time() - + # Security validation _validate_url_security(url) - + try: # Make HEAD request for efficiency result = await http_request_tool( url=url, method="GET", # Some servers don't support HEAD timeout=timeout, - headers={"Accept": "text/html,application/json,*/*"} + headers={"Accept": "text/html,application/json,*/*"}, ) - + response_time = (time.time() - start_time) * 1000 - + health_result = { "url": url, "available": result.get("success", False), @@ -291,86 +305,99 @@ async def url_health_check_tool(url: str, timeout: int = 10) -> Dict[str, Any]: "content_type": result.get("content_type"), "server": result.get("headers", {}).get("server"), "last_modified": result.get("headers", {}).get("last-modified"), - "content_length": result.get("content_length") + "content_length": result.get("content_length"), } - + if not result.get("success", False): health_result["error"] = result.get("error") - - logger.info("URL health check completed", - url=url, - available=health_result["available"], - response_time_ms=response_time) - + + logger.info( + "URL health check completed", + url=url, + available=health_result["available"], + response_time_ms=response_time, + ) + return health_result - + except Exception as e: response_time = (time.time() - start_time) * 1000 - logger.error("URL health check failed", - url=url, - error=str(e), - response_time_ms=response_time) - + logger.error( + "URL health check failed", + url=url, + error=str(e), + response_time_ms=response_time, + ) + return { "url": url, "available": False, "error": str(e), - "response_time_ms": response_time + "response_time_ms": response_time, } def _validate_url_security(url: str) -> None: """ Validate URL for security constraints. - + Args: url: URL to validate - + Raises: SecurityError: If URL is not allowed """ try: parsed_url = urlparse(url) - + # Only allow HTTP and HTTPS if parsed_url.scheme not in ["http", "https"]: raise SecurityError("Only HTTP and HTTPS URLs are allowed") - + # Block localhost and private IP ranges hostname = parsed_url.hostname if not hostname: raise SecurityError("Invalid hostname in URL") - + hostname_lower = hostname.lower() - + # Block localhost variations localhost_patterns = [ - "localhost", "127.0.0.1", "0.0.0.0", "::1", - "metadata.google.internal", "169.254.169.254" + "localhost", + "127.0.0.1", + "0.0.0.0", + "::1", + "metadata.google.internal", + "169.254.169.254", ] - + if hostname_lower in localhost_patterns: raise SecurityError("Access to localhost/internal URLs not allowed") - + # Block private IP ranges (basic check) if hostname_lower.startswith(("10.", "172.", "192.168.")): raise SecurityError("Access to private IP ranges not allowed") - + # Check for suspicious patterns suspicious_patterns = [ - "file://", "ftp://", "javascript:", "data:", - "gopher://", "ldap://", "dict://" + "file://", + "ftp://", + "javascript:", + "data:", + "gopher://", + "ldap://", + "dict://", ] - + url_lower = url.lower() for pattern in suspicious_patterns: if pattern in url_lower: raise SecurityError(f"URL contains suspicious pattern: {pattern}") - + # Length check if len(url) > 2000: raise SecurityError("URL too long") - + except SecurityError: raise except Exception as e: @@ -379,6 +406,7 @@ def _validate_url_security(url: str) -> None: # Additional HTTP utility tools + @Tool( name="Web.ParseJSON", description="Parse and validate JSON data from text", @@ -386,39 +414,39 @@ def _validate_url_security(url: str) -> None: "json_text": { "type": "string", "description": "JSON text to parse", - "maxLength": 100000 + "maxLength": 100000, }, "validate_schema": { "type": "boolean", "description": "Whether to validate against a schema", - "default": False - } + "default": False, + }, }, requires_auth=False, - timeout_seconds=10 + timeout_seconds=10, ) def parse_json_tool(json_text: str, validate_schema: bool = False) -> Dict[str, Any]: """ Parse and validate JSON data from text. - + Args: json_text: JSON text to parse validate_schema: Whether to validate schema - + Returns: Parsed JSON data with validation results """ try: # Parse JSON parsed_data = json.loads(json_text) - + result = { "success": True, "data": parsed_data, "type": type(parsed_data).__name__, - "size_bytes": len(json_text) + "size_bytes": len(json_text), } - + # Basic structure analysis if isinstance(parsed_data, dict): result["keys"] = list(parsed_data.keys()) @@ -427,34 +455,38 @@ def parse_json_tool(json_text: str, validate_schema: bool = False) -> Dict[str, result["item_count"] = len(parsed_data) if parsed_data and isinstance(parsed_data[0], dict): result["sample_keys"] = list(parsed_data[0].keys()) - - logger.debug("JSON parsed successfully", - type=result["type"], - size_bytes=result["size_bytes"]) - + + logger.debug( + "JSON parsed successfully", + type=result["type"], + size_bytes=result["size_bytes"], + ) + return result - + except json.JSONDecodeError as e: - logger.warning("JSON parsing failed", - error=str(e), - line=getattr(e, 'lineno', None), - column=getattr(e, 'colno', None)) - + logger.warning( + "JSON parsing failed", + error=str(e), + line=getattr(e, "lineno", None), + column=getattr(e, "colno", None), + ) + return { "success": False, "error": f"JSON parsing failed: {str(e)}", - "line": getattr(e, 'lineno', None), - "column": getattr(e, 'colno', None), - "size_bytes": len(json_text) + "line": getattr(e, "lineno", None), + "column": getattr(e, "colno", None), + "size_bytes": len(json_text), } - + except Exception as e: logger.error("Unexpected error parsing JSON", error=str(e)) - + return { "success": False, "error": f"Unexpected error: {str(e)}", - "size_bytes": len(json_text) + "size_bytes": len(json_text), } @@ -465,42 +497,42 @@ def parse_json_tool(json_text: str, validate_schema: bool = False) -> Dict[str, "text": { "type": "string", "description": "Text content to extract URLs from", - "maxLength": 50000 + "maxLength": 50000, }, "schemes": { "type": "array", "description": "URL schemes to extract", "items": {"type": "string"}, - "default": ["http", "https"] - } + "default": ["http", "https"], + }, }, requires_auth=False, - timeout_seconds=10 + timeout_seconds=10, ) def extract_urls_tool(text: str, schemes: List[str] = None) -> Dict[str, Any]: """ Extract URLs from text content. - + Args: text: Text content to extract URLs from schemes: URL schemes to look for - + Returns: Extracted URLs with metadata """ import re - + if schemes is None: schemes = ["http", "https"] - + try: # Create regex pattern for URL extraction scheme_pattern = "|".join(re.escape(scheme) for scheme in schemes) url_pattern = rf'\b(?:{scheme_pattern})://[^\s<>"{{}}|\\^`\[\]]*' - + # Find all URLs urls = re.findall(url_pattern, text, re.IGNORECASE) - + # Remove duplicates while preserving order unique_urls = [] seen = set() @@ -508,7 +540,7 @@ def extract_urls_tool(text: str, schemes: List[str] = None) -> Dict[str, Any]: if url not in seen: unique_urls.append(url) seen.add(url) - + # Analyze URLs url_analysis = [] for url in unique_urls: @@ -520,37 +552,42 @@ def extract_urls_tool(text: str, schemes: List[str] = None) -> Dict[str, Any]: "hostname": parsed.hostname, "path": parsed.path, "has_query": bool(parsed.query), - "has_fragment": bool(parsed.fragment) + "has_fragment": bool(parsed.fragment), } url_analysis.append(analysis) except Exception as e: logger.warning("Failed to parse extracted URL", url=url, error=str(e)) - url_analysis.append({ - "url": url, - "error": f"Parse error: {str(e)}" - }) - + url_analysis.append({"url": url, "error": f"Parse error: {str(e)}"}) + result = { "success": True, "urls": unique_urls, "url_count": len(unique_urls), "total_matches": len(urls), - "schemes_found": list(set(analysis.get("scheme") for analysis in url_analysis if "scheme" in analysis)), + "schemes_found": list( + set( + analysis.get("scheme") + for analysis in url_analysis + if "scheme" in analysis + ) + ), "analysis": url_analysis, - "text_length": len(text) + "text_length": len(text), } - - logger.debug("URLs extracted successfully", - url_count=result["url_count"], - text_length=len(text)) - + + logger.debug( + "URLs extracted successfully", + url_count=result["url_count"], + text_length=len(text), + ) + return result - + except Exception as e: logger.error("URL extraction failed", error=str(e)) - + return { "success": False, "error": f"URL extraction failed: {str(e)}", - "text_length": len(text) - } \ No newline at end of file + "text_length": len(text), + } diff --git a/src/tools/connectors/sql.py b/src/tools/connectors/sql.py index ef11dcf..34cd295 100644 --- a/src/tools/connectors/sql.py +++ b/src/tools/connectors/sql.py @@ -18,11 +18,12 @@ # Fall back to absolute imports (when run as script) import sys from pathlib import Path + # Add src to path if not already there src_path = str(Path(__file__).parent.parent.parent) if src_path not in sys.path: sys.path.insert(0, src_path) - + from core.errors import DatabaseError, SecurityError, InvalidSQLError from db.connection import DatabaseManager from tools.decorators import Tool @@ -34,29 +35,29 @@ class SQLQueryTool: """ SQL query tool implementation with security validation and connection management. - + Provides secure read-only database access through validated SQL queries. """ - + def __init__(self, database_manager: DatabaseManager): """ Initialize SQL query tool. - + Args: database_manager: DatabaseManager instance for query execution """ self.database_manager = database_manager - + async def execute_query(self, statement: str) -> Dict[str, Any]: """ Execute a validated SQL query and return structured results. - + Args: statement: SQL SELECT statement to execute - + Returns: Dictionary containing query results and metadata - + Raises: SecurityError: If statement violates security rules InvalidSQLError: If statement has syntax errors @@ -64,11 +65,11 @@ async def execute_query(self, statement: str) -> Dict[str, Any]: """ query_id = f"query_{int(time.time() * 1000)}" start_time = time.time() - + try: # Execute query through database manager (includes validation) result = await self.database_manager.execute_query(statement) - + # Format response response = { "query_id": query_id, @@ -76,34 +77,42 @@ async def execute_query(self, statement: str) -> Dict[str, Any]: "row_count": result.row_count, "columns": result.columns, "execution_time_ms": result.execution_time_ms, - "statement": statement[:100] + "..." if len(statement) > 100 else statement, - "status": "success" + "statement": ( + statement[:100] + "..." if len(statement) > 100 else statement + ), + "status": "success", } - - logger.info("SQL query executed successfully", - query_id=query_id, - row_count=result.row_count, - execution_time_ms=result.execution_time_ms) - + + logger.info( + "SQL query executed successfully", + query_id=query_id, + row_count=result.row_count, + execution_time_ms=result.execution_time_ms, + ) + return response - + except (SecurityError, InvalidSQLError, DatabaseError) as e: execution_time = (time.time() - start_time) * 1000 - - logger.error("SQL query failed", - query_id=query_id, - error=str(e), - error_type=type(e).__name__, - execution_time_ms=execution_time) - + + logger.error( + "SQL query failed", + query_id=query_id, + error=str(e), + error_type=type(e).__name__, + execution_time_ms=execution_time, + ) + # Return error response return { "query_id": query_id, "error": str(e), "error_type": type(e).__name__, "execution_time_ms": execution_time, - "statement": statement[:100] + "..." if len(statement) > 100 else statement, - "status": "failed" + "statement": ( + statement[:100] + "..." if len(statement) > 100 else statement + ), + "status": "failed", } @@ -114,7 +123,7 @@ async def execute_query(self, statement: str) -> Dict[str, Any]: def initialize_sql_tool(database_manager: DatabaseManager) -> None: """ Initialize the global SQL tool instance. - + Args: database_manager: DatabaseManager instance to use """ @@ -131,27 +140,27 @@ def initialize_sql_tool(database_manager: DatabaseManager) -> None: "type": "string", "description": "SQL SELECT statement to execute. Must start with SELECT and cannot contain data modification operations (INSERT, UPDATE, DELETE, etc.). Example: 'SELECT * FROM companies WHERE sector = \"Technology\"'", "minLength": 10, - "maxLength": 1000 + "maxLength": 1000, } }, requires_auth=False, timeout_seconds=30, - version="1.0.0" + version="1.0.0", ) async def sql_query_readonly(statement: str) -> Dict[str, Any]: """ Execute a read-only SQL query on the finance database. - + This tool provides secure access to the financial database using validated SELECT statements. It prevents data modification and enforces security constraints to ensure safe database access. - + Args: statement: SQL SELECT statement to execute - + Returns: Dictionary containing query results, metadata, and execution statistics - + Example: statement = "SELECT name, revenue FROM companies c JOIN financial_records f ON c.id = f.company_id WHERE f.year = 2025" result = await sql_query_readonly(statement) @@ -167,9 +176,9 @@ async def sql_query_readonly(statement: str) -> Dict[str, Any]: return { "error": "SQL tool not initialized", "status": "failed", - "execution_time_ms": 0 + "execution_time_ms": 0, } - + return await _sql_tool_instance.execute_query(statement) @@ -179,12 +188,12 @@ async def sql_query_readonly(statement: str) -> Dict[str, Any]: parameters={}, requires_auth=False, timeout_seconds=10, - version="1.0.0" + version="1.0.0", ) async def sql_get_schema() -> Dict[str, Any]: """ Get database schema information for query construction assistance. - + Returns: Dictionary containing database schema information """ @@ -192,9 +201,9 @@ async def sql_get_schema() -> Dict[str, Any]: return { "error": "SQL tool not initialized", "status": "failed", - "execution_time_ms": 0 + "execution_time_ms": 0, } - + try: # Get table information tables_query = """ @@ -203,63 +212,64 @@ async def sql_get_schema() -> Dict[str, Any]: WHERE type='table' AND name NOT LIKE 'sqlite_%' ORDER BY name """ - + tables_result = await _sql_tool_instance.execute_query(tables_query) - + schema_info = { "tables": [], "total_tables": len(tables_result["rows"]), "database_type": "SQLite", - "status": "success" + "status": "success", } - + # Get column information for each table for table_row in tables_result["rows"]: table_name = table_row["table_name"] - + # Validate table name before using in PRAGMA if not _sql_tool_instance.database_manager._is_valid_table_name(table_name): - logger.warning("Invalid table name in schema query", table_name=table_name) + logger.warning( + "Invalid table name in schema query", table_name=table_name + ) continue - + # Use safe table name in PRAGMA (cannot be parameterized) # PRAGMA queries are read-only and table name is validated columns_query = f'PRAGMA table_info("{table_name}")' - + try: columns_result = await _sql_tool_instance.execute_query(columns_query) - - table_info = { - "name": table_name, - "columns": [] - } - + + table_info = {"name": table_name, "columns": []} + for col_row in columns_result["rows"]: column_info = { "name": col_row["name"], "type": col_row["type"], "nullable": not col_row["notnull"], - "primary_key": bool(col_row["pk"]) + "primary_key": bool(col_row["pk"]), } table_info["columns"].append(column_info) - + schema_info["tables"].append(table_info) - + except Exception as e: - logger.error("Failed to get column info for table", - table_name=table_name, - error=str(e)) + logger.error( + "Failed to get column info for table", + table_name=table_name, + error=str(e), + ) # Continue with other tables continue - + return schema_info - + except Exception as e: logger.error("Failed to get schema information", error=str(e)) return { "error": f"Failed to get schema: {e}", "status": "failed", - "execution_time_ms": 0 + "execution_time_ms": 0, } @@ -269,65 +279,65 @@ async def sql_get_schema() -> Dict[str, Any]: parameters={}, requires_auth=False, timeout_seconds=5, - version="1.0.0" + version="1.0.0", ) async def sql_get_sample_queries() -> Dict[str, Any]: """ Get sample SQL queries that can be used to explore the financial database. - + Returns: Dictionary containing sample queries with descriptions """ sample_queries = [ { "description": "Get all companies in the Technology sector", - "query": "SELECT * FROM companies WHERE sector = 'Technology'" + "query": "SELECT * FROM companies WHERE sector = 'Technology'", }, { "description": "Get total revenue by company for 2024", - "query": "SELECT c.name, SUM(f.revenue) as total_revenue FROM companies c JOIN financial_records f ON c.id = f.company_id WHERE f.year = 2024 GROUP BY c.id, c.name ORDER BY total_revenue DESC" + "query": "SELECT c.name, SUM(f.revenue) as total_revenue FROM companies c JOIN financial_records f ON c.id = f.company_id WHERE f.year = 2024 GROUP BY c.id, c.name ORDER BY total_revenue DESC", }, { "description": "Get Q1 2025 financial results", - "query": "SELECT c.name, f.revenue, f.profit, f.expenses FROM companies c JOIN financial_records f ON c.id = f.company_id WHERE f.quarter = 'Q1' AND f.year = 2025 ORDER BY f.revenue DESC" + "query": "SELECT c.name, f.revenue, f.profit, f.expenses FROM companies c JOIN financial_records f ON c.id = f.company_id WHERE f.quarter = 'Q1' AND f.year = 2025 ORDER BY f.revenue DESC", }, { "description": "Get company count by sector", - "query": "SELECT sector, COUNT(*) as company_count FROM companies GROUP BY sector ORDER BY company_count DESC" + "query": "SELECT sector, COUNT(*) as company_count FROM companies GROUP BY sector ORDER BY company_count DESC", }, { "description": "Get TechCorp's quarterly performance over time", - "query": "SELECT c.name, f.quarter, f.year, f.revenue, f.profit FROM companies c JOIN financial_records f ON c.id = f.company_id WHERE c.symbol = 'TECH' ORDER BY f.year DESC, f.quarter DESC" + "query": "SELECT c.name, f.quarter, f.year, f.revenue, f.profit FROM companies c JOIN financial_records f ON c.id = f.company_id WHERE c.symbol = 'TECH' ORDER BY f.year DESC, f.quarter DESC", }, { "description": "Get average metrics for 2024", - "query": "SELECT AVG(revenue) as avg_revenue, AVG(profit) as avg_profit, AVG(expenses) as avg_expenses FROM financial_records WHERE year = 2024" + "query": "SELECT AVG(revenue) as avg_revenue, AVG(profit) as avg_profit, AVG(expenses) as avg_expenses FROM financial_records WHERE year = 2024", }, { "description": "Get top companies by market cap with latest revenue", - "query": "SELECT c.name, c.market_cap, f.revenue as q1_2025_revenue FROM companies c JOIN financial_records f ON c.id = f.company_id WHERE f.year = 2025 AND f.quarter = 'Q1' ORDER BY c.market_cap DESC" - } + "query": "SELECT c.name, c.market_cap, f.revenue as q1_2025_revenue FROM companies c JOIN financial_records f ON c.id = f.company_id WHERE f.year = 2025 AND f.quarter = 'Q1' ORDER BY c.market_cap DESC", + }, ] - + return { "sample_queries": sample_queries, "total_queries": len(sample_queries), "status": "success", - "execution_time_ms": 0 + "execution_time_ms": 0, } def get_sql_tool() -> SQLQueryTool: """ Get the global SQL tool instance. - + Returns: Global SQLQueryTool instance - + Raises: RuntimeError: If SQL tool is not initialized """ if _sql_tool_instance is None: raise RuntimeError("SQL tool not initialized. Call initialize_sql_tool first.") - - return _sql_tool_instance \ No newline at end of file + + return _sql_tool_instance diff --git a/src/tools/decorators.py b/src/tools/decorators.py index 98885cc..f6516b5 100644 --- a/src/tools/decorators.py +++ b/src/tools/decorators.py @@ -18,22 +18,23 @@ ToolValidationError, ToolExecutionError, ValidationError, - InvalidArgumentsError + InvalidArgumentsError, ) except ImportError: # Fall back to absolute imports (when run as script) import sys from pathlib import Path + # Add src to path if not already there src_path = str(Path(__file__).parent.parent) if src_path not in sys.path: sys.path.insert(0, src_path) - + from core.errors import ( ToolValidationError, ToolExecutionError, ValidationError, - InvalidArgumentsError + InvalidArgumentsError, ) @@ -43,6 +44,7 @@ @dataclass class ToolDefinition: """Represents a tool definition with metadata and validation.""" + name: str description: str parameters: Dict[str, Any] @@ -56,98 +58,102 @@ class ToolDefinition: class ToolRegistry: """ Central registry for managing tool definitions and schemas. - + Maintains a consistent state of registered tools and provides schema export capabilities for Claude integration. """ - + def __init__(self): """Initialize empty tool registry.""" self.tools: Dict[str, ToolDefinition] = {} self.schemas: Dict[str, Dict[str, Any]] = {} self.last_updated = time.time() - + def register_tool(self, tool_definition: ToolDefinition) -> None: """ Register a tool definition in the registry. - + Args: tool_definition: Tool definition to register - + Raises: ToolValidationError: If tool definition is invalid """ tool_name = tool_definition.name - + # Validate tool definition self._validate_tool_definition(tool_definition) - + # Check for duplicates (allow version updates) if tool_name in self.tools: existing_version = self.tools[tool_name].version new_version = tool_definition.version - + if not self._is_newer_version(new_version, existing_version): - logger.warning("Tool registration skipped - same or older version", - tool_name=tool_name, - existing_version=existing_version, - new_version=new_version) + logger.warning( + "Tool registration skipped - same or older version", + tool_name=tool_name, + existing_version=existing_version, + new_version=new_version, + ) return - + # Register tool self.tools[tool_name] = tool_definition self.schemas[tool_name] = self._extract_schema(tool_definition) self.last_updated = time.time() - - logger.info("Tool registered successfully", - tool_name=tool_name, - version=tool_definition.version) - + + logger.info( + "Tool registered successfully", + tool_name=tool_name, + version=tool_definition.version, + ) + def get_tool(self, tool_name: str) -> ToolDefinition: """ Get a tool definition by name. - + Args: tool_name: Name of the tool to retrieve - + Returns: ToolDefinition for the requested tool - + Raises: ToolNotFoundError: If tool is not found """ if tool_name not in self.tools: raise ToolValidationError(f"Tool not found: {tool_name}") - + return self.tools[tool_name] - + def export_all_schemas(self) -> List[Dict[str, Any]]: """ Export all tool schemas in Claude-compatible format. - + Returns: List of tool schemas for Claude API """ schema_list = [] for tool_name in self.tools: schema_list.append(self.schemas[tool_name]) - + logger.debug("Exported tool schemas", count=len(schema_list)) return schema_list - + def list_tools(self) -> List[str]: """ Get list of all registered tool names. - + Returns: List of tool names """ return list(self.tools.keys()) - + def get_tool_info(self) -> Dict[str, Any]: """ Get registry information and statistics. - + Returns: Dictionary containing registry metadata """ @@ -155,47 +161,51 @@ def get_tool_info(self) -> Dict[str, Any]: "total_tools": len(self.tools), "tool_names": list(self.tools.keys()), "last_updated": self.last_updated, - "schema_count": len(self.schemas) + "schema_count": len(self.schemas), } - + def _validate_tool_definition(self, tool_definition: ToolDefinition) -> None: """ Validate tool definition structure and constraints. - + Args: tool_definition: Tool definition to validate - + Raises: ToolValidationError: If validation fails """ # Validate name if not tool_definition.name or not tool_definition.name.strip(): raise ToolValidationError("Tool name cannot be empty") - + if not self._follows_naming_convention(tool_definition.name): - raise ToolValidationError(f"Tool name '{tool_definition.name}' does not follow naming convention") - + raise ToolValidationError( + f"Tool name '{tool_definition.name}' does not follow naming convention" + ) + # Validate description if not tool_definition.description or not tool_definition.description.strip(): raise ToolValidationError("Tool description cannot be empty") - + # Validate parameters schema if not isinstance(tool_definition.parameters, dict): raise ToolValidationError("Tool parameters must be a dictionary") - + # Validate function is callable if not callable(tool_definition.function): raise ToolValidationError("Tool function must be callable") - - logger.debug("Tool definition validation passed", tool_name=tool_definition.name) - + + logger.debug( + "Tool definition validation passed", tool_name=tool_definition.name + ) + def _follows_naming_convention(self, name: str) -> bool: """ Check if tool name follows the expected naming convention. - + Args: name: Tool name to check - + Returns: True if name follows convention, False otherwise """ @@ -203,46 +213,48 @@ def _follows_naming_convention(self, name: str) -> bool: # Anthropic API requires names to match ^[a-zA-Z0-9_-]{1,64}$ if "_" not in name: return False - + parts = name.split("_", 1) # Split on first underscore only if len(parts) != 2: return False - + category, action = parts - + # Validation for Anthropic API compatibility import re - pattern = r'^[a-zA-Z0-9_-]{1,64}$' - + + pattern = r"^[a-zA-Z0-9_-]{1,64}$" + return bool(re.match(pattern, name) and len(name) <= 64) - + def _is_newer_version(self, new_version: str, existing_version: str) -> bool: """ Compare version strings to determine if new version is newer. - + Args: new_version: New version string existing_version: Existing version string - + Returns: True if new version is newer """ + def version_tuple(version: str) -> tuple: """Convert version string to comparable tuple.""" try: - return tuple(map(int, version.split('.'))) + return tuple(map(int, version.split("."))) except ValueError: return (0, 0, 0) # Default for invalid versions - + return version_tuple(new_version) > version_tuple(existing_version) - + def _extract_schema(self, tool_definition: ToolDefinition) -> Dict[str, Any]: """ Extract Claude-compatible schema from tool definition. - + Args: tool_definition: Tool definition to extract schema from - + Returns: Claude-compatible tool schema """ @@ -253,19 +265,19 @@ def _extract_schema(self, tool_definition: ToolDefinition) -> Dict[str, Any]: "input_schema": { "type": "object", "properties": tool_definition.parameters, - "required": self._extract_required_params(tool_definition.parameters) - } + "required": self._extract_required_params(tool_definition.parameters), + }, } - + return schema - + def _extract_required_params(self, parameters: Dict[str, Any]) -> List[str]: """ Extract required parameter names from parameters schema. - + Args: parameters: Parameters schema dictionary - + Returns: List of required parameter names """ @@ -275,7 +287,7 @@ def _extract_required_params(self, parameters: Dict[str, Any]) -> List[str]: # Check if parameter has no default and is not marked as optional if "default" not in param_schema and param_schema.get("required", True): required.append(param_name) - + return required @@ -283,15 +295,17 @@ def _extract_required_params(self, parameters: Dict[str, Any]) -> List[str]: _tool_registry = ToolRegistry() -def Tool(name: str, - description: str, - parameters: Dict[str, Any], - requires_auth: bool = False, - timeout_seconds: int = 30, - version: str = "1.0.0") -> Callable: +def Tool( + name: str, + description: str, + parameters: Dict[str, Any], + requires_auth: bool = False, + timeout_seconds: int = 30, + version: str = "1.0.0", +) -> Callable: """ Decorator for defining FACT system tools. - + Args: name: Tool name following Category.Action format description: Clear description of tool functionality @@ -299,10 +313,10 @@ def Tool(name: str, requires_auth: Whether tool requires user authorization timeout_seconds: Execution timeout in seconds version: Tool version string - + Returns: Decorated function with tool metadata and validation - + Example: @Tool( name="SQL.QueryReadonly", @@ -318,175 +332,197 @@ def execute_sql_query(statement: str) -> Dict[str, Any]: # Tool implementation pass """ + def decorator(tool_function: Callable) -> Callable: - import asyncio - import inspect - - # Determine if we need async or sync wrapper - is_async_tool = asyncio.iscoroutinefunction(tool_function) - - # Create the appropriate wrapper function - if is_async_tool: - @functools.wraps(tool_function) - async def async_wrapped_tool(*args, **kwargs) -> Dict[str, Any]: - """Wrapped async tool function with validation and error handling.""" - start_time = time.time() - - try: - # Validate input parameters against schema - if args or kwargs: - # Convert positional args to kwargs based on function signature - sig = inspect.signature(tool_function) - bound_args = sig.bind(*args, **kwargs) - bound_args.apply_defaults() - validated_kwargs = dict(bound_args.arguments) - - # Validate against parameter schema - validate_tool_parameters(validated_kwargs, parameters) - else: - validated_kwargs = {} - - # Execute async function - result = await tool_function(**validated_kwargs) - - # Validate output format - if not isinstance(result, dict): - result = {"result": result} - - # Ensure result is JSON serializable - json.dumps(result) - - execution_time = (time.time() - start_time) * 1000 - result["execution_time_ms"] = execution_time - result["status"] = "success" - - logger.debug("Tool executed successfully", - tool_name=name, - execution_time_ms=execution_time) - - return result - - except ValidationError as e: - execution_time = (time.time() - start_time) * 1000 - logger.error("Tool parameter validation failed", - tool_name=name, - error=str(e), - execution_time_ms=execution_time) - raise ToolValidationError(f"Parameter validation failed: {e}") - - except Exception as e: - execution_time = (time.time() - start_time) * 1000 - logger.error("Tool execution failed", - tool_name=name, - error=str(e), - execution_time_ms=execution_time) - raise ToolExecutionError(f"Tool execution failed: {e}") - - wrapped_tool = async_wrapped_tool - - else: - @functools.wraps(tool_function) - def sync_wrapped_tool(*args, **kwargs) -> Dict[str, Any]: - """Wrapped sync tool function with validation and error handling.""" - start_time = time.time() - - try: - # Validate input parameters against schema - if args or kwargs: - # Convert positional args to kwargs based on function signature - sig = inspect.signature(tool_function) - bound_args = sig.bind(*args, **kwargs) - bound_args.apply_defaults() - validated_kwargs = dict(bound_args.arguments) - - # Validate against parameter schema - validate_tool_parameters(validated_kwargs, parameters) - else: - validated_kwargs = {} - - # Execute sync function - result = tool_function(**validated_kwargs) - - # Validate output format - if not isinstance(result, dict): - result = {"result": result} - - # Ensure result is JSON serializable - json.dumps(result) - - execution_time = (time.time() - start_time) * 1000 - result["execution_time_ms"] = execution_time - result["status"] = "success" - - logger.debug("Tool executed successfully", - tool_name=name, - execution_time_ms=execution_time) - - return result - - except ValidationError as e: - execution_time = (time.time() - start_time) * 1000 - logger.error("Tool parameter validation failed", - tool_name=name, - error=str(e), - execution_time_ms=execution_time) - raise ToolValidationError(f"Parameter validation failed: {e}") - - except Exception as e: - execution_time = (time.time() - start_time) * 1000 - logger.error("Tool execution failed", - tool_name=name, - error=str(e), - execution_time_ms=execution_time) - raise ToolExecutionError(f"Tool execution failed: {e}") - - wrapped_tool = sync_wrapped_tool - - # Create tool definition - tool_definition = ToolDefinition( - name=name, - description=description, - parameters=parameters, - function=wrapped_tool, - created_at=time.time(), - version=version, - requires_auth=requires_auth, - timeout_seconds=timeout_seconds - ) - - # Attach metadata to function - wrapped_tool.tool_definition = tool_definition - - # Register tool automatically - _tool_registry.register_tool(tool_definition) - - return wrapped_tool + import asyncio + import inspect + + # Determine if we need async or sync wrapper + is_async_tool = asyncio.iscoroutinefunction(tool_function) + + # Create the appropriate wrapper function + if is_async_tool: + + @functools.wraps(tool_function) + async def async_wrapped_tool(*args, **kwargs) -> Dict[str, Any]: + """Wrapped async tool function with validation and error handling.""" + start_time = time.time() + + try: + # Validate input parameters against schema + if args or kwargs: + # Convert positional args to kwargs based on function signature + sig = inspect.signature(tool_function) + bound_args = sig.bind(*args, **kwargs) + bound_args.apply_defaults() + validated_kwargs = dict(bound_args.arguments) + + # Validate against parameter schema + validate_tool_parameters(validated_kwargs, parameters) + else: + validated_kwargs = {} + + # Execute async function + result = await tool_function(**validated_kwargs) + + # Validate output format + if not isinstance(result, dict): + result = {"result": result} + + # Ensure result is JSON serializable + json.dumps(result) + + execution_time = (time.time() - start_time) * 1000 + result["execution_time_ms"] = execution_time + result["status"] = "success" + + logger.debug( + "Tool executed successfully", + tool_name=name, + execution_time_ms=execution_time, + ) + + return result + + except ValidationError as e: + execution_time = (time.time() - start_time) * 1000 + logger.error( + "Tool parameter validation failed", + tool_name=name, + error=str(e), + execution_time_ms=execution_time, + ) + raise ToolValidationError(f"Parameter validation failed: {e}") + + except Exception as e: + execution_time = (time.time() - start_time) * 1000 + logger.error( + "Tool execution failed", + tool_name=name, + error=str(e), + execution_time_ms=execution_time, + ) + raise ToolExecutionError(f"Tool execution failed: {e}") + + wrapped_tool = async_wrapped_tool + + else: + + @functools.wraps(tool_function) + def sync_wrapped_tool(*args, **kwargs) -> Dict[str, Any]: + """Wrapped sync tool function with validation and error handling.""" + start_time = time.time() + + try: + # Validate input parameters against schema + if args or kwargs: + # Convert positional args to kwargs based on function signature + sig = inspect.signature(tool_function) + bound_args = sig.bind(*args, **kwargs) + bound_args.apply_defaults() + validated_kwargs = dict(bound_args.arguments) + + # Validate against parameter schema + validate_tool_parameters(validated_kwargs, parameters) + else: + validated_kwargs = {} + + # Execute sync function + result = tool_function(**validated_kwargs) + + # Validate output format + if not isinstance(result, dict): + result = {"result": result} + + # Ensure result is JSON serializable + json.dumps(result) + + execution_time = (time.time() - start_time) * 1000 + result["execution_time_ms"] = execution_time + result["status"] = "success" + + logger.debug( + "Tool executed successfully", + tool_name=name, + execution_time_ms=execution_time, + ) + + return result + + except ValidationError as e: + execution_time = (time.time() - start_time) * 1000 + logger.error( + "Tool parameter validation failed", + tool_name=name, + error=str(e), + execution_time_ms=execution_time, + ) + raise ToolValidationError(f"Parameter validation failed: {e}") + + except Exception as e: + execution_time = (time.time() - start_time) * 1000 + logger.error( + "Tool execution failed", + tool_name=name, + error=str(e), + execution_time_ms=execution_time, + ) + raise ToolExecutionError(f"Tool execution failed: {e}") + + wrapped_tool = sync_wrapped_tool + + # Create tool definition + tool_definition = ToolDefinition( + name=name, + description=description, + parameters=parameters, + function=wrapped_tool, + created_at=time.time(), + version=version, + requires_auth=requires_auth, + timeout_seconds=timeout_seconds, + ) + + # Attach metadata to function + wrapped_tool.tool_definition = tool_definition + + # Register tool automatically + _tool_registry.register_tool(tool_definition) + + return wrapped_tool + return decorator -def validate_tool_parameters(parameters: Dict[str, Any], schema: Dict[str, Any]) -> None: +def validate_tool_parameters( + parameters: Dict[str, Any], schema: Dict[str, Any] +) -> None: """ Validate tool parameters against schema definition. - + Args: parameters: Parameter values to validate schema: JSON schema for parameters - + Raises: ValidationError: If validation fails """ errors = [] - + # Check required parameters required_params = [] for param_name, param_schema in schema.items(): - if isinstance(param_schema, dict) and param_schema.get("required", True) and "default" not in param_schema: + if ( + isinstance(param_schema, dict) + and param_schema.get("required", True) + and "default" not in param_schema + ): required_params.append(param_name) - + for required_param in required_params: if required_param not in parameters: errors.append(f"Missing required parameter: {required_param}") - + # Validate parameter types and constraints for param_name, param_value in parameters.items(): if param_name in schema: @@ -495,30 +531,39 @@ def validate_tool_parameters(parameters: Dict[str, Any], schema: Dict[str, Any]) # Type validation expected_type = param_schema.get("type") if expected_type and not _validate_type(param_value, expected_type): - errors.append(f"Invalid type for {param_name}: expected {expected_type}") - + errors.append( + f"Invalid type for {param_name}: expected {expected_type}" + ) + # String constraints if expected_type == "string" and isinstance(param_value, str): min_length = param_schema.get("minLength") max_length = param_schema.get("maxLength") pattern = param_schema.get("pattern") - + if min_length and len(param_value) < min_length: - errors.append(f"{param_name} is too short (minimum {min_length} characters)") - + errors.append( + f"{param_name} is too short (minimum {min_length} characters)" + ) + if max_length and len(param_value) > max_length: - errors.append(f"{param_name} is too long (maximum {max_length} characters)") - + errors.append( + f"{param_name} is too long (maximum {max_length} characters)" + ) + if pattern: import re + if not re.match(pattern, param_value): - errors.append(f"{param_name} does not match required pattern") - + errors.append( + f"{param_name} does not match required pattern" + ) + # Enum validation enum_values = param_schema.get("enum") if enum_values and param_value not in enum_values: errors.append(f"{param_name} must be one of: {enum_values}") - + if errors: raise ValidationError("; ".join(errors)) @@ -526,11 +571,11 @@ def validate_tool_parameters(parameters: Dict[str, Any], schema: Dict[str, Any]) def _validate_type(value: Any, expected_type: str) -> bool: """ Validate value type against expected JSON schema type. - + Args: value: Value to validate expected_type: Expected JSON schema type - + Returns: True if type matches, False otherwise """ @@ -540,21 +585,21 @@ def _validate_type(value: Any, expected_type: str) -> bool: "integer": int, "boolean": bool, "object": dict, - "array": list + "array": list, } - + expected_python_type = type_mapping.get(expected_type) if expected_python_type: return isinstance(value, expected_python_type) - + return False def get_tool_registry() -> ToolRegistry: """ Get the global tool registry instance. - + Returns: Global ToolRegistry instance """ - return _tool_registry \ No newline at end of file + return _tool_registry diff --git a/src/tools/executor.py b/src/tools/executor.py index 5df1cd7..be08b6e 100644 --- a/src/tools/executor.py +++ b/src/tools/executor.py @@ -1,4 +1,3 @@ - """ FACT System Tool Execution Engine @@ -22,7 +21,7 @@ ToolNotFoundError, UnauthorizedError, SecurityError, - FinalRetryError + FinalRetryError, ) from .decorators import get_tool_registry, ToolDefinition from .validation import ParameterValidator, SecurityValidator @@ -33,18 +32,19 @@ # Fall back to absolute imports (when run as script) import sys from pathlib import Path + # Add src to path if not already there src_path = str(Path(__file__).parent.parent) if src_path not in sys.path: sys.path.insert(0, src_path) - + from core.errors import ( ToolExecutionError, ToolValidationError, ToolNotFoundError, UnauthorizedError, SecurityError, - FinalRetryError + FinalRetryError, ) from tools.decorators import get_tool_registry, ToolDefinition from tools.validation import ParameterValidator, SecurityValidator @@ -60,13 +60,14 @@ @dataclass class ToolCall: """Represents a tool call request from an LLM.""" + id: str name: str arguments: Dict[str, Any] user_id: Optional[str] = None session_id: Optional[str] = None timestamp: float = None - + def __post_init__(self): if self.timestamp is None: self.timestamp = time.time() @@ -75,6 +76,7 @@ def __post_init__(self): @dataclass class ToolResult: """Represents the result of a tool execution.""" + call_id: str tool_name: str success: bool @@ -83,7 +85,7 @@ class ToolResult: execution_time_ms: float = 0 status_code: int = 200 metadata: Optional[Dict[str, Any]] = None - + def to_dict(self) -> Dict[str, Any]: """Convert result to dictionary format.""" result = { @@ -91,56 +93,56 @@ def to_dict(self) -> Dict[str, Any]: "tool_name": self.tool_name, "success": self.success, "execution_time_ms": self.execution_time_ms, - "status_code": self.status_code + "status_code": self.status_code, } - + if self.success and self.data: result["data"] = self.data elif not self.success and self.error: result["error"] = self.error - + if self.metadata: result["metadata"] = self.metadata - + return result class RateLimiter: """Simple token bucket rate limiter for tool execution.""" - + def __init__(self, max_calls_per_minute: int = 60): """ Initialize rate limiter. - + Args: max_calls_per_minute: Maximum calls allowed per minute """ self.max_calls = max_calls_per_minute self.calls = [] - + def can_execute(self, user_id: Optional[str] = None) -> bool: """ Check if execution is allowed based on rate limits. - + Args: user_id: Optional user identifier for per-user limits - + Returns: True if execution is allowed """ current_time = time.time() - + # Remove calls older than 1 minute cutoff_time = current_time - 60 self.calls = [call_time for call_time in self.calls if call_time > cutoff_time] - + # Check if under limit return len(self.calls) < self.max_calls - + def record_call(self, user_id: Optional[str] = None) -> None: """ Record a tool call for rate limiting. - + Args: user_id: Optional user identifier """ @@ -150,19 +152,21 @@ def record_call(self, user_id: Optional[str] = None) -> None: class ToolExecutor: """ Main tool execution engine that handles LLM tool calls. - + Provides secure tool execution with Arcade.dev integration, parameter validation, authorization checking, and rate limiting. """ - - def __init__(self, - arcade_client: Optional[ArcadeClient] = None, - enable_rate_limiting: bool = True, - max_calls_per_minute: int = 60, - default_timeout: int = 30): + + def __init__( + self, + arcade_client: Optional[ArcadeClient] = None, + enable_rate_limiting: bool = True, + max_calls_per_minute: int = 60, + default_timeout: int = 30, + ): """ Initialize tool executor. - + Args: arcade_client: Arcade.dev client for tool execution enable_rate_limiting: Whether to enable rate limiting @@ -175,55 +179,63 @@ def __init__(self, self.security_validator = SecurityValidator() self.auth_manager = AuthorizationManager() self.metrics_collector = MetricsCollector() - + # Rate limiting self.enable_rate_limiting = enable_rate_limiting - self.rate_limiter = RateLimiter(max_calls_per_minute) if enable_rate_limiting else None - + self.rate_limiter = ( + RateLimiter(max_calls_per_minute) if enable_rate_limiting else None + ) + self.default_timeout = default_timeout - - logger.info("ToolExecutor initialized", - rate_limiting=enable_rate_limiting, - max_calls_per_minute=max_calls_per_minute) - + + logger.info( + "ToolExecutor initialized", + rate_limiting=enable_rate_limiting, + max_calls_per_minute=max_calls_per_minute, + ) + async def execute_tool_call(self, tool_call: ToolCall) -> ToolResult: """ Execute a single tool call with full validation and security. - + Args: tool_call: Tool call to execute - + Returns: ToolResult containing execution results """ start_time = time.time() - + try: # Rate limiting check - if self.enable_rate_limiting and not self.rate_limiter.can_execute(tool_call.user_id): - raise ToolExecutionError("Rate limit exceeded. Too many tool calls per minute.") - + if self.enable_rate_limiting and not self.rate_limiter.can_execute( + tool_call.user_id + ): + raise ToolExecutionError( + "Rate limit exceeded. Too many tool calls per minute." + ) + # Get tool definition tool_definition = self._get_tool_definition(tool_call.name) - + # Security validation await self._validate_security(tool_call, tool_definition) - + # Parameter validation self._validate_parameters(tool_call.arguments, tool_definition) - + # Authorization check await self._check_authorization(tool_call, tool_definition) - + # Record call for rate limiting if self.enable_rate_limiting: self.rate_limiter.record_call(tool_call.user_id) - + # Execute tool result_data = await self._execute_tool(tool_call, tool_definition) - + execution_time = (time.time() - start_time) * 1000 - + # Create successful result result = ToolResult( call_id=tool_call.id, @@ -234,29 +246,29 @@ async def execute_tool_call(self, tool_call: ToolCall) -> ToolResult: metadata={ "user_id": tool_call.user_id, "session_id": tool_call.session_id, - "timestamp": tool_call.timestamp - } + "timestamp": tool_call.timestamp, + }, ) - + # Log successful execution - logger.info("Tool executed successfully", - tool_name=tool_call.name, - call_id=tool_call.id, - execution_time_ms=execution_time, - user_id=tool_call.user_id) - + logger.info( + "Tool executed successfully", + tool_name=tool_call.name, + call_id=tool_call.id, + execution_time_ms=execution_time, + user_id=tool_call.user_id, + ) + # Record metrics self.metrics_collector.record_tool_execution( - tool_name=tool_call.name, - success=True, - execution_time=execution_time + tool_name=tool_call.name, success=True, execution_time=execution_time ) - + return result - + except Exception as e: execution_time = (time.time() - start_time) * 1000 - + # Create error result result = ToolResult( call_id=tool_call.id, @@ -269,50 +281,54 @@ async def execute_tool_call(self, tool_call: ToolCall) -> ToolResult: "user_id": tool_call.user_id, "session_id": tool_call.session_id, "timestamp": tool_call.timestamp, - "error_type": type(e).__name__ - } + "error_type": type(e).__name__, + }, ) - + # Log error - logger.error("Tool execution failed", - tool_name=tool_call.name, - call_id=tool_call.id, - error=str(e), - error_type=type(e).__name__, - execution_time_ms=execution_time, - user_id=tool_call.user_id) - + logger.error( + "Tool execution failed", + tool_name=tool_call.name, + call_id=tool_call.id, + error=str(e), + error_type=type(e).__name__, + execution_time_ms=execution_time, + user_id=tool_call.user_id, + ) + # Record metrics self.metrics_collector.record_tool_execution( tool_name=tool_call.name, success=False, execution_time=execution_time, - error_type=type(e).__name__ + error_type=type(e).__name__, ) - + return result - + async def execute_tool_calls(self, tool_calls: List[ToolCall]) -> List[ToolResult]: """ Execute multiple tool calls concurrently. - + Args: tool_calls: List of tool calls to execute - + Returns: List of ToolResults in the same order as input """ if not tool_calls: return [] - - logger.info("Executing tool calls batch", - count=len(tool_calls), - tool_names=[call.name for call in tool_calls]) - + + logger.info( + "Executing tool calls batch", + count=len(tool_calls), + tool_names=[call.name for call in tool_calls], + ) + # Execute all tool calls concurrently tasks = [self.execute_tool_call(call) for call in tool_calls] results = await asyncio.gather(*tasks, return_exceptions=True) - + # Convert exceptions to error results processed_results = [] for i, result in enumerate(results): @@ -322,30 +338,30 @@ async def execute_tool_calls(self, tool_calls: List[ToolCall]) -> List[ToolResul tool_name=tool_calls[i].name, success=False, error=f"Execution failed: {str(result)}", - status_code=500 + status_code=500, ) processed_results.append(error_result) else: processed_results.append(result) - + return processed_results - + def get_available_tools(self) -> List[Dict[str, Any]]: """ Get list of available tools with their schemas. - + Returns: List of tool schemas for LLM consumption """ return self.tool_registry.export_all_schemas() - + def get_tool_info(self, tool_name: str) -> Dict[str, Any]: """ Get detailed information about a specific tool. - + Args: tool_name: Name of the tool - + Returns: Tool information dictionary """ @@ -358,54 +374,61 @@ def get_tool_info(self, tool_name: str) -> Dict[str, Any]: "version": tool_definition.version, "requires_auth": tool_definition.requires_auth, "timeout_seconds": tool_definition.timeout_seconds, - "created_at": tool_definition.created_at + "created_at": tool_definition.created_at, } except Exception as e: raise ToolNotFoundError(f"Tool '{tool_name}' not found: {str(e)}") - + def _get_tool_definition(self, tool_name: str) -> ToolDefinition: """Get tool definition from registry.""" try: return self.tool_registry.get_tool(tool_name) except Exception as e: raise ToolNotFoundError(f"Tool '{tool_name}' not found in registry") - - async def _validate_security(self, tool_call: ToolCall, tool_definition: ToolDefinition) -> None: + + async def _validate_security( + self, tool_call: ToolCall, tool_definition: ToolDefinition + ) -> None: """Validate security constraints for tool call.""" try: # Validate tool name and arguments for security issues await self.security_validator.validate_tool_call( tool_name=tool_call.name, arguments=tool_call.arguments, - user_id=tool_call.user_id + user_id=tool_call.user_id, ) except Exception as e: raise SecurityError(f"Security validation failed: {str(e)}") - - def _validate_parameters(self, arguments: Dict[str, Any], tool_definition: ToolDefinition) -> None: + + def _validate_parameters( + self, arguments: Dict[str, Any], tool_definition: ToolDefinition + ) -> None: """Validate tool parameters against schema.""" try: self.parameter_validator.validate(arguments, tool_definition.parameters) except Exception as e: raise ToolValidationError(f"Parameter validation failed: {str(e)}") - - async def _check_authorization(self, tool_call: ToolCall, tool_definition: ToolDefinition) -> None: + + async def _check_authorization( + self, tool_call: ToolCall, tool_definition: ToolDefinition + ) -> None: """Check user authorization for tool execution.""" if not tool_definition.requires_auth: return - + if not tool_call.user_id: raise UnauthorizedError("User authentication required for this tool") - + try: await self.auth_manager.validate_authorization( - user_id=tool_call.user_id, - tool_name=tool_call.name + user_id=tool_call.user_id, tool_name=tool_call.name ) except Exception as e: raise UnauthorizedError(f"Authorization failed: {str(e)}") - - async def _execute_tool(self, tool_call: ToolCall, tool_definition: ToolDefinition) -> Dict[str, Any]: + + async def _execute_tool( + self, tool_call: ToolCall, tool_definition: ToolDefinition + ) -> Dict[str, Any]: """Execute the actual tool function.""" if self.arcade_client: # Execute via Arcade.dev @@ -413,48 +436,55 @@ async def _execute_tool(self, tool_call: ToolCall, tool_definition: ToolDefiniti else: # Execute locally return await self._execute_locally(tool_call, tool_definition) - - async def _execute_via_arcade(self, tool_call: ToolCall, tool_definition: ToolDefinition) -> Dict[str, Any]: + + async def _execute_via_arcade( + self, tool_call: ToolCall, tool_definition: ToolDefinition + ) -> Dict[str, Any]: """Execute tool via Arcade.dev gateway.""" try: result = await self.arcade_client.execute_tool( tool_name=tool_call.name, arguments=tool_call.arguments, - timeout=tool_definition.timeout_seconds + timeout=tool_definition.timeout_seconds, ) return result except Exception as e: raise ToolExecutionError(f"Arcade execution failed: {str(e)}") - - async def _execute_locally(self, tool_call: ToolCall, tool_definition: ToolDefinition) -> Dict[str, Any]: + + async def _execute_locally( + self, tool_call: ToolCall, tool_definition: ToolDefinition + ) -> Dict[str, Any]: """Execute tool locally with timeout protection.""" try: # Check if function is async import inspect + if inspect.iscoroutinefunction(tool_definition.function): # Execute async function directly result = await asyncio.wait_for( tool_definition.function(**tool_call.arguments), - timeout=tool_definition.timeout_seconds + timeout=tool_definition.timeout_seconds, ) else: # Execute sync function in thread result = await asyncio.wait_for( asyncio.to_thread(tool_definition.function, **tool_call.arguments), - timeout=tool_definition.timeout_seconds + timeout=tool_definition.timeout_seconds, ) - + # Ensure result is a dictionary if not isinstance(result, dict): result = {"result": result} - + return result - + except asyncio.TimeoutError: - raise ToolExecutionError(f"Tool execution timed out after {tool_definition.timeout_seconds} seconds") + raise ToolExecutionError( + f"Tool execution timed out after {tool_definition.timeout_seconds} seconds" + ) except Exception as e: raise ToolExecutionError(f"Local execution failed: {str(e)}") - + def _get_error_status_code(self, error: Exception) -> int: """Get HTTP status code for error type.""" error_codes = { @@ -463,49 +493,48 @@ def _get_error_status_code(self, error: Exception) -> int: UnauthorizedError: 401, SecurityError: 403, ToolExecutionError: 500, - FinalRetryError: 503 + FinalRetryError: 503, } - + return error_codes.get(type(error), 500) # Utility functions for tool execution -def create_tool_call(tool_name: str, - arguments: Dict[str, Any], - call_id: Optional[str] = None, - user_id: Optional[str] = None) -> ToolCall: + +def create_tool_call( + tool_name: str, + arguments: Dict[str, Any], + call_id: Optional[str] = None, + user_id: Optional[str] = None, +) -> ToolCall: """ Create a ToolCall instance with proper validation. - + Args: tool_name: Name of the tool to call arguments: Tool arguments call_id: Optional call identifier user_id: Optional user identifier - + Returns: ToolCall instance """ if call_id is None: import uuid + call_id = str(uuid.uuid4()) - - return ToolCall( - id=call_id, - name=tool_name, - arguments=arguments, - user_id=user_id - ) + + return ToolCall(id=call_id, name=tool_name, arguments=arguments, user_id=user_id) def format_tool_result_for_llm(result: ToolResult) -> Dict[str, Any]: """ Format tool result for LLM consumption. - + Args: result: ToolResult to format - + Returns: Formatted result dictionary """ @@ -514,14 +543,17 @@ def format_tool_result_for_llm(result: ToolResult) -> Dict[str, Any]: "tool_call_id": result.call_id, "name": result.tool_name, } - + if result.success: formatted["content"] = json.dumps(result.data, indent=2) else: - formatted["content"] = json.dumps({ - "error": result.error, - "status": "failed", - "execution_time_ms": result.execution_time_ms - }, indent=2) - - return formatted \ No newline at end of file + formatted["content"] = json.dumps( + { + "error": result.error, + "status": "failed", + "execution_time_ms": result.execution_time_ms, + }, + indent=2, + ) + + return formatted diff --git a/src/tools/validation.py b/src/tools/validation.py index 086b4ab..7f0b7fa 100644 --- a/src/tools/validation.py +++ b/src/tools/validation.py @@ -18,11 +18,12 @@ # Fall back to absolute imports (when run as script) import sys from pathlib import Path + # Add src to path if not already there src_path = str(Path(__file__).parent.parent) if src_path not in sys.path: sys.path.insert(0, src_path) - + from core.errors import ValidationError, SecurityError @@ -32,11 +33,11 @@ class ParameterValidator: """ Validates tool parameters against JSON schema definitions. - + Provides comprehensive validation including type checking, constraint validation, and security sanitization. """ - + def __init__(self): """Initialize parameter validator.""" self.type_validators = { @@ -45,75 +46,84 @@ def __init__(self): "integer": self._validate_integer, "boolean": self._validate_boolean, "object": self._validate_object, - "array": self._validate_array + "array": self._validate_array, } - + def validate(self, parameters: Dict[str, Any], schema: Dict[str, Any]) -> None: """ Validate parameters against schema definition. - + Args: parameters: Parameter values to validate schema: JSON schema for parameters - + Raises: ValidationError: If validation fails """ errors = [] - + try: # Check required parameters required_params = self._extract_required_params(schema) for required_param in required_params: if required_param not in parameters: errors.append(f"Missing required parameter: {required_param}") - + # Validate each parameter for param_name, param_value in parameters.items(): if param_name in schema: param_schema = schema[param_name] - param_errors = self._validate_parameter(param_name, param_value, param_schema) + param_errors = self._validate_parameter( + param_name, param_value, param_schema + ) errors.extend(param_errors) else: # Allow extra parameters but log warning - logger.warning("Unknown parameter provided", - parameter=param_name, - value=str(param_value)[:100]) - + logger.warning( + "Unknown parameter provided", + parameter=param_name, + value=str(param_value)[:100], + ) + if errors: error_message = "; ".join(errors) - logger.error("Parameter validation failed", - errors=errors, - parameter_count=len(parameters)) + logger.error( + "Parameter validation failed", + errors=errors, + parameter_count=len(parameters), + ) raise ValidationError(error_message) - - logger.debug("Parameter validation passed", - parameter_count=len(parameters)) - + + logger.debug("Parameter validation passed", parameter_count=len(parameters)) + except ValidationError: raise except Exception as e: logger.error("Unexpected validation error", error=str(e)) raise ValidationError(f"Validation failed: {str(e)}") - + def _extract_required_params(self, schema: Dict[str, Any]) -> List[str]: """Extract required parameter names from schema.""" required = [] - + for param_name, param_schema in schema.items(): if isinstance(param_schema, dict): # Parameter is required if it has no default and is not marked as optional - if ("default" not in param_schema and - param_schema.get("required", True) and - not param_schema.get("optional", False)): + if ( + "default" not in param_schema + and param_schema.get("required", True) + and not param_schema.get("optional", False) + ): required.append(param_name) - + return required - - def _validate_parameter(self, name: str, value: Any, schema: Dict[str, Any]) -> List[str]: + + def _validate_parameter( + self, name: str, value: Any, schema: Dict[str, Any] + ) -> List[str]: """Validate a single parameter against its schema.""" errors = [] - + # Type validation param_type = schema.get("type") if param_type: @@ -123,32 +133,34 @@ def _validate_parameter(self, name: str, value: Any, schema: Dict[str, Any]) -> errors.extend(type_errors) else: errors.append(f"Unknown type '{param_type}' for parameter {name}") - + # Enum validation enum_values = schema.get("enum") if enum_values and value not in enum_values: errors.append(f"{name} must be one of: {enum_values}") - + return errors - - def _validate_string(self, name: str, value: Any, schema: Dict[str, Any]) -> List[str]: + + def _validate_string( + self, name: str, value: Any, schema: Dict[str, Any] + ) -> List[str]: """Validate string parameter.""" errors = [] - + if not isinstance(value, str): errors.append(f"{name} must be a string") return errors - + # Length constraints min_length = schema.get("minLength") max_length = schema.get("maxLength") - + if min_length is not None and len(value) < min_length: errors.append(f"{name} must be at least {min_length} characters long") - + if max_length is not None and len(value) > max_length: errors.append(f"{name} must be at most {max_length} characters long") - + # Pattern validation pattern = schema.get("pattern") if pattern: @@ -157,122 +169,136 @@ def _validate_string(self, name: str, value: Any, schema: Dict[str, Any]) -> Lis errors.append(f"{name} does not match required pattern") except re.error as e: errors.append(f"Invalid pattern for {name}: {str(e)}") - + # Format validation format_type = schema.get("format") if format_type: format_errors = self._validate_string_format(name, value, format_type) errors.extend(format_errors) - + return errors - - def _validate_number(self, name: str, value: Any, schema: Dict[str, Any]) -> List[str]: + + def _validate_number( + self, name: str, value: Any, schema: Dict[str, Any] + ) -> List[str]: """Validate number parameter.""" errors = [] - + if not isinstance(value, (int, float)): errors.append(f"{name} must be a number") return errors - + # Range constraints minimum = schema.get("minimum") maximum = schema.get("maximum") exclusive_minimum = schema.get("exclusiveMinimum") exclusive_maximum = schema.get("exclusiveMaximum") - + if minimum is not None and value < minimum: errors.append(f"{name} must be >= {minimum}") - + if maximum is not None and value > maximum: errors.append(f"{name} must be <= {maximum}") - + if exclusive_minimum is not None and value <= exclusive_minimum: errors.append(f"{name} must be > {exclusive_minimum}") - + if exclusive_maximum is not None and value >= exclusive_maximum: errors.append(f"{name} must be < {exclusive_maximum}") - + # Multiple of constraint multiple_of = schema.get("multipleOf") if multiple_of is not None and value % multiple_of != 0: errors.append(f"{name} must be a multiple of {multiple_of}") - + return errors - - def _validate_integer(self, name: str, value: Any, schema: Dict[str, Any]) -> List[str]: + + def _validate_integer( + self, name: str, value: Any, schema: Dict[str, Any] + ) -> List[str]: """Validate integer parameter.""" errors = [] - + if not isinstance(value, int) or isinstance(value, bool): errors.append(f"{name} must be an integer") return errors - + # Use number validation for range constraints number_errors = self._validate_number(name, value, schema) errors.extend(number_errors) - + return errors - - def _validate_boolean(self, name: str, value: Any, schema: Dict[str, Any]) -> List[str]: + + def _validate_boolean( + self, name: str, value: Any, schema: Dict[str, Any] + ) -> List[str]: """Validate boolean parameter.""" errors = [] - + if not isinstance(value, bool): errors.append(f"{name} must be a boolean") - + return errors - - def _validate_object(self, name: str, value: Any, schema: Dict[str, Any]) -> List[str]: + + def _validate_object( + self, name: str, value: Any, schema: Dict[str, Any] + ) -> List[str]: """Validate object parameter.""" errors = [] - + if not isinstance(value, dict): errors.append(f"{name} must be an object") return errors - + # Validate object properties if schema provided properties = schema.get("properties") if properties: for prop_name, prop_value in value.items(): if prop_name in properties: prop_schema = properties[prop_name] - prop_errors = self._validate_parameter(f"{name}.{prop_name}", prop_value, prop_schema) + prop_errors = self._validate_parameter( + f"{name}.{prop_name}", prop_value, prop_schema + ) errors.extend(prop_errors) - + # Additional properties validation additional_properties = schema.get("additionalProperties", True) if not additional_properties and properties: for prop_name in value: if prop_name not in properties: errors.append(f"{name} contains unexpected property: {prop_name}") - + return errors - - def _validate_array(self, name: str, value: Any, schema: Dict[str, Any]) -> List[str]: + + def _validate_array( + self, name: str, value: Any, schema: Dict[str, Any] + ) -> List[str]: """Validate array parameter.""" errors = [] - + if not isinstance(value, list): errors.append(f"{name} must be an array") return errors - + # Length constraints min_items = schema.get("minItems") max_items = schema.get("maxItems") - + if min_items is not None and len(value) < min_items: errors.append(f"{name} must have at least {min_items} items") - + if max_items is not None and len(value) > max_items: errors.append(f"{name} must have at most {max_items} items") - + # Items validation items_schema = schema.get("items") if items_schema: for i, item in enumerate(value): - item_errors = self._validate_parameter(f"{name}[{i}]", item, items_schema) + item_errors = self._validate_parameter( + f"{name}[{i}]", item, items_schema + ) errors.extend(item_errors) - + # Unique items constraint unique_items = schema.get("uniqueItems", False) if unique_items: @@ -283,69 +309,72 @@ def _validate_array(self, name: str, value: Any, schema: Dict[str, Any]) -> List json_items.append(json.dumps(item, sort_keys=True)) except (TypeError, ValueError): json_items.append(str(item)) - + if len(set(json_items)) != len(json_items): errors.append(f"{name} must contain unique items") - + return errors - - def _validate_string_format(self, name: str, value: str, format_type: str) -> List[str]: + + def _validate_string_format( + self, name: str, value: str, format_type: str + ) -> List[str]: """Validate string format constraints.""" errors = [] - + format_validators = { "email": self._validate_email, "uri": self._validate_uri, "date": self._validate_date, "datetime": self._validate_datetime, "ipv4": self._validate_ipv4, - "ipv6": self._validate_ipv6 + "ipv6": self._validate_ipv6, } - + validator = format_validators.get(format_type) if validator: if not validator(value): errors.append(f"{name} is not a valid {format_type}") else: logger.warning("Unknown string format", format=format_type, parameter=name) - + return errors - + def _validate_email(self, value: str) -> bool: """Validate email format.""" - email_pattern = r'^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$' + email_pattern = r"^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$" return bool(re.match(email_pattern, value)) - + def _validate_uri(self, value: str) -> bool: """Validate URI format.""" - uri_pattern = r'^https?://.+' + uri_pattern = r"^https?://.+" return bool(re.match(uri_pattern, value)) - + def _validate_date(self, value: str) -> bool: """Validate date format (YYYY-MM-DD).""" try: - datetime.strptime(value, '%Y-%m-%d') + datetime.strptime(value, "%Y-%m-%d") return True except ValueError: return False - + def _validate_datetime(self, value: str) -> bool: """Validate datetime format (ISO 8601).""" try: - datetime.fromisoformat(value.replace('Z', '+00:00')) + datetime.fromisoformat(value.replace("Z", "+00:00")) return True except ValueError: return False - + def _validate_ipv4(self, value: str) -> bool: """Validate IPv4 address format.""" - ipv4_pattern = r'^(?:(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.){3}(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)$' + ipv4_pattern = r"^(?:(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.){3}(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)$" return bool(re.match(ipv4_pattern, value)) - + def _validate_ipv6(self, value: str) -> bool: """Validate IPv6 address format.""" try: import ipaddress + ipaddress.IPv6Address(value) return True except ValueError: @@ -355,126 +384,133 @@ def _validate_ipv6(self, value: str) -> bool: class SecurityValidator: """ Validates tool calls for security constraints and potential threats. - + Provides security validation to prevent dangerous operations and protect against malicious tool usage. """ - + def __init__(self): """Initialize security validator.""" self.dangerous_patterns = [ # SQL injection patterns - r'\b(union\s+select|drop\s+table|delete\s+from|insert\s+into|update\s+set)\b', + r"\b(union\s+select|drop\s+table|delete\s+from|insert\s+into|update\s+set)\b", # Command injection patterns - r'[;&|`$(){}[\]\\]', + r"[;&|`$(){}[\]\\]", # Path traversal patterns - r'\.\./|\.\.\\', + r"\.\./|\.\.\\", # Script injection patterns - r' None: + + async def validate_tool_call( + self, tool_name: str, arguments: Dict[str, Any], user_id: Optional[str] = None + ) -> None: """ Validate tool call for security issues. - + Args: tool_name: Name of the tool being called arguments: Tool arguments to validate user_id: Optional user identifier - + Raises: SecurityError: If security validation fails """ try: # Validate tool name self._validate_tool_name(tool_name) - + # Validate arguments structure self._validate_arguments_structure(arguments) - + # Check for dangerous patterns self._check_dangerous_patterns(arguments) - + # Validate argument sizes self._validate_argument_sizes(arguments) - - logger.debug("Security validation passed", - tool_name=tool_name, - user_id=user_id, - argument_count=len(arguments)) - + + logger.debug( + "Security validation passed", + tool_name=tool_name, + user_id=user_id, + argument_count=len(arguments), + ) + except SecurityError: raise except Exception as e: - logger.error("Security validation error", - tool_name=tool_name, - error=str(e)) + logger.error("Security validation error", tool_name=tool_name, error=str(e)) raise SecurityError(f"Security validation failed: {str(e)}") - + def _validate_tool_name(self, tool_name: str) -> None: """Validate tool name for security issues.""" if not tool_name or not isinstance(tool_name, str): raise SecurityError("Invalid tool name") - + # Check for dangerous characters - if re.search(r'[<>;&|`$(){}[\]\\]', tool_name): + if re.search(r"[<>;&|`$(){}[\]\\]", tool_name): raise SecurityError("Tool name contains dangerous characters") - + # Check length if len(tool_name) > 100: raise SecurityError("Tool name too long") - + # Validate naming convention - if not re.match(r'^[A-Za-z][A-Za-z0-9_.]*$', tool_name): + if not re.match(r"^[A-Za-z][A-Za-z0-9_.]*$", tool_name): raise SecurityError("Tool name does not follow naming convention") - + def _validate_arguments_structure(self, arguments: Dict[str, Any]) -> None: """Validate arguments structure for security.""" if not isinstance(arguments, dict): raise SecurityError("Arguments must be a dictionary") - + # Check argument count if len(arguments) > 50: raise SecurityError("Too many arguments provided") - + # Validate argument names for arg_name in arguments.keys(): if not isinstance(arg_name, str): raise SecurityError("Argument names must be strings") - + if len(arg_name) > 100: raise SecurityError(f"Argument name too long: {arg_name}") - - if re.search(r'[<>;&|`$(){}[\]\\]', arg_name): - raise SecurityError(f"Argument name contains dangerous characters: {arg_name}") - - def _check_dangerous_patterns(self, arguments: Dict[str, Any], depth: int = 0) -> None: + + if re.search(r"[<>;&|`$(){}[\]\\]", arg_name): + raise SecurityError( + f"Argument name contains dangerous characters: {arg_name}" + ) + + def _check_dangerous_patterns( + self, arguments: Dict[str, Any], depth: int = 0 + ) -> None: """Check for dangerous patterns in arguments.""" if depth > self.max_object_depth: raise SecurityError("Argument structure too deep") - + for key, value in arguments.items(): if isinstance(value, str): # Check string for dangerous patterns for pattern in self.dangerous_patterns: if re.search(pattern, value, re.IGNORECASE): - logger.warning("Dangerous pattern detected", - pattern=pattern, - argument=key, - value=value[:100]) - raise SecurityError(f"Dangerous pattern detected in argument: {key}") - + logger.warning( + "Dangerous pattern detected", + pattern=pattern, + argument=key, + value=value[:100], + ) + raise SecurityError( + f"Dangerous pattern detected in argument: {key}" + ) + elif isinstance(value, dict): # Recursively check nested objects self._check_dangerous_patterns(value, depth + 1) - + elif isinstance(value, list): # Check array elements for item in value: @@ -483,34 +519,41 @@ def _check_dangerous_patterns(self, arguments: Dict[str, Any], depth: int = 0) - elif isinstance(item, str): for pattern in self.dangerous_patterns: if re.search(pattern, item, re.IGNORECASE): - raise SecurityError(f"Dangerous pattern detected in array argument: {key}") - + raise SecurityError( + f"Dangerous pattern detected in array argument: {key}" + ) + def _validate_argument_sizes(self, arguments: Dict[str, Any]) -> None: """Validate argument sizes to prevent DoS attacks.""" + def check_size(obj, current_depth=0): if current_depth > self.max_object_depth: raise SecurityError("Argument structure too deep") - + if isinstance(obj, str): if len(obj) > self.max_string_length: - raise SecurityError(f"String argument too long: {len(obj)} characters") - + raise SecurityError( + f"String argument too long: {len(obj)} characters" + ) + elif isinstance(obj, list): if len(obj) > self.max_array_length: raise SecurityError(f"Array argument too long: {len(obj)} items") - + for item in obj: check_size(item, current_depth + 1) - + elif isinstance(obj, dict): if len(obj) > 100: raise SecurityError(f"Object has too many properties: {len(obj)}") - + for value in obj.values(): check_size(value, current_depth + 1) - + for arg_name, arg_value in arguments.items(): try: check_size(arg_value) except SecurityError as e: - raise SecurityError(f"Size validation failed for argument '{arg_name}': {str(e)}") \ No newline at end of file + raise SecurityError( + f"Size validation failed for argument '{arg_name}': {str(e)}" + ) diff --git a/tests/conftest.py b/tests/conftest.py index ff110b9..29310b1 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -26,9 +26,9 @@ def event_loop(): @pytest.fixture def test_database() -> Generator[str, None, None]: """Create a temporary test database with sample data.""" - with tempfile.NamedTemporaryFile(suffix='.db', delete=False) as tmp: + with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as tmp: db_path = tmp.name - + # Initialize test database with financial data conn = sqlite3.connect(db_path) conn.execute(""" @@ -39,26 +39,26 @@ def test_database() -> Generator[str, None, None]: created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ) """) - + # Insert test data matching requirements test_data = [ - ('Q1-2025', 1234567.89, 'Product Sales'), - ('Q4-2024', 1133221.55, 'Product Sales'), - ('Q3-2024', 987654.32, 'Service Revenue'), - ('Q2-2024', 876543.21, 'Product Sales') + ("Q1-2025", 1234567.89, "Product Sales"), + ("Q4-2024", 1133221.55, "Product Sales"), + ("Q3-2024", 987654.32, "Service Revenue"), + ("Q2-2024", 876543.21, "Product Sales"), ] - + for quarter, value, category in test_data: conn.execute( "INSERT INTO revenue (quarter, value, category) VALUES (?, ?, ?)", - (quarter, value, category) + (quarter, value, category), ) - + conn.commit() conn.close() - + yield db_path - + # Cleanup os.unlink(db_path) @@ -67,15 +67,15 @@ def test_database() -> Generator[str, None, None]: def mock_anthropic_client(): """Mock Anthropic client for testing cache behavior.""" client = AsyncMock() - + # Configure default response mock_response = Mock() mock_response.content = [Mock(text="Test response from Claude")] mock_response.tool_calls = None mock_response.usage = Mock(input_tokens=100, output_tokens=50) - + client.messages.create.return_value = mock_response - + return client @@ -83,14 +83,14 @@ def mock_anthropic_client(): def mock_arcade_client(): """Mock Arcade client for testing tool execution.""" client = Mock() - + # Configure successful tool execution client.tools.execute.return_value = { "status": "success", "data": {"rows": [], "row_count": 0}, - "execution_time_ms": 5 + "execution_time_ms": 5, } - + # Configure tool schema export client.tools.export_schema.return_value = [ { @@ -103,15 +103,15 @@ def mock_arcade_client(): "properties": { "statement": { "type": "string", - "description": "SQL SELECT statement" + "description": "SQL SELECT statement", } }, - "required": ["statement"] - } - } + "required": ["statement"], + }, + }, } ] - + return client @@ -124,15 +124,15 @@ def test_environment(test_database, mock_anthropic_client, mock_arcade_client): "ARCADE_URL": "http://localhost:9099", "FACT_DB": test_database, "FACT_CACHE_PREFIX": "fact_test_v1", - "FACT_LOG_LEVEL": "DEBUG" + "FACT_LOG_LEVEL": "DEBUG", } - + with patch.dict(os.environ, env_vars): yield { "database": test_database, "anthropic": mock_anthropic_client, "arcade": mock_arcade_client, - "env": env_vars + "env": env_vars, } @@ -145,7 +145,7 @@ def cache_config(): "max_size": "10MB", "ttl_seconds": 3600, "hit_target_ms": 50, - "miss_target_ms": 140 + "miss_target_ms": 140, } @@ -158,7 +158,7 @@ def performance_targets(): "tool_execution_lan_ms": 10, "overall_response_ms": 100, "cost_reduction_cache_hit": 0.90, # 90% reduction - "cost_reduction_cache_miss": 0.65 # 65% reduction + "cost_reduction_cache_miss": 0.65, # 65% reduction } @@ -172,14 +172,14 @@ def security_test_data(): "SELECT * FROM revenue UNION SELECT * FROM users", "'; EXEC xp_cmdshell('dir'); --", "SELECT * FROM revenue WHERE quarter = 'Q1' OR '1'='1'", - "SELECT * FROM revenue; INSERT INTO revenue VALUES ('hack', 0); --" + "SELECT * FROM revenue; INSERT INTO revenue VALUES ('hack', 0); --", ], "path_traversal_attempts": [ "../../../etc/passwd", "..\\..\\windows\\system32\\config\\sam", "/etc/shadow", "~/.ssh/id_rsa", - "../../../../proc/self/environ" + "../../../../proc/self/environ", ], "dangerous_urls": [ "http://localhost:80/admin", @@ -187,8 +187,8 @@ def security_test_data(): "http://169.254.169.254/metadata", "http://10.0.0.1/internal", "file:///etc/passwd", - "ftp://internal.server/data" - ] + "ftp://internal.server/data", + ], } @@ -205,14 +205,14 @@ def benchmark_queries(): "Calculate year-over-year growth", "Show revenue trends", "What is average quarterly revenue?", - "Find revenue outliers" + "Find revenue outliers", ] @pytest.fixture def mock_time(): """Mock time for consistent performance testing.""" - with patch('time.perf_counter') as mock_perf: + with patch("time.perf_counter") as mock_perf: # Simulate realistic timing mock_perf.side_effect = [0.000, 0.045] # 45ms execution yield mock_perf @@ -220,31 +220,31 @@ def mock_time(): class PerformanceTimer: """Helper class for measuring performance in tests.""" - + def __init__(self): self.start_time = None self.end_time = None - + def start(self): """Start timing.""" self.start_time = time.perf_counter() return self - + def stop(self): """Stop timing and return duration in milliseconds.""" self.end_time = time.perf_counter() return self.duration_ms - + @property def duration_ms(self) -> float: """Get duration in milliseconds.""" if self.start_time is None or self.end_time is None: return 0.0 return (self.end_time - self.start_time) * 1000 - + def __enter__(self): return self.start() - + def __exit__(self, *args): self.stop() @@ -257,24 +257,26 @@ def performance_timer(): class TestDataFactory: """Factory for creating test data.""" - + @staticmethod - def create_tool_call(name: str = "SQL.QueryReadonly", - statement: str = "SELECT * FROM revenue") -> Mock: + def create_tool_call( + name: str = "SQL.QueryReadonly", statement: str = "SELECT * FROM revenue" + ) -> Mock: """Create a mock tool call object.""" tool_call = Mock() tool_call.name = name tool_call.id = f"test-call-{hash(statement) % 10000}" tool_call.arguments = json.dumps({"statement": statement}) return tool_call - + @staticmethod - def create_cache_entry(prefix: str = "fact_test_v1", - content: str = None) -> Dict[str, Any]: + def create_cache_entry( + prefix: str = "fact_test_v1", content: str = None + ) -> Dict[str, Any]: """Create a cache entry for testing.""" if content is None: content = "A" * 500 # Minimum cache size - + return { "prefix": prefix, "content": content, @@ -283,22 +285,23 @@ def create_cache_entry(prefix: str = "fact_test_v1", "version": "1.0", "is_valid": True, "access_count": 0, - "last_accessed": None + "last_accessed": None, } - + @staticmethod - def create_authorization(user_id: str = "test@example.com", - scopes: list = None) -> Dict[str, Any]: + def create_authorization( + user_id: str = "test@example.com", scopes: list = None + ) -> Dict[str, Any]: """Create an authorization object for testing.""" if scopes is None: scopes = ["read"] - + return { "user_id": user_id, "scopes": scopes, "token": f"test_token_{hash(user_id) % 10000}", "expires_at": time.time() + 3600, - "created_at": time.time() + "created_at": time.time(), } @@ -309,16 +312,10 @@ def test_factory(): # Test markers for organizing test execution -pytestmark = [ - pytest.mark.asyncio -] +pytestmark = [pytest.mark.asyncio] def pytest_configure(config): """Configure pytest with custom markers and settings.""" - config.addinivalue_line( - "markers", "benchmark: Performance benchmark tests" - ) - config.addinivalue_line( - "markers", "cost_analysis: Token cost analysis tests" - ) \ No newline at end of file + config.addinivalue_line("markers", "benchmark: Performance benchmark tests") + config.addinivalue_line("markers", "cost_analysis: Token cost analysis tests") diff --git a/tests/debug_sql_validation.py b/tests/debug_sql_validation.py index 0de56b8..d0e7b45 100644 --- a/tests/debug_sql_validation.py +++ b/tests/debug_sql_validation.py @@ -6,32 +6,37 @@ import sys import os import re -sys.path.insert(0, os.path.abspath('..')) + +sys.path.insert(0, os.path.abspath("..")) from src.db.connection import DatabaseManager + def test_injection_patterns(): """Test specific injection patterns""" - + # The problematic query query = "SELECT name as table_name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%' ORDER BY name" normalized = query.lower() - + print(f"Testing query: {query}") print(f"Normalized: {normalized}") - + # Test each injection pattern individually injection_patterns = [ - (r';\s*(?:drop|delete|insert|update|create|alter)', "Multiple dangerous statements"), - (r'\bunion\s+(?:all\s+)?select\b', "Union injection attempts"), + ( + r";\s*(?:drop|delete|insert|update|create|alter)", + "Multiple dangerous statements", + ), + (r"\bunion\s+(?:all\s+)?select\b", "Union injection attempts"), (r'\bor\s+[\'"]?1[\'"]?\s*=\s*[\'"]?1[\'"]?\b', "Always true OR conditions"), (r'\band\s+[\'"]?1[\'"]?\s*=\s*[\'"]?1[\'"]?\b', "Always true AND conditions"), - (r'\'[^\']*\'[^\']*\'[^\']*\'', "Multiple quotes suggesting injection"), - (r'\\x[0-9a-f]{2}', "Hex encoding"), - (r'--.*(?:union|drop|delete|insert|update|create)', "Dangerous comments"), + (r"\'[^\']*\'[^\']*\'[^\']*\'", "Multiple quotes suggesting injection"), + (r"\\x[0-9a-f]{2}", "Hex encoding"), + (r"--.*(?:union|drop|delete|insert|update|create)", "Dangerous comments"), (r'\bor\s+[\'"]?\w+[\'"]?\s*=\s*[\'"]?\w+[\'"]?\s+or\b', "OR chain injections"), ] - + print("\n=== Testing Injection Patterns ===") for pattern, description in injection_patterns: match = re.search(pattern, normalized, re.IGNORECASE) @@ -42,11 +47,12 @@ def test_injection_patterns(): else: print(f"āœ“ No match: {description}") + def test_sql_validation(): """Test specific SQL queries to find validation issues""" - + db_manager = DatabaseManager("db/test_debug.db") - + # Test queries that should be valid test_queries = [ "SELECT name as table_name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%' ORDER BY name", @@ -54,11 +60,11 @@ def test_sql_validation(): "SELECT COUNT(*) FROM companies", "PRAGMA table_info(companies)", "SELECT name, revenue FROM companies WHERE sector = 'Technology'", - "SELECT c.name, f.revenue FROM companies c JOIN financial_records f ON c.id = f.company_id WHERE f.year = 2024" + "SELECT c.name, f.revenue FROM companies c JOIN financial_records f ON c.id = f.company_id WHERE f.year = 2024", ] - + print("\n=== SQL Validation Debug Tests ===") - + for i, query in enumerate(test_queries, 1): print(f"\n{i}. Testing query: {query[:50]}...") try: @@ -67,6 +73,7 @@ def test_sql_validation(): except Exception as e: print(f"āœ— FAILED: {type(e).__name__}: {e}") + if __name__ == "__main__": test_injection_patterns() - test_sql_validation() \ No newline at end of file + test_sql_validation() diff --git a/tests/integration/test_cache_resilience_e2e.py b/tests/integration/test_cache_resilience_e2e.py index 0d44730..1bc6030 100644 --- a/tests/integration/test_cache_resilience_e2e.py +++ b/tests/integration/test_cache_resilience_e2e.py @@ -31,21 +31,21 @@ # Add src to path for imports import sys + src_path = str(Path(__file__).parent.parent.parent / "src") if src_path not in sys.path: sys.path.insert(0, src_path) from cache.manager import CacheManager, CacheEntry from cache.resilience import ( - CacheCircuitBreaker, - ResilientCacheWrapper, - CircuitBreakerConfig, - CircuitState + CacheCircuitBreaker, + ResilientCacheWrapper, + CircuitBreakerConfig, + CircuitState, ) from cache.config import CacheConfig, get_default_cache_config from core.errors import CacheError - # Configure structured logging for tests structlog.configure( processors=[ @@ -54,7 +54,7 @@ structlog.stdlib.add_log_level, structlog.stdlib.PositionalArgumentsFormatter(), structlog.processors.TimeStamper(fmt="iso"), - structlog.dev.ConsoleRenderer() + structlog.dev.ConsoleRenderer(), ], logger_factory=structlog.stdlib.LoggerFactory(), wrapper_class=structlog.stdlib.BoundLogger, @@ -66,16 +66,16 @@ class TestCacheResilienceE2E: """End-to-end test suite for cache resilience implementation.""" - + @pytest.fixture(autouse=True) def setup_test_environment(self): """Set up real test environment with no mocking.""" # Create temporary database for testing - self.db_fd, self.db_path = tempfile.mkstemp(suffix='.db') - + self.db_fd, self.db_path = tempfile.mkstemp(suffix=".db") + # Initialize test database with realistic data self._setup_test_database() - + # Configure cache for testing self.cache_config = { "prefix": "fact_resilience_test", @@ -83,50 +83,51 @@ def setup_test_environment(self): "max_size": "5MB", "ttl_seconds": 30, # Short TTL for testing "hit_target_ms": 50, - "miss_target_ms": 140 + "miss_target_ms": 140, } - + # Configure circuit breaker for testing self.circuit_config = CircuitBreakerConfig( failure_threshold=3, # Low threshold for testing success_threshold=2, # Low threshold for testing timeout_seconds=5.0, # Short timeout for testing rolling_window_seconds=60.0, - health_check_interval=1.0 # Frequent for testing + health_check_interval=1.0, # Frequent for testing ) - + # Initialize real components self.cache_manager = CacheManager(self.cache_config) self.circuit_breaker = CacheCircuitBreaker(self.circuit_config) self.resilient_cache = ResilientCacheWrapper( - self.cache_manager, - self.circuit_breaker + self.cache_manager, self.circuit_breaker ) - + # Test metrics tracking self.test_metrics = { "operations_executed": 0, "successful_operations": 0, "failed_operations": 0, "circuit_state_changes": 0, - "performance_measurements": [] + "performance_measurements": [], } - - logger.info("Test environment initialized", - db_path=self.db_path, - cache_config=self.cache_config) - + + logger.info( + "Test environment initialized", + db_path=self.db_path, + cache_config=self.cache_config, + ) + yield - + # Cleanup os.close(self.db_fd) os.unlink(self.db_path) logger.info("Test environment cleaned up") - + def _setup_test_database(self): """Set up test database with realistic financial data.""" conn = sqlite3.connect(self.db_path) - + # Create financial data table conn.execute(""" CREATE TABLE financial_data ( @@ -138,183 +139,207 @@ def _setup_test_database(self): created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ) """) - + # Insert test data test_data = [ - ('Q1-2025', 1234567.89, 234567.89, 'Product Sales'), - ('Q4-2024', 1133221.55, 201221.55, 'Product Sales'), - ('Q3-2024', 987654.32, 187654.32, 'Service Revenue'), - ('Q2-2024', 876543.21, 176543.21, 'Product Sales'), - ('Q1-2024', 765432.10, 165432.10, 'Service Revenue'), + ("Q1-2025", 1234567.89, 234567.89, "Product Sales"), + ("Q4-2024", 1133221.55, 201221.55, "Product Sales"), + ("Q3-2024", 987654.32, 187654.32, "Service Revenue"), + ("Q2-2024", 876543.21, 176543.21, "Product Sales"), + ("Q1-2024", 765432.10, 165432.10, "Service Revenue"), ] - + for quarter, revenue, profit, category in test_data: conn.execute( "INSERT INTO financial_data (quarter, revenue, profit, category) VALUES (?, ?, ?, ?)", - (quarter, revenue, profit, category) + (quarter, revenue, profit, category), ) - + conn.commit() conn.close() logger.info("Test database initialized with financial data") - + async def test_cache_initialization_real_components(self): """Test cache initialization with real storage components.""" logger.info("Testing cache initialization with real components") - + # Verify cache manager is properly initialized assert self.cache_manager is not None assert self.cache_manager.prefix == "fact_resilience_test" assert self.cache_manager.min_tokens == 100 - + # Verify circuit breaker is properly initialized assert self.circuit_breaker is not None assert self.circuit_breaker.state == CircuitState.CLOSED - + # Verify resilient wrapper is properly initialized assert self.resilient_cache is not None assert self.resilient_cache.cache_manager is self.cache_manager assert self.resilient_cache.circuit_breaker is self.circuit_breaker - + # Test initial metrics metrics = self.resilient_cache.get_metrics() assert "cache" in metrics assert "circuit_breaker" in metrics assert metrics["circuit_breaker"]["state"] == "closed" - + logger.info("Cache initialization test completed successfully") - + async def test_normal_cache_operations_real_storage(self): """Test normal cache operations with real storage.""" logger.info("Testing normal cache operations with real storage") - - test_content = "This is test content for cache storage with sufficient tokens to meet the minimum requirement. " * 15 # Ensure minimum tokens + + test_content = ( + "This is test content for cache storage with sufficient tokens to meet the minimum requirement. " + * 15 + ) # Ensure minimum tokens query_hash = self.cache_manager.generate_hash("test_query") - + # Test store operation start_time = time.perf_counter() stored_entry = await self.resilient_cache.store(query_hash, test_content) store_time = (time.perf_counter() - start_time) * 1000 - + assert stored_entry is not None assert stored_entry == True # Resilient wrapper returns boolean - + # Test get operation (cache hit) start_time = time.perf_counter() retrieved_entry = await self.resilient_cache.get(query_hash) get_time = (time.perf_counter() - start_time) * 1000 - + assert retrieved_entry is not None assert retrieved_entry.content == test_content assert retrieved_entry.access_count >= 1 - + # Test cache miss missing_hash = self.cache_manager.generate_hash("non_existent_query") missing_entry = await self.resilient_cache.get(missing_hash) assert missing_entry is None - + # Test invalidation - invalidated_count = await self.resilient_cache.invalidate_by_prefix("fact_resilience_test") + invalidated_count = await self.resilient_cache.invalidate_by_prefix( + "fact_resilience_test" + ) assert invalidated_count >= 1 - + # Verify performance assert store_time < self.cache_config["miss_target_ms"] assert get_time < self.cache_config["hit_target_ms"] - + self.test_metrics["operations_executed"] += 4 self.test_metrics["successful_operations"] += 4 self.test_metrics["performance_measurements"].extend([store_time, get_time]) - - logger.info("Normal cache operations test completed successfully", - store_time_ms=store_time, - get_time_ms=get_time, - invalidated_count=invalidated_count) - + + logger.info( + "Normal cache operations test completed successfully", + store_time_ms=store_time, + get_time_ms=get_time, + invalidated_count=invalidated_count, + ) + async def test_circuit_breaker_failure_scenarios(self): """Test circuit breaker functionality with various failure scenarios.""" logger.info("Testing circuit breaker failure scenarios") - + # Create a failing cache manager mock within the real manager original_store = self.cache_manager.store - + def failing_store(query_hash: str, content: str): """Simulate cache storage failures.""" - raise CacheError("Simulated cache storage failure", error_code="CACHE_STORAGE_ERROR") - + raise CacheError( + "Simulated cache storage failure", error_code="CACHE_STORAGE_ERROR" + ) + # Test intermittent failures failure_count = 0 success_count = 0 - + try: # Patch the store method to simulate failures self.cache_manager.store = failing_store - + # Generate enough failures to trigger circuit breaker for i in range(5): try: - await self.resilient_cache.store(f"failing_query_{i}", "test content " * 20) + await self.resilient_cache.store( + f"failing_query_{i}", "test content " * 20 + ) success_count += 1 except CacheError: failure_count += 1 self.test_metrics["failed_operations"] += 1 - + self.test_metrics["operations_executed"] += 1 - + # Check circuit breaker state circuit_state = self.circuit_breaker.get_state() - logger.info("Circuit breaker state after failures", state=circuit_state.value) - + logger.info( + "Circuit breaker state after failures", state=circuit_state.value + ) + # Circuit should be open after threshold failures assert circuit_state == CircuitState.OPEN self.test_metrics["circuit_state_changes"] += 1 - + # Test graceful degradation try: - degraded_result = await self.resilient_cache.store("degraded_query", "test content " * 20) + degraded_result = await self.resilient_cache.store( + "degraded_query", "test content " * 20 + ) # Should return fallback response assert degraded_result == True self.test_metrics["successful_operations"] += 1 except CacheError as e: assert "CIRCUIT_BREAKER" in str(e.error_code) self.test_metrics["failed_operations"] += 1 - + self.test_metrics["operations_executed"] += 1 - + finally: # Restore original method self.cache_manager.store = original_store - - logger.info("Circuit breaker failure scenarios test completed", - failure_count=failure_count, - success_count=success_count, - final_state=circuit_state.value) - + + logger.info( + "Circuit breaker failure scenarios test completed", + failure_count=failure_count, + success_count=success_count, + final_state=circuit_state.value, + ) + async def test_circuit_breaker_recovery_mechanisms(self): """Test circuit breaker recovery mechanisms after timeout periods.""" logger.info("Testing circuit breaker recovery mechanisms") - + # Ensure circuit is open self.circuit_breaker.force_open() assert self.circuit_breaker.get_state() == CircuitState.OPEN - + # Wait for timeout period timeout_seconds = self.circuit_config.timeout_seconds - logger.info("Waiting for circuit breaker timeout", timeout_seconds=timeout_seconds) + logger.info( + "Waiting for circuit breaker timeout", timeout_seconds=timeout_seconds + ) await asyncio.sleep(timeout_seconds + 1) - + # Test half-open transition - test_content = "Recovery test content with sufficient tokens to meet the minimum requirement for cache storage. " * 12 + test_content = ( + "Recovery test content with sufficient tokens to meet the minimum requirement for cache storage. " + * 12 + ) query_hash = self.cache_manager.generate_hash("recovery_test_query") - + try: # This should transition to half-open and succeed result = await self.resilient_cache.store(query_hash, test_content) assert result == True - + # Verify state transition circuit_state = self.circuit_breaker.get_state() - logger.info("Circuit state after first recovery attempt", state=circuit_state.value) - + logger.info( + "Circuit state after first recovery attempt", state=circuit_state.value + ) + # Continue with successful operations to close circuit for i in range(self.circuit_config.success_threshold): recovery_hash = self.cache_manager.generate_hash(f"recovery_query_{i}") @@ -322,126 +347,140 @@ async def test_circuit_breaker_recovery_mechanisms(self): assert result == True self.test_metrics["successful_operations"] += 1 self.test_metrics["operations_executed"] += 1 - + # Circuit should now be closed final_state = self.circuit_breaker.get_state() assert final_state == CircuitState.CLOSED self.test_metrics["circuit_state_changes"] += 1 - - logger.info("Circuit breaker recovery completed successfully", - final_state=final_state.value) - + + logger.info( + "Circuit breaker recovery completed successfully", + final_state=final_state.value, + ) + except Exception as e: logger.error("Recovery test failed", error=str(e)) self.test_metrics["failed_operations"] += 1 raise - + async def test_performance_under_various_conditions(self): """Test and measure performance under various conditions.""" logger.info("Testing performance under various conditions") - + performance_results = { "normal_operations": [], "high_load": [], "degraded_mode": [], - "recovery_mode": [] + "recovery_mode": [], } - + # Test normal operations performance for i in range(10): query_hash = self.cache_manager.generate_hash(f"perf_test_normal_{i}") content = f"Performance test content {i}. " * 20 - + start_time = time.perf_counter() await self.resilient_cache.store(query_hash, content) store_time = (time.perf_counter() - start_time) * 1000 - + start_time = time.perf_counter() retrieved = await self.resilient_cache.get(query_hash) get_time = (time.perf_counter() - start_time) * 1000 - - performance_results["normal_operations"].append({ - "store_time_ms": store_time, - "get_time_ms": get_time - }) - + + performance_results["normal_operations"].append( + {"store_time_ms": store_time, "get_time_ms": get_time} + ) + self.test_metrics["operations_executed"] += 2 self.test_metrics["successful_operations"] += 2 - + # Test high load performance with concurrent operations async def concurrent_operation(index: int): - query_hash = self.cache_manager.generate_hash(f"perf_test_concurrent_{index}") + query_hash = self.cache_manager.generate_hash( + f"perf_test_concurrent_{index}" + ) content = f"Concurrent test content {index}. " * 20 - + start_time = time.perf_counter() await self.resilient_cache.store(query_hash, content) total_time = (time.perf_counter() - start_time) * 1000 return total_time - + # Run concurrent operations concurrent_tasks = [concurrent_operation(i) for i in range(20)] concurrent_times = await asyncio.gather(*concurrent_tasks) - - performance_results["high_load"] = [{"operation_time_ms": t} for t in concurrent_times] + + performance_results["high_load"] = [ + {"operation_time_ms": t} for t in concurrent_times + ] self.test_metrics["operations_executed"] += len(concurrent_tasks) self.test_metrics["successful_operations"] += len(concurrent_tasks) - + # Test degraded mode performance self.circuit_breaker.force_open() - + for i in range(5): query_hash = self.cache_manager.generate_hash(f"perf_test_degraded_{i}") content = f"Degraded test content {i}. " * 20 - + start_time = time.perf_counter() try: await self.resilient_cache.store(query_hash, content) operation_time = (time.perf_counter() - start_time) * 1000 - performance_results["degraded_mode"].append({ - "operation_time_ms": operation_time, - "status": "success" - }) + performance_results["degraded_mode"].append( + {"operation_time_ms": operation_time, "status": "success"} + ) self.test_metrics["successful_operations"] += 1 except CacheError: operation_time = (time.perf_counter() - start_time) * 1000 - performance_results["degraded_mode"].append({ - "operation_time_ms": operation_time, - "status": "circuit_breaker_blocked" - }) + performance_results["degraded_mode"].append( + { + "operation_time_ms": operation_time, + "status": "circuit_breaker_blocked", + } + ) self.test_metrics["failed_operations"] += 1 - + self.test_metrics["operations_executed"] += 1 - + # Reset circuit for recovery test self.circuit_breaker.force_closed() - + # Calculate performance statistics - normal_store_times = [r["store_time_ms"] for r in performance_results["normal_operations"]] - normal_get_times = [r["get_time_ms"] for r in performance_results["normal_operations"]] - concurrent_times = [r["operation_time_ms"] for r in performance_results["high_load"]] - + normal_store_times = [ + r["store_time_ms"] for r in performance_results["normal_operations"] + ] + normal_get_times = [ + r["get_time_ms"] for r in performance_results["normal_operations"] + ] + concurrent_times = [ + r["operation_time_ms"] for r in performance_results["high_load"] + ] + avg_store_time = sum(normal_store_times) / len(normal_store_times) avg_get_time = sum(normal_get_times) / len(normal_get_times) avg_concurrent_time = sum(concurrent_times) / len(concurrent_times) - + # Verify performance meets targets assert avg_store_time < self.cache_config["miss_target_ms"] assert avg_get_time < self.cache_config["hit_target_ms"] - - self.test_metrics["performance_measurements"].extend([ - avg_store_time, avg_get_time, avg_concurrent_time - ]) - - logger.info("Performance testing completed", - avg_store_time_ms=avg_store_time, - avg_get_time_ms=avg_get_time, - avg_concurrent_time_ms=avg_concurrent_time, - performance_results=performance_results) - + + self.test_metrics["performance_measurements"].extend( + [avg_store_time, avg_get_time, avg_concurrent_time] + ) + + logger.info( + "Performance testing completed", + avg_store_time_ms=avg_store_time, + avg_get_time_ms=avg_get_time, + avg_concurrent_time_ms=avg_concurrent_time, + performance_results=performance_results, + ) + async def test_metrics_collection_and_validation(self): """Test comprehensive metrics collection and validation.""" logger.info("Testing metrics collection and validation") - + # Perform various operations to generate metrics test_operations = [ ("store_op_1", "store", "test content 1 " * 20), @@ -449,29 +488,34 @@ async def test_metrics_collection_and_validation(self): ("store_op_2", "store", "test content 2 " * 20), ("invalidate_op", "invalidate", None), ] - + for op_id, operation, content in test_operations: query_hash = self.cache_manager.generate_hash(op_id) - + try: if operation == "store": await self.resilient_cache.store(query_hash, content) elif operation == "get": await self.resilient_cache.get(query_hash) elif operation == "invalidate": - await self.resilient_cache.invalidate_by_prefix("fact_resilience_test") - + await self.resilient_cache.invalidate_by_prefix( + "fact_resilience_test" + ) + self.test_metrics["successful_operations"] += 1 except Exception as e: - logger.warning("Operation failed during metrics test", - operation=operation, error=str(e)) + logger.warning( + "Operation failed during metrics test", + operation=operation, + error=str(e), + ) self.test_metrics["failed_operations"] += 1 - + self.test_metrics["operations_executed"] += 1 - + # Get comprehensive metrics metrics = self.resilient_cache.get_metrics() - + # Validate cache metrics assert "cache" in metrics cache_metrics = metrics["cache"] @@ -479,7 +523,7 @@ async def test_metrics_collection_and_validation(self): assert "total_requests" in cache_metrics assert "cache_hits" in cache_metrics assert "cache_misses" in cache_metrics - + # Validate circuit breaker metrics assert "circuit_breaker" in metrics cb_metrics = metrics["circuit_breaker"] @@ -487,169 +531,199 @@ async def test_metrics_collection_and_validation(self): assert "failure_count" in cb_metrics assert "success_count" in cb_metrics assert "total_operations" in cb_metrics - + # Validate performance stats perf_stats = self.cache_manager.get_performance_stats() assert "avg_hit_latency_ms" in perf_stats assert "avg_miss_latency_ms" in perf_stats assert "hit_latency_compliance" in perf_stats assert "miss_latency_compliance" in perf_stats - - logger.info("Metrics collection and validation completed", - cache_metrics=cache_metrics, - circuit_breaker_metrics=cb_metrics, - performance_stats=perf_stats) - + + logger.info( + "Metrics collection and validation completed", + cache_metrics=cache_metrics, + circuit_breaker_metrics=cb_metrics, + performance_stats=perf_stats, + ) + async def test_real_database_integration(self): """Test integration with real database operations.""" logger.info("Testing real database integration") - + # Test database connectivity conn = sqlite3.connect(self.db_path) cursor = conn.cursor() - + # Perform real database query - cursor.execute("SELECT quarter, revenue FROM financial_data ORDER BY quarter DESC LIMIT 3") + cursor.execute( + "SELECT quarter, revenue FROM financial_data ORDER BY quarter DESC LIMIT 3" + ) results = cursor.fetchall() - + assert len(results) >= 3 logger.info("Database query successful", results=results) - + # Test caching of database query results query_content = json.dumps(results) db_query_hash = self.cache_manager.generate_hash("latest_financial_data") - + # Store database results in cache await self.resilient_cache.store(db_query_hash, query_content) - + # Retrieve from cache cached_result = await self.resilient_cache.get(db_query_hash) assert cached_result is not None assert cached_result.content == query_content - + # Test cache invalidation on database changes - cursor.execute("INSERT INTO financial_data (quarter, revenue, profit, category) VALUES (?, ?, ?, ?)", - ('Q2-2025', 1500000.00, 300000.00, 'Product Sales')) + cursor.execute( + "INSERT INTO financial_data (quarter, revenue, profit, category) VALUES (?, ?, ?, ?)", + ("Q2-2025", 1500000.00, 300000.00, "Product Sales"), + ) conn.commit() - + # Invalidate cache after database change - invalidated = await self.resilient_cache.invalidate_by_prefix("fact_resilience_test") + invalidated = await self.resilient_cache.invalidate_by_prefix( + "fact_resilience_test" + ) assert invalidated >= 1 - + conn.close() - + self.test_metrics["operations_executed"] += 3 self.test_metrics["successful_operations"] += 3 - + logger.info("Real database integration test completed successfully") - + async def test_stress_and_failure_recovery(self): """Test system behavior under stress and various failure patterns.""" logger.info("Testing stress and failure recovery") - + # Test rapid successive operations rapid_operations = [] for i in range(50): query_hash = self.cache_manager.generate_hash(f"stress_test_{i}") content = f"Stress test content {i}. " * 25 rapid_operations.append(self.resilient_cache.store(query_hash, content)) - + # Execute all operations concurrently start_time = time.perf_counter() results = await asyncio.gather(*rapid_operations, return_exceptions=True) total_time = (time.perf_counter() - start_time) * 1000 - + # Analyze results successful_ops = sum(1 for r in results if not isinstance(r, Exception)) failed_ops = len(results) - successful_ops - - logger.info("Stress test completed", - total_operations=len(results), - successful_operations=successful_ops, - failed_operations=failed_ops, - total_time_ms=total_time, - avg_time_per_op_ms=total_time / len(results)) - + + logger.info( + "Stress test completed", + total_operations=len(results), + successful_operations=successful_ops, + failed_operations=failed_ops, + total_time_ms=total_time, + avg_time_per_op_ms=total_time / len(results), + ) + # Test recovery after stress recovery_ops = [] for i in range(10): query_hash = self.cache_manager.generate_hash(f"recovery_test_{i}") recovery_ops.append(self.resilient_cache.get(query_hash)) - + recovery_results = await asyncio.gather(*recovery_ops, return_exceptions=True) - recovery_successful = sum(1 for r in recovery_results if not isinstance(r, Exception)) - + recovery_successful = sum( + 1 for r in recovery_results if not isinstance(r, Exception) + ) + self.test_metrics["operations_executed"] += len(results) + len(recovery_results) - self.test_metrics["successful_operations"] += successful_ops + recovery_successful - self.test_metrics["failed_operations"] += failed_ops + (len(recovery_results) - recovery_successful) - - logger.info("Stress and recovery test completed", - recovery_successful=recovery_successful, - recovery_total=len(recovery_results)) - + self.test_metrics["successful_operations"] += ( + successful_ops + recovery_successful + ) + self.test_metrics["failed_operations"] += failed_ops + ( + len(recovery_results) - recovery_successful + ) + + logger.info( + "Stress and recovery test completed", + recovery_successful=recovery_successful, + recovery_total=len(recovery_results), + ) + async def test_comprehensive_system_validation(self): """Comprehensive validation of the entire system.""" logger.info("Running comprehensive system validation") - + # Test all components work together test_scenarios = [ "normal_operation", "circuit_breaker_open", "circuit_breaker_recovery", "high_concurrency", - "database_integration" + "database_integration", ] - + validation_results = {} - + for scenario in test_scenarios: try: if scenario == "normal_operation": # Test normal cache operations for i in range(5): - hash_key = self.cache_manager.generate_hash(f"validation_{scenario}_{i}") + hash_key = self.cache_manager.generate_hash( + f"validation_{scenario}_{i}" + ) content = f"Validation content for {scenario} {i}. " * 20 await self.resilient_cache.store(hash_key, content) retrieved = await self.resilient_cache.get(hash_key) assert retrieved is not None - + validation_results[scenario] = "PASS" - + elif scenario == "circuit_breaker_open": # Force circuit open and test graceful degradation self.circuit_breaker.force_open() - hash_key = self.cache_manager.generate_hash(f"validation_{scenario}") + hash_key = self.cache_manager.generate_hash( + f"validation_{scenario}" + ) content = f"Validation content for {scenario}. " * 20 - + # Should handle gracefully result = await self.resilient_cache.store(hash_key, content) validation_results[scenario] = "PASS" - + elif scenario == "circuit_breaker_recovery": # Test recovery mechanism self.circuit_breaker.force_closed() - hash_key = self.cache_manager.generate_hash(f"validation_{scenario}") + hash_key = self.cache_manager.generate_hash( + f"validation_{scenario}" + ) content = f"Validation content for {scenario}. " * 20 - + result = await self.resilient_cache.store(hash_key, content) assert result == True validation_results[scenario] = "PASS" - + elif scenario == "high_concurrency": # Test concurrent operations concurrent_ops = [] for i in range(20): - hash_key = self.cache_manager.generate_hash(f"validation_{scenario}_{i}") + hash_key = self.cache_manager.generate_hash( + f"validation_{scenario}_{i}" + ) content = f"Concurrent validation content {i}. " * 15 - concurrent_ops.append(self.resilient_cache.store(hash_key, content)) - - results = await asyncio.gather(*concurrent_ops, return_exceptions=True) + concurrent_ops.append( + self.resilient_cache.store(hash_key, content) + ) + + results = await asyncio.gather( + *concurrent_ops, return_exceptions=True + ) successful = sum(1 for r in results if not isinstance(r, Exception)) - + # At least 80% should succeed assert successful >= len(results) * 0.8 validation_results[scenario] = "PASS" - + elif scenario == "database_integration": # Test with real database conn = sqlite3.connect(self.db_path) @@ -657,78 +731,87 @@ async def test_comprehensive_system_validation(self): cursor.execute("SELECT COUNT(*) FROM financial_data") count = cursor.fetchone()[0] conn.close() - + # Cache the result hash_key = self.cache_manager.generate_hash("db_count_query") - await self.resilient_cache.store(hash_key, f"Database count: {count}") - + await self.resilient_cache.store( + hash_key, f"Database count: {count}" + ) + # Retrieve from cache cached = await self.resilient_cache.get(hash_key) assert cached is not None validation_results[scenario] = "PASS" - + self.test_metrics["successful_operations"] += 1 - + except Exception as e: - logger.error("Validation scenario failed", - scenario=scenario, error=str(e)) + logger.error( + "Validation scenario failed", scenario=scenario, error=str(e) + ) validation_results[scenario] = f"FAIL: {str(e)}" self.test_metrics["failed_operations"] += 1 - + self.test_metrics["operations_executed"] += 1 - + # Final system health check final_metrics = self.resilient_cache.get_metrics() final_circuit_state = self.circuit_breaker.get_state() - - logger.info("Comprehensive system validation completed", - validation_results=validation_results, - final_metrics=final_metrics, - final_circuit_state=final_circuit_state.value, - test_metrics=self.test_metrics) - + + logger.info( + "Comprehensive system validation completed", + validation_results=validation_results, + final_metrics=final_metrics, + final_circuit_state=final_circuit_state.value, + test_metrics=self.test_metrics, + ) + # Ensure all critical scenarios passed critical_scenarios = ["normal_operation", "database_integration"] for scenario in critical_scenarios: - assert validation_results.get(scenario) == "PASS", f"Critical scenario {scenario} failed" - + assert ( + validation_results.get(scenario) == "PASS" + ), f"Critical scenario {scenario} failed" + async def test_run_complete_test_suite(self): """Run the complete test suite in sequence.""" logger.info("Starting complete cache resilience test suite") - + try: # Initialize environment await self.test_cache_initialization_real_components() - + # Test normal operations await self.test_normal_cache_operations_real_storage() - + # Test failure scenarios await self.test_circuit_breaker_failure_scenarios() - + # Test recovery await self.test_circuit_breaker_recovery_mechanisms() - + # Test performance await self.test_performance_under_various_conditions() - + # Test metrics await self.test_metrics_collection_and_validation() - + # Test database integration await self.test_real_database_integration() - + # Test stress and recovery await self.test_stress_and_failure_recovery() - + # Final validation await self.test_comprehensive_system_validation() - - logger.info("Complete cache resilience test suite completed successfully", - final_test_metrics=self.test_metrics) - + + logger.info( + "Complete cache resilience test suite completed successfully", + final_test_metrics=self.test_metrics, + ) + return self.test_metrics - + except Exception as e: logger.error("Test suite failed", error=str(e)) raise @@ -736,11 +819,12 @@ async def test_run_complete_test_suite(self): # Standalone test runner for direct execution if __name__ == "__main__": + async def run_tests(): """Run tests directly.""" test_instance = TestCacheResilienceE2E() test_instance.setup_test_environment() - + try: results = await test_instance.test_run_complete_test_suite() print(f"\nTest Results Summary:") @@ -748,15 +832,19 @@ async def run_tests(): print(f"Successful Operations: {results['successful_operations']}") print(f"Failed Operations: {results['failed_operations']}") print(f"Circuit State Changes: {results['circuit_state_changes']}") - print(f"Performance Measurements: {len(results['performance_measurements'])}") - - if results['performance_measurements']: - avg_perf = sum(results['performance_measurements']) / len(results['performance_measurements']) + print( + f"Performance Measurements: {len(results['performance_measurements'])}" + ) + + if results["performance_measurements"]: + avg_perf = sum(results["performance_measurements"]) / len( + results["performance_measurements"] + ) print(f"Average Performance: {avg_perf:.2f}ms") - + except Exception as e: print(f"Test execution failed: {e}") raise - + # Run the tests - asyncio.run(run_tests()) \ No newline at end of file + asyncio.run(run_tests()) diff --git a/tests/integration/test_cache_resilience_e2e_sync.py b/tests/integration/test_cache_resilience_e2e_sync.py index 856b7a2..d9c385a 100644 --- a/tests/integration/test_cache_resilience_e2e_sync.py +++ b/tests/integration/test_cache_resilience_e2e_sync.py @@ -33,7 +33,7 @@ structlog.processors.StackInfoRenderer(), structlog.processors.format_exc_info, structlog.processors.UnicodeDecoder(), - structlog.processors.JSONRenderer() + structlog.processors.JSONRenderer(), ], context_class=dict, logger_factory=structlog.stdlib.LoggerFactory(), @@ -46,48 +46,48 @@ class TestCacheResilienceE2ESync(unittest.TestCase): """Comprehensive End-to-End Cache Resilience Test Suite with Real Components""" - + def setUp(self): """Set up test environment with real components""" - self.temp_db = tempfile.NamedTemporaryFile(delete=False, suffix='.db') + self.temp_db = tempfile.NamedTemporaryFile(delete=False, suffix=".db") self.temp_db.close() self.db_path = self.temp_db.name - + # Initialize test database with sample data self._setup_test_database() - + # Configure cache with real settings self.cache_config = { - 'prefix': 'fact_resilience_test', - 'min_tokens': 100, - 'max_size': '5MB', - 'ttl_seconds': 30, - 'hit_target_ms': 50, - 'miss_target_ms': 140 + "prefix": "fact_resilience_test", + "min_tokens": 100, + "max_size": "5MB", + "ttl_seconds": 30, + "hit_target_ms": 50, + "miss_target_ms": 140, } - + # Initialize real cache manager self.cache_manager = CacheManager(self.cache_config) - + # Initialize circuit breaker with test-friendly settings circuit_config = CircuitBreakerConfig( - failure_threshold=3, - timeout_seconds=5.0, - success_threshold=3 + failure_threshold=3, timeout_seconds=5.0, success_threshold=3 ) - + # Initialize circuit breaker and resilient cache wrapper from cache.resilience import CacheCircuitBreaker + circuit_breaker = CacheCircuitBreaker(circuit_config) self.resilient_cache = ResilientCacheWrapper( - cache_manager=self.cache_manager, - circuit_breaker=circuit_breaker + cache_manager=self.cache_manager, circuit_breaker=circuit_breaker + ) + + logger.info( + "Test environment initialized", + cache_config=self.cache_config, + db_path=self.db_path, ) - - logger.info("Test environment initialized", - cache_config=self.cache_config, - db_path=self.db_path) - + def _run_async(self, coro): """Helper to run async operations in sync tests""" loop = asyncio.new_event_loop() @@ -96,12 +96,12 @@ def _run_async(self, coro): return loop.run_until_complete(coro) finally: loop.close() - + def tearDown(self): """Clean up test environment""" try: # Clean up cache manually since CacheManager doesn't have clear() - if hasattr(self, 'cache_manager'): + if hasattr(self, "cache_manager"): try: with self.cache_manager._lock: self.cache_manager.cache.clear() @@ -112,14 +112,14 @@ def tearDown(self): logger.info("Test environment cleaned up") except Exception as e: logger.warning("Cleanup warning", error=str(e)) - + def _setup_test_database(self): """Initialize test database with sample financial data""" conn = sqlite3.connect(self.db_path) cursor = conn.cursor() - + # Create tables - cursor.execute(''' + cursor.execute(""" CREATE TABLE companies ( id INTEGER PRIMARY KEY, symbol TEXT UNIQUE, @@ -127,9 +127,9 @@ def _setup_test_database(self): sector TEXT, market_cap REAL ) - ''') - - cursor.execute(''' + """) + + cursor.execute(""" CREATE TABLE financial_data ( id INTEGER PRIMARY KEY, company_id INTEGER, @@ -138,194 +138,222 @@ def _setup_test_database(self): net_income REAL, FOREIGN KEY (company_id) REFERENCES companies (id) ) - ''') - + """) + # Insert sample data companies = [ - ('AAPL', 'Apple Inc.', 'Technology', 2800000000000), - ('MSFT', 'Microsoft Corporation', 'Technology', 2400000000000), - ('GOOGL', 'Alphabet Inc.', 'Technology', 1600000000000), - ('AMZN', 'Amazon.com Inc.', 'Consumer Discretionary', 1500000000000), - ('TSLA', 'Tesla Inc.', 'Consumer Discretionary', 800000000000) + ("AAPL", "Apple Inc.", "Technology", 2800000000000), + ("MSFT", "Microsoft Corporation", "Technology", 2400000000000), + ("GOOGL", "Alphabet Inc.", "Technology", 1600000000000), + ("AMZN", "Amazon.com Inc.", "Consumer Discretionary", 1500000000000), + ("TSLA", "Tesla Inc.", "Consumer Discretionary", 800000000000), ] - + cursor.executemany( - 'INSERT INTO companies (symbol, name, sector, market_cap) VALUES (?, ?, ?, ?)', - companies + "INSERT INTO companies (symbol, name, sector, market_cap) VALUES (?, ?, ?, ?)", + companies, ) - + # Insert financial data financial_data = [ - (1, '2024-Q1', 90750000000, 23636000000), - (1, '2024-Q2', 94930000000, 25010000000), - (2, '2024-Q1', 61858000000, 21939000000), - (2, '2024-Q2', 64728000000, 22036000000), - (3, '2024-Q1', 80539000000, 15051000000), - (3, '2024-Q2', 84742000000, 16130000000) + (1, "2024-Q1", 90750000000, 23636000000), + (1, "2024-Q2", 94930000000, 25010000000), + (2, "2024-Q1", 61858000000, 21939000000), + (2, "2024-Q2", 64728000000, 22036000000), + (3, "2024-Q1", 80539000000, 15051000000), + (3, "2024-Q2", 84742000000, 16130000000), ] - + cursor.executemany( - 'INSERT INTO financial_data (company_id, date, revenue, net_income) VALUES (?, ?, ?, ?)', - financial_data + "INSERT INTO financial_data (company_id, date, revenue, net_income) VALUES (?, ?, ?, ?)", + financial_data, ) - + conn.commit() conn.close() logger.info("Test database initialized with financial data") - + def _get_financial_data(self, symbol: str) -> str: """Retrieve financial data from real database""" conn = sqlite3.connect(self.db_path) cursor = conn.cursor() - - query = ''' + + query = """ SELECT c.name, c.sector, c.market_cap, f.date, f.revenue, f.net_income FROM companies c LEFT JOIN financial_data f ON c.id = f.company_id WHERE c.symbol = ? ORDER BY f.date DESC - ''' - + """ + cursor.execute(query, (symbol,)) results = cursor.fetchall() conn.close() - + if not results: return f"No data found for {symbol}" - + # Format as comprehensive financial report company_name = results[0][0] sector = results[0][1] market_cap = results[0][2] - + report = f"Financial Analysis Report for {company_name} ({symbol})\n" report += f"Sector: {sector}\n" report += f"Market Cap: ${market_cap:,.0f}\n" report += "\nQuarterly Financial Performance:\n" - + for row in results: if row[3]: # Has financial data - report += f" {row[3]}: Revenue ${row[4]:,.0f}, Net Income ${row[5]:,.0f}\n" - + report += ( + f" {row[3]}: Revenue ${row[4]:,.0f}, Net Income ${row[5]:,.0f}\n" + ) + report += "\nThis comprehensive financial analysis provides detailed insights into the company's performance, " report += "market position, revenue trends, profitability metrics, and sector comparison data. " report += "The analysis includes quarterly revenue figures, net income calculations, market capitalization data, " report += "and comprehensive sector analysis to provide investors with actionable intelligence for investment decisions." - + return report - + def test_cache_initialization_real_components(self): """Test cache initialization with real storage components""" # Verify cache manager is properly initialized self.assertIsNotNone(self.cache_manager) - self.assertEqual(self.cache_manager.prefix, 'fact_resilience_test') + self.assertEqual(self.cache_manager.prefix, "fact_resilience_test") self.assertEqual(self.cache_manager.min_tokens, 100) - + # Verify resilient wrapper is properly initialized self.assertIsNotNone(self.resilient_cache) self.assertIsNotNone(self.resilient_cache.circuit_breaker) self.assertTrue(self.resilient_cache.enable_graceful_degradation) - + # Test basic cache functionality - test_content = "This is test content for cache storage with sufficient tokens to meet the minimum requirement. " * 15 # Ensure minimum tokens - + test_content = ( + "This is test content for cache storage with sufficient tokens to meet the minimum requirement. " + * 15 + ) # Ensure minimum tokens + # Store content result = self._run_async(self.resilient_cache.store("test_key", test_content)) self.assertTrue(result) - + # Retrieve content retrieved = self._run_async(self.resilient_cache.get("test_key")) self.assertIsNotNone(retrieved) self.assertEqual(retrieved.content, test_content) - + logger.info("Cache initialization test completed successfully") - + def test_normal_cache_operations_real_storage(self): """Test normal cache operations with real storage backend""" # Test cache miss scenario missing_content = self._run_async(self.resilient_cache.get("nonexistent_key")) self.assertIsNone(missing_content) - + # Test cache store and hit scenario for i in range(5): - symbol = ['AAPL', 'MSFT', 'GOOGL', 'AMZN', 'TSLA'][i] + symbol = ["AAPL", "MSFT", "GOOGL", "AMZN", "TSLA"][i] financial_data = self._get_financial_data(symbol) - + # Store in cache cache_key = f"financial_data_{symbol}" - store_result = self._run_async(self.resilient_cache.store(cache_key, financial_data)) + store_result = self._run_async( + self.resilient_cache.store(cache_key, financial_data) + ) self.assertTrue(store_result, f"Failed to store data for {symbol}") - + # Retrieve from cache (should be a hit) cached_data = self._run_async(self.resilient_cache.get(cache_key)) - self.assertEqual(cached_data.content, financial_data, f"Cache hit failed for {symbol}") - + self.assertEqual( + cached_data.content, financial_data, f"Cache hit failed for {symbol}" + ) + # Test cache metrics metrics = self.resilient_cache.get_metrics() self.assertIsInstance(metrics, dict) - self.assertIn('cache', metrics) - self.assertIn('circuit_breaker', metrics) - + self.assertIn("cache", metrics) + self.assertIn("circuit_breaker", metrics) + logger.info("Normal cache operations test completed", metrics=metrics) - + def test_circuit_breaker_failure_scenarios(self): """Test circuit breaker with forced failure scenarios""" # Simulate cache failures by patching the underlying cache original_get = self.cache_manager.get original_store = self.cache_manager.store - + def failing_get(key): raise Exception("Cache storage failure") - + def failing_store(key, content): raise Exception("Cache storage failure") - + # Force failures to trigger circuit breaker - with patch.object(self.cache_manager, 'get', side_effect=failing_get): - with patch.object(self.cache_manager, 'store', side_effect=failing_store): - + with patch.object(self.cache_manager, "get", side_effect=failing_get): + with patch.object(self.cache_manager, "store", side_effect=failing_store): + # Trigger multiple failures to open circuit breaker for i in range(5): try: - result = self._run_async(self.resilient_cache.get(f"test_key_{i}")) + result = self._run_async( + self.resilient_cache.get(f"test_key_{i}") + ) # First few calls may still raise exceptions before circuit opens - if self.resilient_cache.circuit_breaker.state == CircuitState.OPEN: + if ( + self.resilient_cache.circuit_breaker.state + == CircuitState.OPEN + ): self.assertIsNone(result) except Exception as e: # Expected until circuit breaker opens - logger.debug(f"Expected failure during circuit breaker activation: {e}") - + logger.debug( + f"Expected failure during circuit breaker activation: {e}" + ) + try: - store_result = self._run_async(self.resilient_cache.store(f"test_key_{i}", f"test_content_{i}")) + store_result = self._run_async( + self.resilient_cache.store( + f"test_key_{i}", f"test_content_{i}" + ) + ) # Should return False when circuit is open - if self.resilient_cache.circuit_breaker.state == CircuitState.OPEN: + if ( + self.resilient_cache.circuit_breaker.state + == CircuitState.OPEN + ): self.assertFalse(store_result) except Exception as e: # Expected until circuit breaker opens - logger.debug(f"Expected failure during circuit breaker activation: {e}") - + logger.debug( + f"Expected failure during circuit breaker activation: {e}" + ) + # Check circuit breaker state - should be open after multiple failures from cache.resilience import CircuitState - self.assertEqual(self.resilient_cache.circuit_breaker.state, CircuitState.OPEN) - + + self.assertEqual( + self.resilient_cache.circuit_breaker.state, CircuitState.OPEN + ) + # Test circuit breaker blocks requests when open - with patch.object(self.cache_manager, 'get', side_effect=original_get): + with patch.object(self.cache_manager, "get", side_effect=original_get): # Should still be blocked by circuit breaker and use graceful degradation result = self._run_async(self.resilient_cache.get("test_key")) self.assertIsNone(result) # Should return None due to graceful degradation - + logger.info("Circuit breaker failure scenario test completed") - + def test_circuit_breaker_recovery_mechanisms(self): """Test circuit breaker recovery after timeout period""" # Force circuit breaker to open original_get = self.cache_manager.get - + def failing_get(key): raise Exception("Cache failure") - + # Trigger failures - with patch.object(self.cache_manager, 'get', side_effect=failing_get): + with patch.object(self.cache_manager, "get", side_effect=failing_get): for i in range(5): try: self._run_async(self.resilient_cache.get(f"test_key_{i}")) @@ -339,31 +367,34 @@ def failing_get(key): # Test recovery after timeout # Short timeout for testing - just wait a bit and restore functionality time.sleep(0.1) - + # Reset to normal functionality - with patch.object(self.cache_manager, 'get', return_value=None): + with patch.object(self.cache_manager, "get", return_value=None): # This should work again result = self._run_async(self.resilient_cache.get("recovery_test")) # The operation should complete even if it returns None - + logger.info("Circuit breaker recovery test completed") - + # Verify circuit breaker is open self.assertTrue(self.resilient_cache.circuit_breaker.is_open()) - + # Wait for timeout (use shorter timeout for testing) self.resilient_cache.circuit_breaker.config.timeout_seconds = 0.5 time.sleep(0.8) # Wait a bit longer than timeout to ensure transition - + # Try an operation to trigger transition to half-open # Use original working get method - with patch.object(self.cache_manager, 'get', return_value=None): + with patch.object(self.cache_manager, "get", return_value=None): try: result = self._run_async(self.resilient_cache.get("transition_test")) # Circuit breaker should now be half-open or closed after timeout state = self.resilient_cache.circuit_breaker.get_state() - self.assertIn(state, [CircuitState.HALF_OPEN, CircuitState.CLOSED], - f"Expected HALF_OPEN or CLOSED, got {state}") + self.assertIn( + state, + [CircuitState.HALF_OPEN, CircuitState.CLOSED], + f"Expected HALF_OPEN or CLOSED, got {state}", + ) except Exception as e: # If still throwing errors, check if it's due to circuit breaker state if "CIRCUIT_BREAKER" in str(e): @@ -372,7 +403,7 @@ def failing_get(key): else: # Unexpected error, re-raise raise - + # Ensure circuit breaker is in HALF_OPEN state before recovery state = self.resilient_cache.circuit_breaker.get_state() if state == CircuitState.OPEN: @@ -383,172 +414,208 @@ def failing_get(key): self._run_async(self.resilient_cache.get("force_transition")) except: pass - + # Test successful operations to close circuit breaker - test_content = "Recovery test content with sufficient tokens to meet the minimum requirement for cache storage. " * 12 - + test_content = ( + "Recovery test content with sufficient tokens to meet the minimum requirement for cache storage. " + * 12 + ) + # Perform multiple successful operations (circuit breaker needs 3 successes to close) success_count = 0 stored_queries = [] max_attempts = 10 # Increase attempts to ensure we get enough successes - + for i in range(max_attempts): try: query = f"recovery_test_{i}" # Generate the hash that will be used as the storage key query_hash = self.resilient_cache.generate_hash(query) # Store using the generated hash as the key - store_result = self._run_async(self.resilient_cache.store(query_hash, test_content)) - + store_result = self._run_async( + self.resilient_cache.store(query_hash, test_content) + ) + # Verify actual storage by checking if item exists in cache # This is important because graceful degradation returns True even when not stored actually_stored = self.cache_manager.get(query_hash) is not None - + if store_result and actually_stored: success_count += 1 stored_queries.append((query, query_hash)) - logger.info(f"Successful store operation {success_count} (actually stored)") + logger.info( + f"Successful store operation {success_count} (actually stored)" + ) elif store_result and not actually_stored: - logger.info(f"Store returned True but item not actually stored (graceful degradation)") + logger.info( + f"Store returned True but item not actually stored (graceful degradation)" + ) else: - logger.info(f"Store operation failed: store_result={store_result}, actually_stored={actually_stored}") - + logger.info( + f"Store operation failed: store_result={store_result}, actually_stored={actually_stored}" + ) + # Add small delay between operations to ensure proper state transitions time.sleep(0.02) - + # Only break if we have enough successes AND circuit breaker is closed if success_count >= 3: state = self.resilient_cache.circuit_breaker.get_state() - logger.info(f"After {success_count} successes, circuit breaker state: {state}") + logger.info( + f"After {success_count} successes, circuit breaker state: {state}" + ) if self.resilient_cache.circuit_breaker.is_closed(): - logger.info(f"Circuit breaker closed after {success_count} successful operations") + logger.info( + f"Circuit breaker closed after {success_count} successful operations" + ) break - + except Exception as e: logger.warning(f"Store operation failed during recovery: {e}") continue - + # Verify we got at least 3 successful operations - logger.info(f"Recovery test completed with {success_count} successful operations out of {max_attempts} attempts") - self.assertGreaterEqual(success_count, 3, f"Expected at least 3 successful operations, got {success_count}") - + logger.info( + f"Recovery test completed with {success_count} successful operations out of {max_attempts} attempts" + ) + self.assertGreaterEqual( + success_count, + 3, + f"Expected at least 3 successful operations, got {success_count}", + ) + # Circuit breaker should now be closed after 3 successful operations state = self.resilient_cache.circuit_breaker.get_state() - self.assertTrue(self.resilient_cache.circuit_breaker.is_closed(), - f"Expected circuit breaker to be CLOSED after {success_count} successes, but state is {state}") - + self.assertTrue( + self.resilient_cache.circuit_breaker.is_closed(), + f"Expected circuit breaker to be CLOSED after {success_count} successes, but state is {state}", + ) + # Add a small delay before retrieval to ensure storage is complete time.sleep(0.1) - + # Verify normal operation is restored by retrieving one of the stored items if stored_queries: query, query_hash = stored_queries[0] # Retrieve using the same hash that was used for storage retrieved = self._run_async(self.resilient_cache.get(query_hash)) - self.assertIsNotNone(retrieved, f"Failed to retrieve stored item after circuit breaker recovery. Query: {query}, Hash: {query_hash}") + self.assertIsNotNone( + retrieved, + f"Failed to retrieve stored item after circuit breaker recovery. Query: {query}, Hash: {query_hash}", + ) # Verify the content matches what we stored - self.assertEqual(retrieved.content if hasattr(retrieved, 'content') else retrieved, test_content) + self.assertEqual( + retrieved.content if hasattr(retrieved, "content") else retrieved, + test_content, + ) self.assertIn("Recovery test content", retrieved.content) - + logger.info("Circuit breaker recovery test completed") - + def test_performance_under_various_conditions(self): """Test performance metrics under various load conditions""" performance_results = {} - + # Test normal load start_time = time.time() for i in range(10): - symbol = ['AAPL', 'MSFT', 'GOOGL', 'AMZN', 'TSLA'][i % 5] + symbol = ["AAPL", "MSFT", "GOOGL", "AMZN", "TSLA"][i % 5] financial_data = self._get_financial_data(symbol) - self._run_async(self.resilient_cache.store(f"perf_test_{i}", financial_data)) + self._run_async( + self.resilient_cache.store(f"perf_test_{i}", financial_data) + ) normal_load_time = time.time() - start_time - performance_results['normal_load'] = normal_load_time - + performance_results["normal_load"] = normal_load_time + # Test concurrent access def concurrent_worker(worker_id): for i in range(5): key = f"concurrent_{worker_id}_{i}" - content = self._get_financial_data('AAPL') + content = self._get_financial_data("AAPL") self._run_async(self.resilient_cache.store(key, content)) retrieved = self._run_async(self.resilient_cache.get(key)) self.assertIsNotNone(retrieved) - + start_time = time.time() threads = [] for i in range(3): thread = threading.Thread(target=concurrent_worker, args=(i,)) threads.append(thread) thread.start() - + for thread in threads: thread.join() concurrent_time = time.time() - start_time - performance_results['concurrent_load'] = concurrent_time - + performance_results["concurrent_load"] = concurrent_time + # Get final metrics metrics = self.resilient_cache.get_metrics() - performance_results['final_metrics'] = metrics - + performance_results["final_metrics"] = metrics + # Verify performance is reasonable self.assertLess(normal_load_time, 5.0, "Normal load performance too slow") self.assertLess(concurrent_time, 10.0, "Concurrent load performance too slow") - + logger.info("Performance test completed", results=performance_results) - + def test_metrics_collection_and_validation(self): """Test comprehensive metrics collection and validation""" # Reset metrics initial_metrics = self.resilient_cache.get_metrics() - + # Perform various operations test_operations = [ - ('store_test_1', 'Test content 1'), - ('store_test_2', 'Test content 2'), - ('get_test_1', None), # Hit - ('get_nonexistent', None), # Miss + ("store_test_1", "Test content 1"), + ("store_test_2", "Test content 2"), + ("get_test_1", None), # Hit + ("get_nonexistent", None), # Miss ] - + for operation, content in test_operations: - if operation.startswith('store_'): - financial_data = self._get_financial_data('AAPL') + if operation.startswith("store_"): + financial_data = self._get_financial_data("AAPL") self._run_async(self.resilient_cache.store(operation, financial_data)) - elif operation.startswith('get_'): - if operation == 'get_test_1': + elif operation.startswith("get_"): + if operation == "get_test_1": # This should be a hit - result = self._run_async(self.resilient_cache.get('store_test_1')) + result = self._run_async(self.resilient_cache.get("store_test_1")) self.assertIsNotNone(result) else: # This should be a miss result = self._run_async(self.resilient_cache.get(operation)) self.assertIsNone(result) - + # Get final metrics final_metrics = self.resilient_cache.get_metrics() - + # Validate metric structure - self.assertIn('cache', final_metrics) - self.assertIn('circuit_breaker', final_metrics) - - cache_metrics = final_metrics['cache'] - cb_metrics = final_metrics['circuit_breaker'] - - required_cache_metrics = ['cache_hits', 'cache_misses', 'total_requests'] + self.assertIn("cache", final_metrics) + self.assertIn("circuit_breaker", final_metrics) + + cache_metrics = final_metrics["cache"] + cb_metrics = final_metrics["circuit_breaker"] + + required_cache_metrics = ["cache_hits", "cache_misses", "total_requests"] for metric in required_cache_metrics: self.assertIn(metric, cache_metrics) - - required_cb_metrics = ['state', 'failure_count', 'success_count', 'total_operations'] + + required_cb_metrics = [ + "state", + "failure_count", + "success_count", + "total_operations", + ] for metric in required_cb_metrics: self.assertIn(metric, cb_metrics) - + # Validate metric values - handle nested structure - cache_metrics = final_metrics['cache'] - self.assertGreaterEqual(cache_metrics['cache_hits'], 1) - self.assertGreaterEqual(cache_metrics['cache_misses'], 1) - self.assertGreater(cache_metrics['total_requests'], 0) - + cache_metrics = final_metrics["cache"] + self.assertGreaterEqual(cache_metrics["cache_hits"], 1) + self.assertGreaterEqual(cache_metrics["cache_misses"], 1) + self.assertGreater(cache_metrics["total_requests"], 0) + logger.info("Metrics validation test completed", metrics=final_metrics) - + def test_real_database_integration(self): """Test cache integration with real database operations""" # Test database connectivity @@ -558,21 +625,21 @@ def test_real_database_integration(self): company_count = cursor.fetchone()[0] self.assertGreater(company_count, 0) conn.close() - + # Test cache-database integration - symbols = ['AAPL', 'MSFT', 'GOOGL', 'AMZN', 'TSLA'] - + symbols = ["AAPL", "MSFT", "GOOGL", "AMZN", "TSLA"] + for symbol in symbols: cache_key = f"db_integration_{symbol}" - + # First access (cache miss, database hit) start_time = time.time() financial_data = self._get_financial_data(symbol) db_time = time.time() - start_time - + # Store in cache self._run_async(self.resilient_cache.store(cache_key, financial_data)) - + # Second access (cache hit) start_time = time.time() cached_data = self._run_async(self.resilient_cache.get(cache_key)) @@ -581,107 +648,129 @@ def test_real_database_integration(self): # Verify data integrity self.assertIsNotNone(cached_data) self.assertEqual(cached_data.content, financial_data) - + # Cache should be faster than database, but allow for timing variance # If cache isn't significantly faster, log it but don't fail the test if cache_time >= db_time: - logger.warning(f"Cache timing variance detected for {symbol}: " - f"cache={cache_time*1000:.2f}ms, db={db_time*1000:.2f}ms") + logger.warning( + f"Cache timing variance detected for {symbol}: " + f"cache={cache_time*1000:.2f}ms, db={db_time*1000:.2f}ms" + ) # Allow up to 50% slower due to security scanning overhead tolerance = db_time * 1.5 - self.assertLess(cache_time, tolerance, - f"Cache time {cache_time*1000:.2f}ms exceeds tolerance " - f"of {tolerance*1000:.2f}ms for {symbol}") + self.assertLess( + cache_time, + tolerance, + f"Cache time {cache_time*1000:.2f}ms exceeds tolerance " + f"of {tolerance*1000:.2f}ms for {symbol}", + ) else: # Cache is faster as expected self.assertLess(cache_time, db_time) - + logger.info("Database integration test completed") - + def test_stress_and_failure_recovery(self): """Test system behavior under stress and failure conditions""" stress_results = {} - + # Stress test with high volume stress_operations = 50 success_count = 0 error_count = 0 - + start_time = time.time() for i in range(stress_operations): try: - symbol = ['AAPL', 'MSFT', 'GOOGL', 'AMZN', 'TSLA'][i % 5] + symbol = ["AAPL", "MSFT", "GOOGL", "AMZN", "TSLA"][i % 5] financial_data = self._get_financial_data(symbol) - - if self._run_async(self.resilient_cache.store(f"stress_test_{i}", financial_data)): + + if self._run_async( + self.resilient_cache.store(f"stress_test_{i}", financial_data) + ): success_count += 1 - + # Verify retrieval - retrieved = self._run_async(self.resilient_cache.get(f"stress_test_{i}")) + retrieved = self._run_async( + self.resilient_cache.get(f"stress_test_{i}") + ) if retrieved and retrieved.content == financial_data: success_count += 1 else: error_count += 1 else: error_count += 1 - + except Exception as e: error_count += 1 logger.warning("Stress test error", error=str(e), iteration=i) - + stress_time = time.time() - start_time - stress_results['total_time'] = stress_time - stress_results['success_count'] = success_count - stress_results['error_count'] = error_count - stress_results['success_rate'] = success_count / (success_count + error_count) if (success_count + error_count) > 0 else 0 - + stress_results["total_time"] = stress_time + stress_results["success_count"] = success_count + stress_results["error_count"] = error_count + stress_results["success_rate"] = ( + success_count / (success_count + error_count) + if (success_count + error_count) > 0 + else 0 + ) + # Verify reasonable success rate - self.assertGreater(stress_results['success_rate'], 0.8, "Success rate too low under stress") - + self.assertGreater( + stress_results["success_rate"], 0.8, "Success rate too low under stress" + ) + logger.info("Stress test completed", results=stress_results) - + def test_comprehensive_system_validation(self): """Run comprehensive system validation with all components""" validation_results = { - 'initialization': False, - 'basic_operations': False, - 'circuit_breaker': False, - 'performance': False, - 'database_integration': False, - 'metrics_collection': False, - 'error_handling': False + "initialization": False, + "basic_operations": False, + "circuit_breaker": False, + "performance": False, + "database_integration": False, + "metrics_collection": False, + "error_handling": False, } - + try: # Test initialization self.assertIsNotNone(self.cache_manager) self.assertIsNotNone(self.resilient_cache) - validation_results['initialization'] = True - + validation_results["initialization"] = True + # Test basic operations - test_content = "Comprehensive validation test content with sufficient tokens for cache storage. " * 20 - store_result = self._run_async(self.resilient_cache.store("validation_test", test_content)) + test_content = ( + "Comprehensive validation test content with sufficient tokens for cache storage. " + * 20 + ) + store_result = self._run_async( + self.resilient_cache.store("validation_test", test_content) + ) self.assertTrue(store_result) - + retrieved = self._run_async(self.resilient_cache.get("validation_test")) self.assertIsNotNone(retrieved) self.assertEqual(retrieved.content, test_content) - validation_results['basic_operations'] = True - + validation_results["basic_operations"] = True + # Test circuit breaker functionality self.assertIsNotNone(self.resilient_cache.circuit_breaker) self.assertTrue(self.resilient_cache.circuit_breaker.is_closed()) - validation_results['circuit_breaker'] = True - + validation_results["circuit_breaker"] = True + # Test performance start_time = time.time() for i in range(5): - financial_data = self._get_financial_data('AAPL') - self._run_async(self.resilient_cache.store(f"perf_validation_{i}", financial_data)) + financial_data = self._get_financial_data("AAPL") + self._run_async( + self.resilient_cache.store(f"perf_validation_{i}", financial_data) + ) perf_time = time.time() - start_time self.assertLess(perf_time, 2.0) - validation_results['performance'] = True - + validation_results["performance"] = True + # Test database integration conn = sqlite3.connect(self.db_path) cursor = conn.cursor() @@ -689,35 +778,39 @@ def test_comprehensive_system_validation(self): result = cursor.fetchone() self.assertIsNotNone(result) conn.close() - validation_results['database_integration'] = True - + validation_results["database_integration"] = True + # Test metrics collection metrics = self.resilient_cache.get_metrics() self.assertIsInstance(metrics, dict) - self.assertIn('cache', metrics) - self.assertIn('circuit_breaker', metrics) + self.assertIn("cache", metrics) + self.assertIn("circuit_breaker", metrics) # Check nested structure - cache_metrics = metrics['cache'] - self.assertIn('cache_hits', cache_metrics) - self.assertIn('cache_misses', cache_metrics) - validation_results['metrics_collection'] = True - + cache_metrics = metrics["cache"] + self.assertIn("cache_hits", cache_metrics) + self.assertIn("cache_misses", cache_metrics) + validation_results["metrics_collection"] = True + # Test error handling - with patch.object(self.cache_manager, 'get', side_effect=Exception("Test error")): + with patch.object( + self.cache_manager, "get", side_effect=Exception("Test error") + ): result = self._run_async(self.resilient_cache.get("error_test")) # Should handle gracefully self.assertIsNone(result) - validation_results['error_handling'] = True - + validation_results["error_handling"] = True + except Exception as e: logger.error("Validation error", error=str(e)) raise - + # Verify all components passed for component, status in validation_results.items(): self.assertTrue(status, f"Component {component} failed validation") - - logger.info("Comprehensive system validation completed", results=validation_results) + + logger.info( + "Comprehensive system validation completed", results=validation_results + ) def run_test_suite(): @@ -725,23 +818,23 @@ def run_test_suite(): print("=" * 80) print("FACT Cache Resilience End-to-End Test Suite") print("=" * 80) - + # Create test suite test_suite = unittest.TestLoader().loadTestsFromTestCase(TestCacheResilienceE2ESync) - + # Run tests with detailed output runner = unittest.TextTestRunner(verbosity=2, buffer=True) result = runner.run(test_suite) - + print("\n" + "=" * 80) print(f"Test Results: {result.testsRun} tests run") print(f"Failures: {len(result.failures)}") print(f"Errors: {len(result.errors)}") print(f"Skipped: {len(result.skipped)}") print("=" * 80) - + return result -if __name__ == '__main__': - run_test_suite() \ No newline at end of file +if __name__ == "__main__": + run_test_suite() diff --git a/tests/integration/test_complete_system.py b/tests/integration/test_complete_system.py index 9a30a85..c91d77f 100644 --- a/tests/integration/test_complete_system.py +++ b/tests/integration/test_complete_system.py @@ -3,7 +3,7 @@ This test suite validates the complete integration of all FACT system components: - Environment configuration -- Database connectivity +- Database connectivity - API integration - Cache system - Security layer @@ -21,6 +21,7 @@ # Import system components import sys + sys.path.insert(0, str(Path(__file__).parent.parent.parent / "src")) from core.config import Config, ConfigurationError @@ -34,7 +35,7 @@ class TestCompleteSystemIntegration: """Test complete system integration with all components.""" - + @pytest.fixture def valid_test_env(self): """Provide valid test environment variables.""" @@ -43,9 +44,9 @@ def valid_test_env(self): "ARCADE_API_KEY": "ak-test-key-for-integration", "DATABASE_PATH": "test_data/integration_test.db", "CACHE_MAX_SIZE": "1000", - "LOG_LEVEL": "DEBUG" + "LOG_LEVEL": "DEBUG", } - + @pytest.fixture async def test_database(self, valid_test_env): """Create a test database for integration testing.""" @@ -53,164 +54,175 @@ async def test_database(self, valid_test_env): config = Config() db_path = Path(config.database_path) db_path.parent.mkdir(parents=True, exist_ok=True) - + db_manager = DatabaseManager(config.database_path) await db_manager.initialize_database() - + yield db_manager - + # Cleanup if db_path.exists(): db_path.unlink() - + def test_environment_configuration_integration(self, valid_test_env): """Test that environment configuration integrates properly with all components.""" with patch.dict(os.environ, valid_test_env, clear=True): # Should successfully create config config = Config() - + # Verify all required keys are present assert config.anthropic_api_key == "sk-ant-test-key-for-integration" assert config.arcade_api_key == "ak-test-key-for-integration" assert config.database_path == "test_data/integration_test.db" - + # Verify configuration exports safely config_dict = config.to_dict() assert config_dict["anthropic_api_key"] == "***" assert config_dict["arcade_api_key"] == "***" - + @pytest.mark.asyncio - async def test_database_integration_with_config(self, test_database, valid_test_env): + async def test_database_integration_with_config( + self, test_database, valid_test_env + ): """Test database integration with configuration system.""" with patch.dict(os.environ, valid_test_env, clear=True): config = Config() - + # Database should be initialized and accessible db_info = await test_database.get_database_info() - + assert db_info["database_path"] == config.database_path assert db_info["total_tables"] >= 3 # companies, financial_data, benchmarks assert "companies" in db_info["tables"] assert "financial_data" in db_info["tables"] - + @pytest.mark.asyncio async def test_security_integration(self, valid_test_env): """Test security component integration.""" with patch.dict(os.environ, valid_test_env, clear=True): config = Config() security_manager = SecurityManager() - + # Test input sanitization test_query = "What was TechCorp's revenue in Q1 2025?" sanitized = security_manager.sanitize_input(test_query) assert sanitized == test_query # Should pass through clean input - + # Test malicious input blocking malicious_query = "" with pytest.raises(Exception): # Should raise security exception security_manager.sanitize_input(malicious_query) - - @pytest.mark.asyncio + + @pytest.mark.asyncio async def test_cache_system_integration(self, valid_test_env): """Test cache system integration with configuration.""" with patch.dict(os.environ, valid_test_env, clear=True): config = Config() - + # Cache configuration should be loaded from environment cache_config = config.cache_config assert cache_config["max_size"] == "1000" # From test env assert cache_config["prefix"] == "fact_v1" - + # Cache validator should initialize validator = CacheValidator() assert validator is not None - + @pytest.mark.asyncio - async def test_driver_initialization_integration(self, test_database, valid_test_env): + async def test_driver_initialization_integration( + self, test_database, valid_test_env + ): """Test complete driver initialization with all components.""" with patch.dict(os.environ, valid_test_env, clear=True): - with patch('src.core.driver.FACTDriver._test_connections', return_value=None): + with patch( + "src.core.driver.FACTDriver._test_connections", return_value=None + ): # Driver should initialize successfully driver = await get_driver() - + assert driver is not None assert driver.config is not None assert driver.tool_registry is not None - + # Should have tools registered tools = driver.tool_registry.list_tools() assert len(tools) > 0 - + # Should have metrics metrics = driver.get_metrics() assert metrics["initialized"] is True - + await driver.shutdown() - + @pytest.mark.asyncio async def test_benchmark_framework_integration(self, valid_test_env): """Test benchmark framework integration with system components.""" with patch.dict(os.environ, valid_test_env, clear=True): framework = BenchmarkFramework() - + # Should initialize with configuration assert framework.config is not None assert framework.metrics_collector is not None - + # Should handle empty query list summary = await framework.run_benchmark_suite([]) assert summary.total_queries == 0 assert summary.successful_queries == 0 - + @pytest.mark.asyncio async def test_end_to_end_query_processing(self, test_database, valid_test_env): """Test complete end-to-end query processing through all system layers.""" with patch.dict(os.environ, valid_test_env, clear=True): - with patch('src.core.driver.FACTDriver._test_connections', return_value=None): - with patch('anthropic.AsyncAnthropic') as mock_anthropic: + with patch( + "src.core.driver.FACTDriver._test_connections", return_value=None + ): + with patch("anthropic.AsyncAnthropic") as mock_anthropic: # Mock successful API response mock_response = MagicMock() - mock_response.content = [MagicMock(text="TechCorp's Q1 2025 revenue was $50M")] - mock_anthropic.return_value.messages.create.return_value = mock_response - + mock_response.content = [ + MagicMock(text="TechCorp's Q1 2025 revenue was $50M") + ] + mock_anthropic.return_value.messages.create.return_value = ( + mock_response + ) + # Initialize driver driver = await get_driver() - + # Process a query end-to-end - response = await driver.process_query("What was TechCorp's Q1 2025 revenue?") - + response = await driver.process_query( + "What was TechCorp's Q1 2025 revenue?" + ) + assert response is not None assert "TechCorp" in response assert "revenue" in response - + await driver.shutdown() - + @pytest.mark.asyncio async def test_performance_monitoring_integration(self, valid_test_env): """Test performance monitoring integration.""" with patch.dict(os.environ, valid_test_env, clear=True): metrics_collector = MetricsCollector() - + # Should record metrics await metrics_collector.record_query_metrics( - query="test query", - response_time_ms=50.0, - success=True, - cache_hit=False + query="test query", response_time_ms=50.0, success=True, cache_hit=False ) - + # Should provide summary summary = await metrics_collector.get_performance_summary() assert summary is not None assert "total_queries" in summary - + def test_configuration_validation_integration(self, valid_test_env): """Test that configuration validation works across all components.""" # Test with valid configuration with patch.dict(os.environ, valid_test_env, clear=True): config = Config() assert config is not None - + # Test with missing keys with patch.dict(os.environ, {}, clear=True): with pytest.raises(ConfigurationError) as exc_info: @@ -220,39 +232,42 @@ def test_configuration_validation_integration(self, valid_test_env): class TestSystemErrorHandling: """Test error handling across integrated components.""" - + def test_graceful_degradation_with_missing_api_keys(self): """Test system behavior when API keys are missing.""" with patch.dict(os.environ, {}, clear=True): # Configuration should fail fast with pytest.raises(ConfigurationError): Config() - + def test_database_error_handling(self): """Test database error handling in integrated system.""" test_env = { "ANTHROPIC_API_KEY": "sk-test-key", "ARCADE_API_KEY": "ak-test-key", - "DATABASE_PATH": "/invalid/path/database.db" + "DATABASE_PATH": "/invalid/path/database.db", } - + with patch.dict(os.environ, test_env, clear=True): config = Config() - + # Should handle invalid database path gracefully db_manager = DatabaseManager(config.database_path) # Database operations should fail with clear error messages - + @pytest.mark.asyncio async def test_api_connectivity_error_handling(self): """Test API connectivity error handling.""" test_env = { "ANTHROPIC_API_KEY": "sk-invalid-key", - "ARCADE_API_KEY": "ak-invalid-key" + "ARCADE_API_KEY": "ak-invalid-key", } - + with patch.dict(os.environ, test_env, clear=True): - with patch('src.core.driver.FACTDriver._test_connections', side_effect=Exception("API Error")): + with patch( + "src.core.driver.FACTDriver._test_connections", + side_effect=Exception("API Error"), + ): # Should handle API errors gracefully with pytest.raises(Exception): await get_driver() @@ -260,86 +275,90 @@ async def test_api_connectivity_error_handling(self): class TestSystemPerformance: """Test system performance characteristics.""" - + @pytest.mark.asyncio async def test_concurrent_query_handling(self, valid_test_env=None): """Test system handling of concurrent queries.""" if valid_test_env is None: valid_test_env = { "ANTHROPIC_API_KEY": "sk-test-key", - "ARCADE_API_KEY": "ak-test-key" + "ARCADE_API_KEY": "ak-test-key", } - + with patch.dict(os.environ, valid_test_env, clear=True): - with patch('src.core.driver.FACTDriver._test_connections', return_value=None): + with patch( + "src.core.driver.FACTDriver._test_connections", return_value=None + ): framework = BenchmarkFramework() - + # Test concurrent query processing queries = ["Query 1", "Query 2", "Query 3"] - - with patch('src.benchmarking.framework.process_user_query', return_value="Mock response"): + + with patch( + "src.benchmarking.framework.process_user_query", + return_value="Mock response", + ): summary = await framework.run_benchmark_suite(queries) - + # Should handle all queries expected_total = len(queries) * framework.config.iterations assert summary.total_queries == expected_total - + @pytest.mark.asyncio async def test_cache_performance_validation(self): """Test cache performance meets targets.""" test_env = { - "ANTHROPIC_API_KEY": "sk-test-key", + "ANTHROPIC_API_KEY": "sk-test-key", "ARCADE_API_KEY": "ak-test-key", "CACHE_HIT_TARGET_MS": "30", - "CACHE_MISS_TARGET_MS": "120" + "CACHE_MISS_TARGET_MS": "120", } - + with patch.dict(os.environ, test_env, clear=True): validator = CacheValidator() - + # Validate cache performance targets # This test validates the cache system meets performance requirements class TestSystemRecovery: """Test system recovery and resilience.""" - + @pytest.mark.asyncio async def test_system_recovery_after_failure(self): """Test system recovery after component failure.""" - test_env = { - "ANTHROPIC_API_KEY": "sk-test-key", - "ARCADE_API_KEY": "ak-test-key" - } - + test_env = {"ANTHROPIC_API_KEY": "sk-test-key", "ARCADE_API_KEY": "ak-test-key"} + with patch.dict(os.environ, test_env, clear=True): # Simulate system failure and recovery - with patch('src.core.driver.FACTDriver._test_connections', return_value=None): + with patch( + "src.core.driver.FACTDriver._test_connections", return_value=None + ): driver = await get_driver() - + # System should be operational assert driver is not None - + # Simulate graceful shutdown await driver.shutdown() - + def test_configuration_reload_capability(self): """Test system ability to reload configuration.""" # Test configuration reloading without restart test_env_1 = { - "ANTHROPIC_API_KEY": "sk-test-key-1", - "ARCADE_API_KEY": "ak-test-key-1" + "ANTHROPIC_API_KEY": "sk-test-key-1", + "ARCADE_API_KEY": "ak-test-key-1", } - + test_env_2 = { "ANTHROPIC_API_KEY": "sk-test-key-2", - "ARCADE_API_KEY": "ak-test-key-2" + "ARCADE_API_KEY": "ak-test-key-2", } - + with patch.dict(os.environ, test_env_1, clear=True): config1 = Config() assert config1.anthropic_api_key == "sk-test-key-1" - + with patch.dict(os.environ, test_env_2, clear=True): config2 = Config() - assert config2.anthropic_api_key == "sk-test-key-2" \ No newline at end of file + assert config2.anthropic_api_key == "sk-test-key-2" diff --git a/tests/integration/test_system_integration.py b/tests/integration/test_system_integration.py index a7bdaf3..0ac64da 100644 --- a/tests/integration/test_system_integration.py +++ b/tests/integration/test_system_integration.py @@ -17,94 +17,111 @@ class TestEndToEndWorkflow: """Integration tests for complete FACT workflow.""" - + @pytest.mark.integration - async def test_complete_query_processing_workflow(self, test_environment, benchmark_queries): + async def test_complete_query_processing_workflow( + self, test_environment, benchmark_queries + ): """TEST: Complete query from user input to final response""" # Arrange query = "What was Q1-2025 revenue?" expected_revenue = 1234567.89 - + # Act with test_environment: response = await process_user_query(query) - + # Assert assert response is not None assert str(expected_revenue) in response or "1,234,567.89" in response assert "Q1-2025" in response - + # Verify response contains structured financial data - assert any(keyword in response.lower() for keyword in ["revenue", "quarter", "financial"]) - + assert any( + keyword in response.lower() + for keyword in ["revenue", "quarter", "financial"] + ) + @pytest.mark.integration - async def test_cache_miss_to_tool_execution_workflow(self, test_environment, test_database): + async def test_cache_miss_to_tool_execution_workflow( + self, test_environment, test_database + ): """TEST: Cache miss triggers tool execution and response generation""" # Arrange unique_query = f"What is the total revenue for 2024? (timestamp: {time.time()})" - + # Configure mock to simulate tool call mock_tool_call = Mock() mock_tool_call.name = "SQL.QueryReadonly" mock_tool_call.id = "test-call-123" - mock_tool_call.arguments = json.dumps({ - "statement": "SELECT SUM(value) as total FROM revenue WHERE quarter LIKE '%2024%'" - }) - + mock_tool_call.arguments = json.dumps( + { + "statement": "SELECT SUM(value) as total FROM revenue WHERE quarter LIKE '%2024%'" + } + ) + mock_response = Mock() - mock_response.content = [Mock(text="The total revenue for 2024 was $2,997,419.08")] + mock_response.content = [ + Mock(text="The total revenue for 2024 was $2,997,419.08") + ] mock_response.tool_calls = [mock_tool_call] - + # Act with test_environment as env: env["anthropic"].messages.create.return_value = mock_response env["arcade"].tools.execute.return_value = { "status": "success", - "data": {"rows": [{"total": 2997419.08}], "row_count": 1} + "data": {"rows": [{"total": 2997419.08}], "row_count": 1}, } - + response = await process_user_query(unique_query) - + # Assert assert response is not None assert "2,997,419.08" in response or "2997419.08" in response assert "2024" in response - + # Verify tool was executed env["arcade"].tools.execute.assert_called_once() call_args = env["arcade"].tools.execute.call_args[1] assert "SQL.QueryReadonly" in str(call_args) - + @pytest.mark.integration - async def test_cache_hit_workflow_bypasses_tool_execution(self, test_environment, cache_config): + async def test_cache_hit_workflow_bypasses_tool_execution( + self, test_environment, cache_config + ): """TEST: Cache hit bypasses tool execution for faster response""" # Arrange cache_manager = CacheManager(config=cache_config) query = "What was Q1-2025 revenue?" - cached_response = "Q1-2025 revenue was $1,234,567.89 based on our financial records." - + cached_response = ( + "Q1-2025 revenue was $1,234,567.89 based on our financial records." + ) + # Pre-populate cache query_hash = cache_manager.generate_hash(query) cache_manager.store(query_hash, cached_response) - + # Act start_time = time.perf_counter() - + with test_environment as env: - with patch('src.cache.manager.cache_manager', cache_manager): + with patch("src.cache.manager.cache_manager", cache_manager): response = await process_user_query(query) - + end_time = time.perf_counter() response_time_ms = (end_time - start_time) * 1000 - + # Assert assert response == cached_response - assert response_time_ms < 50, f"Cache hit took {response_time_ms:.2f}ms, exceeds 50ms target" - + assert ( + response_time_ms < 50 + ), f"Cache hit took {response_time_ms:.2f}ms, exceeds 50ms target" + # Verify tool execution was bypassed env["arcade"].tools.execute.assert_not_called() env["anthropic"].messages.create.assert_not_called() - + @pytest.mark.integration async def test_error_handling_and_recovery_workflow(self, test_environment): """TEST: System handles errors gracefully and provides useful feedback""" @@ -114,22 +131,22 @@ async def test_error_handling_and_recovery_workflow(self, test_environment): "name": "database_connection_error", "query": "SELECT * FROM revenue", "error": Exception("Database connection failed"), - "expected_in_response": ["error", "database", "connection"] + "expected_in_response": ["error", "database", "connection"], }, { "name": "tool_execution_timeout", "query": "Complex analytical query", "error": TimeoutError("Tool execution timed out"), - "expected_in_response": ["timeout", "error"] + "expected_in_response": ["timeout", "error"], }, { "name": "anthropic_api_error", "query": "What is the revenue?", "error": Exception("API rate limit exceeded"), - "expected_in_response": ["error", "api", "temporarily"] - } + "expected_in_response": ["error", "api", "temporarily"], + }, ] - + # Test each error scenario for scenario in error_scenarios: # Act @@ -138,63 +155,78 @@ async def test_error_handling_and_recovery_workflow(self, test_environment): env["arcade"].tools.execute.side_effect = scenario["error"] elif "anthropic" in scenario["name"]: env["anthropic"].messages.create.side_effect = scenario["error"] - + response = await process_user_query(scenario["query"]) - + # Assert assert response is not None response_lower = response.lower() - + # Should contain error context - assert any(keyword in response_lower for keyword in scenario["expected_in_response"]) - + assert any( + keyword in response_lower + for keyword in scenario["expected_in_response"] + ) + # Should not contain sensitive information - assert not any(sensitive in response_lower for sensitive in ["api_key", "token", "secret"]) - + assert not any( + sensitive in response_lower + for sensitive in ["api_key", "token", "secret"] + ) + @pytest.mark.integration def test_system_initialization_and_configuration(self, test_environment): """TEST: System initializes properly with all components""" # Act with test_environment: - anthropic_client, arcade_client, cache_prefix, system_prompt = initialize_system() - + anthropic_client, arcade_client, cache_prefix, system_prompt = ( + initialize_system() + ) + # Assert assert anthropic_client is not None assert arcade_client is not None assert cache_prefix == "fact_test_v1" assert "deterministic" in system_prompt.lower() - assert "finance" in system_prompt.lower() or "financial" in system_prompt.lower() - + assert ( + "finance" in system_prompt.lower() or "financial" in system_prompt.lower() + ) + @pytest.mark.integration async def test_concurrent_user_sessions(self, test_environment, benchmark_queries): """TEST: System handles concurrent user sessions correctly""" + # Arrange async def user_session(user_id: int, queries: list): session_results = [] for query in queries: user_query = f"{query} (user {user_id})" result = await process_user_query(user_query) - session_results.append({ - "user_id": user_id, - "query": user_query, - "response": result, - "timestamp": time.time() - }) + session_results.append( + { + "user_id": user_id, + "query": user_query, + "response": result, + "timestamp": time.time(), + } + ) return session_results - + # Act user_count = 5 queries_per_user = benchmark_queries[:3] # First 3 queries - + with test_environment: - concurrent_sessions = await asyncio.gather(*[ - user_session(user_id, queries_per_user) - for user_id in range(user_count) - ]) - + concurrent_sessions = await asyncio.gather( + *[ + user_session(user_id, queries_per_user) + for user_id in range(user_count) + ] + ) + # Assert assert len(concurrent_sessions) == user_count - + # Verify all sessions completed successfully for session in concurrent_sessions: assert len(session) == len(queries_per_user) @@ -205,134 +237,135 @@ async def user_session(user_id: int, queries: list): class TestToolIntegration: """Integration tests for tool registration and execution.""" - + @pytest.mark.integration - def test_tool_registration_and_discovery_flow(self, test_environment, mock_arcade_client): + def test_tool_registration_and_discovery_flow( + self, test_environment, mock_arcade_client + ): """TEST: Tool registration and subsequent discovery workflow""" # Arrange from src.tools.decorators import tool, register_tool, discover_tools - + @tool( name="TestTool.Integration", description="Integration test tool", - parameters={ - "input": {"type": "string", "description": "Test input"} - } + parameters={"input": {"type": "string", "description": "Test input"}}, ) def integration_test_tool(input: str) -> dict: return {"processed": input.upper(), "timestamp": time.time()} - + # Act register_tool(integration_test_tool) discovered_tools = discover_tools() - + # Assert - tool_names = [t._tool_metadata['name'] for t in discovered_tools] + tool_names = [t._tool_metadata["name"] for t in discovered_tools] assert "TestTool.Integration" in tool_names - + # Verify tool can be executed from src.tools.decorators import execute_tool + result = execute_tool("TestTool.Integration", {"input": "test_value"}) - + assert result["success"] == True assert result["data"]["processed"] == "TEST_VALUE" - + @pytest.mark.integration def test_sql_tool_end_to_end_execution(self, test_environment, test_database): """TEST: SQL tool end-to-end execution workflow""" # Arrange sql_tool = SQLQueryTool(database_path=test_database) test_query = "SELECT quarter, value FROM revenue WHERE value > 1000000 ORDER BY value DESC" - + # Act result = sql_tool.execute({"statement": test_query}) - + # Assert assert result["success"] == True assert "data" in result assert "rows" in result["data"] assert len(result["data"]["rows"]) > 0 - + # Verify results are properly formatted first_row = result["data"]["rows"][0] assert "quarter" in first_row assert "value" in first_row assert isinstance(first_row["value"], (int, float)) assert first_row["value"] > 1000000 - + @pytest.mark.integration - async def test_tool_call_via_arcade_integration(self, test_environment, mock_arcade_client): + async def test_tool_call_via_arcade_integration( + self, test_environment, mock_arcade_client + ): """TEST: Tool execution via Arcade.dev integration""" # Arrange from src.tools.decorators import execute_tool_via_arcade - + # Configure successful tool execution mock_arcade_client.tools.execute.return_value = { "status": "success", "data": { "rows": [{"quarter": "Q1-2025", "value": 1234567.89}], "row_count": 1, - "execution_time_ms": 8 - } + "execution_time_ms": 8, + }, } - + tool_call_data = { "name": "SQL.QueryReadonly", "arguments": { "statement": "SELECT quarter, value FROM revenue WHERE quarter = 'Q1-2025'" - } + }, } - + # Act result = await execute_tool_via_arcade( - tool_call_data["name"], - tool_call_data["arguments"], - mock_arcade_client + tool_call_data["name"], tool_call_data["arguments"], mock_arcade_client ) - + # Assert assert result["success"] == True assert result["data"]["row_count"] == 1 assert result["data"]["rows"][0]["quarter"] == "Q1-2025" assert result["data"]["execution_time_ms"] < 10 - + # Verify Arcade client was called correctly mock_arcade_client.tools.execute.assert_called_once() - + @pytest.mark.integration def test_tool_authorization_integration(self, test_environment): """TEST: Tool authorization integration with security framework""" # Arrange from src.tools.decorators import tool, register_tool, execute_tool_with_auth from src.core.errors import AuthorizationError - + @tool( name="TestTool.Protected", description="Protected tool requiring admin access", - required_scopes=["admin", "write"] + required_scopes=["admin", "write"], ) def protected_tool() -> dict: return {"message": "Admin operation completed"} - + register_tool(protected_tool) - + # Test insufficient permissions with pytest.raises(AuthorizationError): execute_tool_with_auth( "TestTool.Protected", {}, user_scopes=["read"], - user_id="test@example.com" + user_id="test@example.com", ) - + # Test sufficient permissions result = execute_tool_with_auth( "TestTool.Protected", {}, user_scopes=["admin", "write", "read"], - user_id="admin@example.com" + user_id="admin@example.com", ) - + # Assert assert result["success"] == True assert result["data"]["message"] == "Admin operation completed" @@ -340,7 +373,7 @@ def protected_tool() -> dict: class TestDatabaseIntegration: """Integration tests for database operations and connection management.""" - + @pytest.mark.database def test_database_connection_pool_under_load(self, test_database): """TEST: Database connection pool handles concurrent load""" @@ -348,164 +381,177 @@ def test_database_connection_pool_under_load(self, test_database): pool = DatabaseConnectionPool(test_database, max_connections=3) import threading import queue - + results_queue = queue.Queue() - + def database_worker(worker_id: int): try: connection = pool.acquire_connection() cursor = connection.execute("SELECT COUNT(*) as count FROM revenue") result = cursor.fetchone() pool.release_connection(connection) - results_queue.put({"worker_id": worker_id, "count": result[0], "success": True}) + results_queue.put( + {"worker_id": worker_id, "count": result[0], "success": True} + ) except Exception as e: - results_queue.put({"worker_id": worker_id, "error": str(e), "success": False}) - + results_queue.put( + {"worker_id": worker_id, "error": str(e), "success": False} + ) + # Act workers = [] for i in range(5): # More workers than pool size worker = threading.Thread(target=database_worker, args=(i,)) workers.append(worker) worker.start() - + # Wait for all workers to complete for worker in workers: worker.join(timeout=5.0) - + # Collect results results = [] while not results_queue.empty(): results.append(results_queue.get()) - + # Assert assert len(results) == 5 successful_results = [r for r in results if r["success"]] assert len(successful_results) == 5 # All should succeed - + # Verify correct data for result in successful_results: assert result["count"] == 4 # Test database has 4 revenue records - + @pytest.mark.database def test_database_transaction_integrity(self, test_database): """TEST: Database maintains transaction integrity""" # Arrange from src.tools.connectors.sql import execute_sql_query - + # Verify initial state - initial_result = execute_sql_query("SELECT COUNT(*) as count FROM revenue", test_database) + initial_result = execute_sql_query( + "SELECT COUNT(*) as count FROM revenue", test_database + ) initial_count = initial_result.rows[0]["count"] - + # Act - Try to execute multiple queries queries = [ "SELECT * FROM revenue WHERE quarter = 'Q1-2025'", "SELECT AVG(value) as avg_value FROM revenue", - "SELECT quarter, value FROM revenue ORDER BY value DESC LIMIT 2" + "SELECT quarter, value FROM revenue ORDER BY value DESC LIMIT 2", ] - + results = [] for query in queries: result = execute_sql_query(query, test_database) results.append(result) - + # Verify final state - final_result = execute_sql_query("SELECT COUNT(*) as count FROM revenue", test_database) + final_result = execute_sql_query( + "SELECT COUNT(*) as count FROM revenue", test_database + ) final_count = final_result.rows[0]["count"] - + # Assert assert all(r.success for r in results) assert initial_count == final_count # No data should be modified assert initial_count == 4 # Original test data intact - + @pytest.mark.database def test_database_error_recovery(self, test_database): """TEST: Database error recovery and connection resilience""" # Arrange from src.tools.connectors.sql import execute_sql_query from src.core.errors import DatabaseError - + # Test invalid query handling invalid_queries = [ "SELECT * FROM non_existent_table", "SELECT invalid_column FROM revenue", - "SELECT * FROM revenue WHERE invalid_syntax =" + "SELECT * FROM revenue WHERE invalid_syntax =", ] - + # Act & Assert for invalid_query in invalid_queries: with pytest.raises(DatabaseError): execute_sql_query(invalid_query, test_database) - + # Verify database is still functional after errors - recovery_result = execute_sql_query("SELECT COUNT(*) as count FROM revenue", test_database) + recovery_result = execute_sql_query( + "SELECT COUNT(*) as count FROM revenue", test_database + ) assert recovery_result.success == True assert recovery_result.rows[0]["count"] == 4 class TestCacheIntegration: """Integration tests for cache system integration.""" - + @pytest.mark.cache - async def test_cache_integration_with_query_processing(self, test_environment, cache_config): + async def test_cache_integration_with_query_processing( + self, test_environment, cache_config + ): """TEST: Cache integration with complete query processing workflow""" # Arrange cache_manager = CacheManager(config=cache_config) query = "What was Q1-2025 revenue and how does it compare to previous quarters?" - + # First request (cache miss) with test_environment as env: - with patch('src.cache.manager.cache_manager', cache_manager): + with patch("src.cache.manager.cache_manager", cache_manager): # Act - First request first_response = await process_user_query(query) - + # Verify cache miss behavior assert first_response is not None env["anthropic"].messages.create.assert_called() - + # Act - Second request (should be cache hit) env["anthropic"].messages.create.reset_mock() second_response = await process_user_query(query) - + # Assert assert second_response is not None # Second request should hit cache and not call Anthropic env["anthropic"].messages.create.assert_not_called() - + # Verify cache metrics metrics = cache_manager.get_metrics() assert metrics.total_requests >= 2 assert metrics.cache_hits >= 1 - + @pytest.mark.cache def test_cache_invalidation_integration(self, test_environment, cache_config): """TEST: Cache invalidation integration with system updates""" # Arrange cache_manager = CacheManager(config=cache_config) - + # Populate cache with multiple entries test_queries = [ "What is Q1-2025 revenue?", "Show quarterly summary", - "Calculate total 2024 revenue" + "Calculate total 2024 revenue", ] - + for query in test_queries: content = f"Cached response for: {query}" query_hash = cache_manager.generate_hash(query) cache_manager.store(query_hash, content) - + # Verify cache is populated assert cache_manager.get_metrics().total_entries == 3 - + # Act - Trigger cache invalidation (e.g., due to schema change) from src.cache.manager import invalidate_on_schema_change - with patch('src.cache.manager.cache_manager', cache_manager): + + with patch("src.cache.manager.cache_manager", cache_manager): invalidated_count = invalidate_on_schema_change("Database schema updated") - + # Assert assert invalidated_count == 3 assert cache_manager.get_metrics().total_entries == 0 - + # Verify all queries now result in cache misses for query in test_queries: query_hash = cache_manager.generate_hash(query) @@ -515,54 +561,62 @@ def test_cache_invalidation_integration(self, test_environment, cache_config): @pytest.mark.integration class TestSystemStability: """Integration tests for system stability and reliability.""" - - async def test_memory_usage_under_sustained_load(self, test_environment, benchmark_queries): + + async def test_memory_usage_under_sustained_load( + self, test_environment, benchmark_queries + ): """TEST: Memory usage remains stable under sustained load""" # Arrange import psutil import os - + process = psutil.Process(os.getpid()) initial_memory = process.memory_info().rss / 1024 / 1024 # MB - + # Act - Sustained load test query_count = 0 max_queries = 100 - + with test_environment: while query_count < max_queries: query = benchmark_queries[query_count % len(benchmark_queries)] response = await process_user_query(query) - + assert response is not None query_count += 1 - + # Check memory every 10 queries if query_count % 10 == 0: current_memory = process.memory_info().rss / 1024 / 1024 memory_growth = current_memory - initial_memory - + # Memory growth should be reasonable (< 50MB for 100 queries) - assert memory_growth < 50, f"Memory growth {memory_growth:.1f}MB too high" - + assert ( + memory_growth < 50 + ), f"Memory growth {memory_growth:.1f}MB too high" + # Final memory check final_memory = process.memory_info().rss / 1024 / 1024 total_growth = final_memory - initial_memory - + # Assert - assert total_growth < 100, f"Total memory growth {total_growth:.1f}MB exceeds limit" - - async def test_error_rate_under_normal_operation(self, test_environment, benchmark_queries): + assert ( + total_growth < 100 + ), f"Total memory growth {total_growth:.1f}MB exceeds limit" + + async def test_error_rate_under_normal_operation( + self, test_environment, benchmark_queries + ): """TEST: Error rate remains low under normal operation""" # Arrange total_requests = 50 error_count = 0 - + # Act with test_environment: for i in range(total_requests): query = benchmark_queries[i % len(benchmark_queries)] - + try: response = await process_user_query(query) assert response is not None @@ -570,36 +624,40 @@ async def test_error_rate_under_normal_operation(self, test_environment, benchma error_count += 1 # Log error for debugging print(f"Query {i} failed: {str(e)}") - + # Assert error_rate = error_count / total_requests assert error_rate < 0.01, f"Error rate {error_rate:.1%} exceeds 1% threshold" - + def test_configuration_validation_and_recovery(self, test_environment): """TEST: System validates configuration and recovers from invalid configs""" # Arrange from src.core.config import validate_configuration, ConfigurationError - + invalid_configs = [ {}, # Empty config {"ANTHROPIC_API_KEY": ""}, # Empty API key {"ANTHROPIC_API_KEY": "valid", "ARCADE_URL": "invalid-url"}, # Invalid URL - {"ANTHROPIC_API_KEY": "valid", "ARCADE_URL": "http://localhost:9099", "ARCADE_API_KEY": ""} # Empty Arcade key + { + "ANTHROPIC_API_KEY": "valid", + "ARCADE_URL": "http://localhost:9099", + "ARCADE_API_KEY": "", + }, # Empty Arcade key ] - + # Test invalid configurations for config in invalid_configs: - with patch.dict('os.environ', config, clear=True): + with patch.dict("os.environ", config, clear=True): with pytest.raises(ConfigurationError): validate_configuration() - + # Test valid configuration recovery valid_config = { "ANTHROPIC_API_KEY": "test-anthropic-key", "ARCADE_API_KEY": "test-arcade-key", - "ARCADE_URL": "http://localhost:9099" + "ARCADE_URL": "http://localhost:9099", } - - with patch.dict('os.environ', valid_config): + + with patch.dict("os.environ", valid_config): # Should not raise exception - validate_configuration() \ No newline at end of file + validate_configuration() diff --git a/tests/performance/test_benchmarks.py b/tests/performance/test_benchmarks.py index d1bd0fb..7220788 100644 --- a/tests/performance/test_benchmarks.py +++ b/tests/performance/test_benchmarks.py @@ -17,6 +17,7 @@ @dataclass class PerformanceMetrics: """Performance metrics for benchmark testing.""" + response_time_ms: float token_cost: float cache_hit: bool @@ -27,6 +28,7 @@ class PerformanceMetrics: @dataclass class BenchmarkResults: """Benchmark results comparison.""" + fact_metrics: List[PerformanceMetrics] rag_baseline_metrics: List[PerformanceMetrics] improvement_factor: float @@ -35,238 +37,271 @@ class BenchmarkResults: class TestResponseTimeTargets: """Test suite for response time performance targets.""" - + @pytest.mark.performance - async def test_cache_hit_response_under_50ms(self, test_environment, cache_config, benchmark_queries): + async def test_cache_hit_response_under_50ms( + self, test_environment, cache_config, benchmark_queries + ): """TEST: Cache hit responses achieve target latency under 50ms""" # Arrange from src.core.driver import process_user_query from src.cache.manager import CacheManager - + manager = CacheManager(config=cache_config) query = benchmark_queries[0] - + # Pre-warm cache cached_response = "Cached response: Q1-2025 revenue was $1,234,567.89" query_hash = manager.generate_hash(query) manager.store(query_hash, cached_response) - + # Act measurements = [] for _ in range(10): # Multiple measurements for accuracy start_time = time.perf_counter() - - with patch('src.cache.manager.cache_manager', manager): + + with patch("src.cache.manager.cache_manager", manager): response = await process_user_query(query) - + end_time = time.perf_counter() latency_ms = (end_time - start_time) * 1000 measurements.append(latency_ms) - + # Assert avg_latency = statistics.mean(measurements) p95_latency = statistics.quantiles(measurements, n=20)[18] # 95th percentile - - assert avg_latency < 50, f"Average cache hit latency {avg_latency:.2f}ms exceeds 50ms target" - assert p95_latency < 50, f"P95 cache hit latency {p95_latency:.2f}ms exceeds 50ms target" + + assert ( + avg_latency < 50 + ), f"Average cache hit latency {avg_latency:.2f}ms exceeds 50ms target" + assert ( + p95_latency < 50 + ), f"P95 cache hit latency {p95_latency:.2f}ms exceeds 50ms target" assert response is not None - + @pytest.mark.performance - async def test_cache_miss_response_under_140ms(self, test_environment, mock_anthropic_client, benchmark_queries): + async def test_cache_miss_response_under_140ms( + self, test_environment, mock_anthropic_client, benchmark_queries + ): """TEST: Cache miss responses achieve target latency under 140ms""" # Arrange from src.core.driver import process_user_query - + # Configure mock for realistic response time mock_anthropic_client.messages.create = AsyncMock() mock_response = Mock() mock_response.content = [Mock(text="Mock response from Claude")] mock_response.tool_calls = None mock_anthropic_client.messages.create.return_value = mock_response - + # Simulate network latency async def delayed_response(*args, **kwargs): await asyncio.sleep(0.08) # 80ms simulated API call return mock_response - + mock_anthropic_client.messages.create.side_effect = delayed_response - + query = f"Unique query {time.time()}" # Ensure cache miss - + # Act measurements = [] for _ in range(5): # Fewer iterations due to higher latency start_time = time.perf_counter() - - with patch('src.core.driver.anthropic_client', mock_anthropic_client): + + with patch("src.core.driver.anthropic_client", mock_anthropic_client): response = await process_user_query(query) - + end_time = time.perf_counter() latency_ms = (end_time - start_time) * 1000 measurements.append(latency_ms) - + # Assert avg_latency = statistics.mean(measurements) - assert avg_latency < 140, f"Average cache miss latency {avg_latency:.2f}ms exceeds 140ms target" + assert ( + avg_latency < 140 + ), f"Average cache miss latency {avg_latency:.2f}ms exceeds 140ms target" assert response is not None - + @pytest.mark.performance def test_tool_execution_under_10ms_lan(self, test_environment, test_database): """TEST: Tool execution meets LAN latency targets under 10ms""" # Arrange from src.tools.connectors.sql import execute_sql_query + queries = [ "SELECT COUNT(*) as total FROM revenue", "SELECT quarter, value FROM revenue WHERE quarter = 'Q1-2025'", "SELECT AVG(value) as avg_revenue FROM revenue", - "SELECT MAX(value) as max_revenue FROM revenue" + "SELECT MAX(value) as max_revenue FROM revenue", ] - + # Act measurements = [] for query in queries: start_time = time.perf_counter() result = execute_sql_query(query, test_database) end_time = time.perf_counter() - + latency_ms = (end_time - start_time) * 1000 measurements.append(latency_ms) - + assert result.success == True - + # Assert avg_latency = statistics.mean(measurements) max_latency = max(measurements) - - assert avg_latency < 10, f"Average tool execution {avg_latency:.2f}ms exceeds 10ms target" - assert max_latency < 10, f"Max tool execution {max_latency:.2f}ms exceeds 10ms target" - + + assert ( + avg_latency < 10 + ), f"Average tool execution {avg_latency:.2f}ms exceeds 10ms target" + assert ( + max_latency < 10 + ), f"Max tool execution {max_latency:.2f}ms exceeds 10ms target" + @pytest.mark.performance - async def test_overall_system_response_under_100ms(self, test_environment, benchmark_queries): + async def test_overall_system_response_under_100ms( + self, test_environment, benchmark_queries + ): """TEST: Overall system response time meets 100ms average target""" # Arrange from src.core.driver import process_user_query - + # Mix of cache hits and misses mixed_queries = benchmark_queries[:5] # Use first 5 benchmark queries - + # Act measurements = [] for query in mixed_queries: start_time = time.perf_counter() response = await process_user_query(query) end_time = time.perf_counter() - + latency_ms = (end_time - start_time) * 1000 measurements.append(latency_ms) assert response is not None - + # Assert avg_latency = statistics.mean(measurements) - assert avg_latency < 100, f"Average system response {avg_latency:.2f}ms exceeds 100ms target" + assert ( + avg_latency < 100 + ), f"Average system response {avg_latency:.2f}ms exceeds 100ms target" class TestTokenCostOptimization: """Test suite for token cost optimization requirements.""" - + @pytest.mark.cost_analysis - async def test_cache_hit_cost_reduction_90_percent(self, test_environment, performance_targets): + async def test_cache_hit_cost_reduction_90_percent( + self, test_environment, performance_targets + ): """TEST: Cache hits achieve 90% cost reduction vs traditional RAG""" # Arrange from src.core.driver import process_user_query from src.cache.manager import CacheManager, calculate_token_cost - + query = "What was Q1-2025 revenue?" - + # Simulate traditional RAG cost (baseline) baseline_input_tokens = 1500 # Typical RAG prompt with context baseline_output_tokens = 200 - baseline_cost = calculate_token_cost(baseline_input_tokens, baseline_output_tokens) - + baseline_cost = calculate_token_cost( + baseline_input_tokens, baseline_output_tokens + ) + # Test FACT cache hit cost cached_response = "Q1-2025 revenue was $1,234,567.89 based on financial data." - + # Act - with patch('src.cache.manager.get_cached_response') as mock_cache: + with patch("src.cache.manager.get_cached_response") as mock_cache: mock_cache.return_value = cached_response - + # Measure cache hit cost start_tokens = 100 # Minimal prompt for cache lookup - end_tokens = 50 # Cached response tokens + end_tokens = 50 # Cached response tokens fact_cost = calculate_token_cost(start_tokens, end_tokens) - + # Assert cost_reduction = (baseline_cost - fact_cost) / baseline_cost - assert cost_reduction >= performance_targets["cost_reduction_cache_hit"], \ - f"Cache hit cost reduction {cost_reduction:.1%} below 90% target" - + assert ( + cost_reduction >= performance_targets["cost_reduction_cache_hit"] + ), f"Cache hit cost reduction {cost_reduction:.1%} below 90% target" + # Verify actual 90% reduction - assert cost_reduction >= 0.90, f"Cost reduction {cost_reduction:.1%} below 90% requirement" - + assert ( + cost_reduction >= 0.90 + ), f"Cost reduction {cost_reduction:.1%} below 90% requirement" + @pytest.mark.cost_analysis - async def test_cache_miss_cost_reduction_65_percent(self, test_environment, performance_targets): + async def test_cache_miss_cost_reduction_65_percent( + self, test_environment, performance_targets + ): """TEST: Cache misses achieve 65% cost reduction vs traditional RAG""" # Arrange from src.core.driver import process_user_query from src.cache.manager import calculate_token_cost - + query = "What was Q2-2024 revenue breakdown by category?" - + # Simulate traditional RAG cost baseline_input_tokens = 1500 # Large context with vector search results baseline_output_tokens = 300 # Detailed response - baseline_cost = calculate_token_cost(baseline_input_tokens, baseline_output_tokens) - + baseline_cost = calculate_token_cost( + baseline_input_tokens, baseline_output_tokens + ) + # Test FACT cache miss cost (uses tool calls) - fact_input_tokens = 200 # Streamlined prompt + tool call + fact_input_tokens = 200 # Streamlined prompt + tool call fact_output_tokens = 150 # Response with tool results fact_cost = calculate_token_cost(fact_input_tokens, fact_output_tokens) - + # Assert cost_reduction = (baseline_cost - fact_cost) / baseline_cost - assert cost_reduction >= performance_targets["cost_reduction_cache_miss"], \ - f"Cache miss cost reduction {cost_reduction:.1%} below 65% target" - + assert ( + cost_reduction >= performance_targets["cost_reduction_cache_miss"] + ), f"Cache miss cost reduction {cost_reduction:.1%} below 65% target" + # Verify actual 65% reduction - assert cost_reduction >= 0.65, f"Cost reduction {cost_reduction:.1%} below 65% requirement" - + assert ( + cost_reduction >= 0.65 + ), f"Cost reduction {cost_reduction:.1%} below 65% requirement" + @pytest.mark.cost_analysis def test_token_efficiency_metrics_calculation(self, test_environment): """TEST: Token efficiency metrics are calculated accurately""" # Arrange from src.cache.manager import TokenEfficiencyCalculator - + test_scenarios = [ { "scenario": "cache_hit", "input_tokens": 100, "output_tokens": 50, "baseline_input": 1500, - "baseline_output": 200 + "baseline_output": 200, }, { "scenario": "cache_miss", "input_tokens": 200, "output_tokens": 150, "baseline_input": 1500, - "baseline_output": 300 - } + "baseline_output": 300, + }, ] - + calculator = TokenEfficiencyCalculator() - + # Act & Assert for scenario in test_scenarios: efficiency = calculator.calculate_efficiency( fact_input=scenario["input_tokens"], fact_output=scenario["output_tokens"], rag_input=scenario["baseline_input"], - rag_output=scenario["baseline_output"] + rag_output=scenario["baseline_output"], ) - + assert efficiency.fact_total_tokens < efficiency.rag_total_tokens assert efficiency.reduction_percentage > 0 - + if scenario["scenario"] == "cache_hit": assert efficiency.reduction_percentage >= 0.90 else: # cache_miss @@ -275,168 +310,189 @@ def test_token_efficiency_metrics_calculation(self, test_environment): class TestPerformanceComparison: """Test suite for performance comparison with traditional RAG.""" - + @pytest.mark.benchmark - async def test_fact_vs_rag_latency_comparison(self, test_environment, benchmark_queries): + async def test_fact_vs_rag_latency_comparison( + self, test_environment, benchmark_queries + ): """TEST: FACT system latency comparison vs traditional RAG""" # Arrange from src.core.driver import process_user_query - + # Simulate traditional RAG latencies rag_latencies = [ - 450, 520, 380, 490, 410, # Typical RAG response times (400-500ms) - 510, 470, 430, 460, 480 + 450, + 520, + 380, + 490, + 410, # Typical RAG response times (400-500ms) + 510, + 470, + 430, + 460, + 480, ] - + # Test FACT system latencies fact_latencies = [] - + # Act for i, query in enumerate(benchmark_queries[:10]): start_time = time.perf_counter() response = await process_user_query(query) end_time = time.perf_counter() - + latency_ms = (end_time - start_time) * 1000 fact_latencies.append(latency_ms) - + assert response is not None - + # Assert fact_avg = statistics.mean(fact_latencies) rag_avg = statistics.mean(rag_latencies) improvement_factor = rag_avg / fact_avg - - assert improvement_factor >= 3.0, \ - f"FACT latency improvement {improvement_factor:.1f}x below 3x target vs RAG" - + + assert ( + improvement_factor >= 3.0 + ), f"FACT latency improvement {improvement_factor:.1f}x below 3x target vs RAG" + # Verify absolute performance - assert fact_avg < 100, f"FACT average latency {fact_avg:.2f}ms exceeds 100ms target" - + assert ( + fact_avg < 100 + ), f"FACT average latency {fact_avg:.2f}ms exceeds 100ms target" + @pytest.mark.benchmark def test_fact_vs_rag_cost_comparison(self, performance_targets): """TEST: FACT system cost comparison vs traditional RAG""" # Arrange from src.cache.manager import calculate_token_cost - + # Typical query scenarios scenarios = [ {"query_type": "simple", "rag_input": 1200, "rag_output": 150}, {"query_type": "complex", "rag_input": 2000, "rag_output": 400}, - {"query_type": "analytical", "rag_input": 1800, "rag_output": 300} + {"query_type": "analytical", "rag_input": 1800, "rag_output": 300}, ] - + fact_costs = [] rag_costs = [] - + # Act for scenario in scenarios: # RAG costs rag_cost = calculate_token_cost( - scenario["rag_input"], - scenario["rag_output"] + scenario["rag_input"], scenario["rag_output"] ) rag_costs.append(rag_cost) - + # FACT costs (assuming 70% cache hits) cache_hit_cost = calculate_token_cost(80, 40) # Minimal tokens cache_miss_cost = calculate_token_cost(180, 120) # Tool-enhanced - + fact_avg_cost = (0.7 * cache_hit_cost) + (0.3 * cache_miss_cost) fact_costs.append(fact_avg_cost) - + # Assert total_rag_cost = sum(rag_costs) total_fact_cost = sum(fact_costs) cost_reduction = (total_rag_cost - total_fact_cost) / total_rag_cost - - assert cost_reduction >= 0.75, \ - f"Overall cost reduction {cost_reduction:.1%} below 75% target" - + + assert ( + cost_reduction >= 0.75 + ), f"Overall cost reduction {cost_reduction:.1%} below 75% target" + @pytest.mark.benchmark - async def test_throughput_comparison_under_load(self, test_environment, benchmark_queries): + async def test_throughput_comparison_under_load( + self, test_environment, benchmark_queries + ): """TEST: FACT system throughput vs traditional RAG under load""" # Arrange from src.core.driver import process_user_query import asyncio - + concurrent_users = 10 queries_per_user = 5 - + # Simulate concurrent load async def user_session(user_id: int): session_times = [] for i in range(queries_per_user): query = benchmark_queries[i % len(benchmark_queries)] - + start_time = time.perf_counter() response = await process_user_query(f"{query} (user {user_id})") end_time = time.perf_counter() - + session_times.append(end_time - start_time) assert response is not None - + return session_times - + # Act start_time = time.perf_counter() - - user_sessions = await asyncio.gather(*[ - user_session(user_id) for user_id in range(concurrent_users) - ]) - + + user_sessions = await asyncio.gather( + *[user_session(user_id) for user_id in range(concurrent_users)] + ) + end_time = time.perf_counter() total_duration = end_time - start_time - + # Assert total_queries = concurrent_users * queries_per_user throughput_qps = total_queries / total_duration - + # FACT should handle significantly higher throughput than RAG - assert throughput_qps >= 20, f"Throughput {throughput_qps:.1f} QPS below 20 QPS target" - + assert ( + throughput_qps >= 20 + ), f"Throughput {throughput_qps:.1f} QPS below 20 QPS target" + # Verify response time consistency under load all_response_times = [time for session in user_sessions for time in session] avg_response_time = statistics.mean(all_response_times) * 1000 p95_response_time = statistics.quantiles(all_response_times, n=20)[18] * 1000 - - assert avg_response_time < 150, f"Average response time under load {avg_response_time:.2f}ms too high" - assert p95_response_time < 200, f"P95 response time under load {p95_response_time:.2f}ms too high" + + assert ( + avg_response_time < 150 + ), f"Average response time under load {avg_response_time:.2f}ms too high" + assert ( + p95_response_time < 200 + ), f"P95 response time under load {p95_response_time:.2f}ms too high" class TestContinuousBenchmarking: """Test suite for continuous benchmarking and monitoring.""" - + def test_benchmark_data_collection_and_storage(self, test_environment): """TEST: Benchmark data collection and storage for continuous monitoring""" # Arrange from src.monitoring.benchmarks import BenchmarkCollector - + collector = BenchmarkCollector() - + # Simulate benchmark run metrics = PerformanceMetrics( response_time_ms=45.2, token_cost=0.002, cache_hit=True, query_type="financial_query", - timestamp=time.time() + timestamp=time.time(), ) - + # Act collector.record_metric(metrics) stored_metrics = collector.get_recent_metrics(limit=10) - + # Assert assert len(stored_metrics) >= 1 assert stored_metrics[0].response_time_ms == 45.2 assert stored_metrics[0].cache_hit == True - + def test_benchmark_trend_analysis(self, test_environment): """TEST: Benchmark trend analysis for performance monitoring""" # Arrange from src.monitoring.benchmarks import BenchmarkAnalyzer - + # Create sample benchmark data over time historical_data = [] for i in range(30): # 30 data points @@ -445,35 +501,35 @@ def test_benchmark_trend_analysis(self, test_environment): token_cost=0.001 + (i * 0.0001), cache_hit=i % 2 == 0, query_type="test_query", - timestamp=time.time() - (30 - i) * 86400 # Daily data + timestamp=time.time() - (30 - i) * 86400, # Daily data ) historical_data.append(metrics) - + analyzer = BenchmarkAnalyzer(historical_data) - + # Act trend_analysis = analyzer.analyze_trends() - + # Assert assert "response_time_trend" in trend_analysis assert "cost_trend" in trend_analysis assert "cache_hit_rate_trend" in trend_analysis - + # Should detect increasing response time trend assert trend_analysis["response_time_trend"]["direction"] == "increasing" assert trend_analysis["response_time_trend"]["significance"] > 0.5 - + def test_performance_alert_thresholds(self, test_environment, performance_targets): """TEST: Performance alert thresholds for continuous monitoring""" # Arrange from src.monitoring.alerts import PerformanceAlertManager - + alert_manager = PerformanceAlertManager( response_time_threshold=performance_targets["overall_response_ms"], cost_increase_threshold=0.20, # 20% cost increase - cache_hit_rate_threshold=0.60 # 60% minimum hit rate + cache_hit_rate_threshold=0.60, # 60% minimum hit rate ) - + # Test scenarios scenarios = [ { @@ -481,78 +537,82 @@ def test_performance_alert_thresholds(self, test_environment, performance_target "response_time": 80, "cost_increase": 0.05, "cache_hit_rate": 0.75, - "should_alert": False + "should_alert": False, }, { "name": "slow_response", "response_time": 150, "cost_increase": 0.10, "cache_hit_rate": 0.70, - "should_alert": True + "should_alert": True, }, { "name": "high_cost", "response_time": 90, "cost_increase": 0.25, "cache_hit_rate": 0.65, - "should_alert": True + "should_alert": True, }, { "name": "low_cache_hit_rate", "response_time": 85, "cost_increase": 0.08, "cache_hit_rate": 0.45, - "should_alert": True - } + "should_alert": True, + }, ] - + # Act & Assert for scenario in scenarios: alerts = alert_manager.check_thresholds( response_time_ms=scenario["response_time"], cost_increase_ratio=scenario["cost_increase"], - cache_hit_rate=scenario["cache_hit_rate"] + cache_hit_rate=scenario["cache_hit_rate"], ) - + has_alerts = len(alerts) > 0 - assert has_alerts == scenario["should_alert"], \ - f"Alert expectation failed for scenario: {scenario['name']}" - + assert ( + has_alerts == scenario["should_alert"] + ), f"Alert expectation failed for scenario: {scenario['name']}" + def test_benchmark_report_generation(self, test_environment): """TEST: Benchmark report generation for performance tracking""" # Arrange from src.monitoring.reports import BenchmarkReportGenerator - + # Sample benchmark results fact_metrics = [ PerformanceMetrics(45.2, 0.002, True, "query1", time.time()), PerformanceMetrics(65.1, 0.008, False, "query2", time.time()), - PerformanceMetrics(38.7, 0.001, True, "query3", time.time()) + PerformanceMetrics(38.7, 0.001, True, "query3", time.time()), ] - + rag_baseline = [ PerformanceMetrics(420.5, 0.025, False, "query1", time.time()), PerformanceMetrics(380.2, 0.030, False, "query2", time.time()), - PerformanceMetrics(450.8, 0.028, False, "query3", time.time()) + PerformanceMetrics(450.8, 0.028, False, "query3", time.time()), ] - + generator = BenchmarkReportGenerator() - + # Act report = generator.generate_comparison_report( - fact_results=fact_metrics, - rag_baseline=rag_baseline + fact_results=fact_metrics, rag_baseline=rag_baseline ) - + # Assert assert "performance_summary" in report assert "cost_analysis" in report assert "improvement_metrics" in report - + # Verify improvement calculations - assert report["improvement_metrics"]["latency_improvement"] > 5.0 # At least 5x improvement - assert report["improvement_metrics"]["cost_reduction"] > 0.80 # At least 80% cost reduction - + assert ( + report["improvement_metrics"]["latency_improvement"] > 5.0 + ) # At least 5x improvement + assert ( + report["improvement_metrics"]["cost_reduction"] > 0.80 + ) # At least 80% cost reduction + # Verify JSON serializable json_report = json.dumps(report, default=str) assert len(json_report) > 0 @@ -561,52 +621,60 @@ def test_benchmark_report_generation(self, test_environment): @pytest.mark.slow class TestLongRunningBenchmarks: """Long-running benchmark tests for stability and performance over time.""" - + @pytest.mark.performance - async def test_sustained_performance_over_time(self, test_environment, benchmark_queries): + async def test_sustained_performance_over_time( + self, test_environment, benchmark_queries + ): """TEST: System maintains performance over sustained operation""" # Arrange from src.core.driver import process_user_query + duration_minutes = 5 # 5 minute sustained test measurement_interval = 10 # Measure every 10 seconds - + measurements = [] start_time = time.time() end_time = start_time + (duration_minutes * 60) - + # Act query_count = 0 while time.time() < end_time: query = benchmark_queries[query_count % len(benchmark_queries)] - + measurement_start = time.perf_counter() response = await process_user_query(query) measurement_end = time.perf_counter() - + latency_ms = (measurement_end - measurement_start) * 1000 - measurements.append({ - "timestamp": time.time(), - "latency_ms": latency_ms, - "query_count": query_count - }) - + measurements.append( + { + "timestamp": time.time(), + "latency_ms": latency_ms, + "query_count": query_count, + } + ) + query_count += 1 assert response is not None - + # Wait for next measurement interval await asyncio.sleep(measurement_interval) - + # Assert latencies = [m["latency_ms"] for m in measurements] - + # Performance should remain stable early_avg = statistics.mean(latencies[:3]) # First 3 measurements late_avg = statistics.mean(latencies[-3:]) # Last 3 measurements - + performance_degradation = (late_avg - early_avg) / early_avg - assert performance_degradation < 0.20, \ - f"Performance degraded {performance_degradation:.1%} over time" - + assert ( + performance_degradation < 0.20 + ), f"Performance degraded {performance_degradation:.1%} over time" + # Overall performance should meet targets overall_avg = statistics.mean(latencies) - assert overall_avg < 100, f"Sustained performance {overall_avg:.2f}ms exceeds target" \ No newline at end of file + assert ( + overall_avg < 100 + ), f"Sustained performance {overall_avg:.2f}ms exceeds target" diff --git a/tests/test_all_fixes.py b/tests/test_all_fixes.py index eb0ea92..019679a 100644 --- a/tests/test_all_fixes.py +++ b/tests/test_all_fixes.py @@ -18,11 +18,16 @@ from unittest.mock import Mock, AsyncMock, MagicMock # Add the project root to Python path -sys.path.insert(0, os.path.abspath('..')) +sys.path.insert(0, os.path.abspath("..")) from src.core.driver import FACTDriver from src.core.config import Config, get_config -from src.tools.connectors.sql import SQLQueryTool, initialize_sql_tool, sql_query_readonly, sql_get_schema +from src.tools.connectors.sql import ( + SQLQueryTool, + initialize_sql_tool, + sql_query_readonly, + sql_get_schema, +) from src.db.connection import DatabaseManager from src.tools.decorators import get_tool_registry from src.core.errors import InvalidSQLError, SecurityError, DatabaseError @@ -30,38 +35,48 @@ class TestRunner: """Comprehensive test runner for FACT system validation.""" - + def __init__(self): self.passed_tests = 0 self.failed_tests = 0 self.test_results = [] - + async def run_test(self, test_name: str, test_func): """Run a single test and track results.""" print(f"\n{'='*60}") print(f"Running: {test_name}") print(f"{'='*60}") - + try: if asyncio.iscoroutinefunction(test_func): result = await test_func() else: result = test_func() - + if result: print(f"āœ… PASSED: {test_name}") self.passed_tests += 1 - self.test_results.append({"test": test_name, "status": "PASSED", "error": None}) + self.test_results.append( + {"test": test_name, "status": "PASSED", "error": None} + ) else: print(f"āŒ FAILED: {test_name} - Test returned False") self.failed_tests += 1 - self.test_results.append({"test": test_name, "status": "FAILED", "error": "Test returned False"}) - + self.test_results.append( + { + "test": test_name, + "status": "FAILED", + "error": "Test returned False", + } + ) + except Exception as e: print(f"āŒ FAILED: {test_name} - {type(e).__name__}: {e}") self.failed_tests += 1 - self.test_results.append({"test": test_name, "status": "FAILED", "error": str(e)}) - + self.test_results.append( + {"test": test_name, "status": "FAILED", "error": str(e)} + ) + def print_summary(self): """Print test summary.""" total_tests = self.passed_tests + self.failed_tests @@ -71,8 +86,10 @@ def print_summary(self): print(f"Total Tests: {total_tests}") print(f"Passed: {self.passed_tests}") print(f"Failed: {self.failed_tests}") - print(f"Success Rate: {(self.passed_tests/total_tests*100) if total_tests > 0 else 0:.1f}%") - + print( + f"Success Rate: {(self.passed_tests/total_tests*100) if total_tests > 0 else 0:.1f}%" + ) + if self.failed_tests > 0: print(f"\nFailed Tests:") for result in self.test_results: @@ -82,14 +99,14 @@ def print_summary(self): async def test_sql_statement_validation(): """Test 1: SQL statement handling with None checks, empty strings, and non-string inputs.""" - + # Initialize database and SQL tool db_manager = DatabaseManager("db/test_validation.db") await db_manager.initialize_database() initialize_sql_tool(db_manager) - + print("Testing SQL statement validation fixes...") - + # Test 1.1: None statement - Test both decorator validation and internal validation print("\n1.1 Testing None statement...") try: @@ -105,7 +122,7 @@ async def test_sql_statement_validation(): else: print(f"āœ— None statement test failed: {e}") return False - + # Test 1.2: Empty string statement print("\n1.2 Testing empty string statement...") try: @@ -119,7 +136,7 @@ async def test_sql_statement_validation(): else: print(f"āœ— Empty string test failed: {e}") return False - + # Test 1.3: Non-string statement print("\n1.3 Testing non-string statement...") try: @@ -133,7 +150,7 @@ async def test_sql_statement_validation(): else: print(f"āœ— Non-string test failed: {e}") return False - + # Test 1.4: Valid query print("\n1.4 Testing valid SQL query...") try: @@ -145,7 +162,7 @@ async def test_sql_statement_validation(): except Exception as e: print(f"āœ— Valid query test failed: {e}") return False - + # Test 1.5: Schema retrieval (tests None/length checks in schema function) print("\n1.5 Testing schema retrieval with None checks...") try: @@ -157,50 +174,50 @@ async def test_sql_statement_validation(): except Exception as e: print(f"āœ— Schema retrieval test failed: {e}") return False - + return True async def test_tool_execution_async_handling(): """Test 2: Tool execution with proper async/await handling.""" - + print("Testing tool execution and async/await handling...") - + # Test 2.1: Tool registry functionality print("\n2.1 Testing tool registry...") registry = get_tool_registry() tools = registry.list_tools() - + if not tools: print("āœ— No tools registered") return False - + print(f"āœ“ Found {len(tools)} registered tools: {tools}") - + # Test 2.2: Tool schema export print("\n2.2 Testing tool schema export...") schemas = registry.export_all_schemas() - + if not schemas: print("āœ— No schemas exported") return False - + print(f"āœ“ Exported {len(schemas)} tool schemas") - + # Verify schema format for schema in schemas: if "function" not in schema or "name" not in schema["function"]: print(f"āœ— Invalid schema format: {schema}") return False - + # Test 2.3: Async tool execution print("\n2.3 Testing async tool execution...") - + # Initialize database for tool testing db_manager = DatabaseManager("db/test_validation.db") await db_manager.initialize_database() initialize_sql_tool(db_manager) - + try: # Test async tool directly result = await sql_query_readonly("SELECT 1 as test_value") @@ -210,36 +227,40 @@ async def test_tool_execution_async_handling(): except Exception as e: print(f"āœ— Async tool execution failed: {e}") return False - + return True async def test_llm_response_processing(): """Test 3: LLM response processing with tool calls.""" - + print("Testing LLM response processing...") - + # Create a mock LLM response with tool calls mock_tool_call = Mock() mock_tool_call.id = "call_123" mock_tool_call.type = "function" mock_tool_call.function = Mock() - mock_tool_call.function.name = "SQL_QueryReadonly" # Note: underscores for API compatibility - mock_tool_call.function.arguments = '{"statement": "SELECT COUNT(*) FROM companies"}' - + mock_tool_call.function.name = ( + "SQL_QueryReadonly" # Note: underscores for API compatibility + ) + mock_tool_call.function.arguments = ( + '{"statement": "SELECT COUNT(*) FROM companies"}' + ) + mock_message = Mock() mock_message.content = "I'll help you query the database." mock_message.tool_calls = [mock_tool_call] - + mock_choice = Mock() mock_choice.message = mock_message - + mock_response = Mock() mock_response.choices = [mock_choice] - + # Test 3.1: Tool call extraction print("\n3.1 Testing tool call extraction...") - + # Simulate the driver's tool call processing try: tool_calls = mock_response.choices[0].message.tool_calls @@ -249,39 +270,41 @@ async def test_llm_response_processing(): except Exception as e: print(f"āœ— Tool call extraction failed: {e}") return False - + # Test 3.2: Tool execution simulation print("\n3.2 Testing tool execution from LLM response...") - + # Initialize database and tools db_manager = DatabaseManager("db/test_validation.db") await db_manager.initialize_database() initialize_sql_tool(db_manager) - + try: # Simulate what the driver does tool_call = mock_tool_call - tool_name = tool_call.function.name.replace('_', '.') # Convert back to registry format + tool_name = tool_call.function.name.replace( + "_", "." + ) # Convert back to registry format tool_args = json.loads(tool_call.function.arguments) - + # Get tool from registry registry = get_tool_registry() tool_definition = registry.get_tool(tool_name) - + # Execute tool (async) result = await tool_definition.function(**tool_args) - + assert isinstance(result, dict) assert result["status"] == "success" print("āœ“ Tool execution from LLM response successful") - + except Exception as e: print(f"āœ— Tool execution from LLM response failed: {e}") return False - + # Test 3.3: Response content handling (test None response handling) print("\n3.3 Testing None response content handling...") - + # Test the driver's content extraction logic try: # Simulate various response scenarios @@ -290,60 +313,60 @@ async def test_llm_response_processing(): {"content": None, "expected": None}, {"content": "", "expected": ""}, ] - + for i, case in enumerate(test_cases): mock_resp = Mock() mock_resp.content = case["content"] - + # Simulate driver logic - if hasattr(mock_resp, 'content'): + if hasattr(mock_resp, "content"): response_text = mock_resp.content else: response_text = None - + if case["expected"] is None: assert response_text is None else: assert response_text == case["expected"] - + print("āœ“ Response content handling working correctly") - + except Exception as e: print(f"āœ— Response content handling failed: {e}") return False - + return True async def test_full_system_functionality(): """Test 4: Full system test with end-to-end functionality.""" - + print("Testing full system functionality...") - + # Test 4.1: Driver initialization print("\n4.1 Testing driver initialization...") - + try: # Create test config config = get_config() # Can't modify config directly, so create driver with test database path - + # Create driver with default config but test database driver = FACTDriver(config) # Override the database path after initialization driver.config.database_path = "db/test_validation.db" await driver.initialize() - + assert driver._initialized == True print("āœ“ Driver initialized successfully") - + except Exception as e: print(f"āœ— Driver initialization failed: {e}") return False - + # Test 4.2: System metrics print("\n4.2 Testing system metrics...") - + try: metrics = driver.get_metrics() assert isinstance(metrics, dict) @@ -351,140 +374,146 @@ async def test_full_system_functionality(): assert "initialized" in metrics assert metrics["initialized"] == True print("āœ“ System metrics working") - + except Exception as e: print(f"āœ— System metrics failed: {e}") return False - + # Test 4.3: Database operations through driver print("\n4.3 Testing database operations through system...") - + try: # Test database info db_info = await driver.database_manager.get_database_info() assert isinstance(db_info, dict) assert "tables" in db_info print("āœ“ Database operations working") - + except Exception as e: print(f"āœ— Database operations failed: {e}") return False - + # Test 4.4: Error handling in driver print("\n4.4 Testing error handling in driver...") - + try: # Create a driver with invalid OpenAI key to test error handling test_config = get_config() test_config.openai_api_key = "invalid_key_for_testing" test_config.database_path = "db/test_validation.db" - + error_driver = FACTDriver(test_config) await error_driver.initialize() # Should not crash, just log warnings - + print("āœ“ Error handling in driver working (allows system to continue)") - + except Exception as e: # This is expected - driver should handle errors gracefully print(f"āœ“ Driver error handling working: {type(e).__name__}") - + # Clean up await driver.shutdown() - + return True async def test_demo_queries(): """Test 5: Demo queries to confirm end-to-end functionality.""" - + print("Testing demo queries for end-to-end validation...") - + # Initialize system db_manager = DatabaseManager("db/test_validation.db") await db_manager.initialize_database() initialize_sql_tool(db_manager) - + # Test queries that should work demo_queries = [ { "name": "Company count", "query": "SELECT COUNT(*) as total_companies FROM companies", - "should_work": True + "should_work": True, }, { "name": "Technology companies", "query": "SELECT name, symbol FROM companies WHERE sector = 'Technology'", - "should_work": True + "should_work": True, }, { "name": "Latest financial data", "query": "SELECT c.name, f.revenue FROM companies c JOIN financial_records f ON c.id = f.company_id WHERE f.year = 2025 AND f.quarter = 'Q1' LIMIT 5", - "should_work": True + "should_work": True, }, { "name": "Schema information", "query": "PRAGMA table_info(companies)", - "should_work": True + "should_work": True, }, { "name": "Invalid dangerous query", "query": "DROP TABLE companies", - "should_work": False + "should_work": False, }, { "name": "SQL injection attempt", "query": "SELECT * FROM companies WHERE id = 1; DROP TABLE companies; --", - "should_work": False - } + "should_work": False, + }, ] - + success_count = 0 - + for i, test_case in enumerate(demo_queries, 1): print(f"\n5.{i} Testing: {test_case['name']}") print(f"Query: {test_case['query']}") - + try: - result = await sql_query_readonly(test_case['query']) - - if test_case['should_work']: + result = await sql_query_readonly(test_case["query"]) + + if test_case["should_work"]: if result["status"] == "success": print(f"āœ“ Query succeeded as expected") success_count += 1 else: - print(f"āœ— Query should have succeeded but failed: {result.get('error', 'Unknown error')}") + print( + f"āœ— Query should have succeeded but failed: {result.get('error', 'Unknown error')}" + ) else: if result["status"] == "failed": - print(f"āœ“ Query correctly rejected: {result.get('error', 'Security violation')}") + print( + f"āœ“ Query correctly rejected: {result.get('error', 'Security violation')}" + ) success_count += 1 else: print(f"āœ— Dangerous query should have been rejected but succeeded") - + except Exception as e: - if test_case['should_work']: + if test_case["should_work"]: print(f"āœ— Query failed unexpectedly: {e}") else: print(f"āœ“ Query correctly failed with exception: {e}") success_count += 1 - + print(f"\nDemo queries: {success_count}/{len(demo_queries)} passed") return success_count == len(demo_queries) async def test_edge_cases(): """Test 6: Edge cases and boundary conditions.""" - + print("Testing edge cases and boundary conditions...") - + # Initialize system db_manager = DatabaseManager("db/test_validation.db") await db_manager.initialize_database() initialize_sql_tool(db_manager) - + # Test 6.1: Very long query (should be rejected) print("\n6.1 Testing very long query...") try: - long_query = "SELECT * FROM companies WHERE " + " OR ".join([f"id = {i}" for i in range(1000)]) + long_query = "SELECT * FROM companies WHERE " + " OR ".join( + [f"id = {i}" for i in range(1000)] + ) result = await sql_query_readonly(long_query) assert result["status"] == "failed" assert "too long" in result.get("error", "").lower() @@ -495,23 +524,28 @@ async def test_edge_cases(): else: print(f"āœ— Long query test failed: {e}") return False - + # Test 6.2: Complex nested query (should be rejected) print("\n6.2 Testing complex nested query...") try: nested_query = "SELECT * FROM (SELECT * FROM (SELECT * FROM (SELECT * FROM (SELECT * FROM (SELECT * FROM companies)))))" result = await sql_query_readonly(nested_query) assert result["status"] == "failed" - assert "nested" in result.get("error", "").lower() or "subqueries" in result.get("error", "").lower() + assert ( + "nested" in result.get("error", "").lower() + or "subqueries" in result.get("error", "").lower() + ) print("āœ“ Complex nested query rejected correctly") except Exception as e: print(f"āœ— Nested query test failed: {e}") return False - + # Test 6.3: Query with special characters print("\n6.3 Testing query with special characters...") try: - special_query = "SELECT name FROM companies WHERE name LIKE '%&%' OR name LIKE '%@%'" + special_query = ( + "SELECT name FROM companies WHERE name LIKE '%&%' OR name LIKE '%@%'" + ) result = await sql_query_readonly(special_query) # This should work - special characters in data are OK assert result["status"] == "success" @@ -519,24 +553,26 @@ async def test_edge_cases(): except Exception as e: print(f"āœ— Special characters test failed: {e}") return False - + # Test 6.4: Unicode query print("\n6.4 Testing unicode query...") try: - unicode_query = "SELECT name FROM companies WHERE name LIKE '%Ć©%' OR name LIKE '%äø­%'" + unicode_query = ( + "SELECT name FROM companies WHERE name LIKE '%Ć©%' OR name LIKE '%äø­%'" + ) result = await sql_query_readonly(unicode_query) assert result["status"] == "success" print("āœ“ Unicode query handled correctly") except Exception as e: print(f"āœ— Unicode query test failed: {e}") return False - + return True async def main(): """Main test execution function.""" - + print("🧪 FACT System Comprehensive Validation Test") print("=" * 60) print("Testing all major fixes and functionality:") @@ -546,24 +582,26 @@ async def main(): print("4. Full system end-to-end functionality") print("5. Demo queries validation") print("6. Edge cases and boundary conditions") - + runner = TestRunner() - + # Run all test suites await runner.run_test("SQL Statement Validation", test_sql_statement_validation) - await runner.run_test("Tool Execution & Async Handling", test_tool_execution_async_handling) + await runner.run_test( + "Tool Execution & Async Handling", test_tool_execution_async_handling + ) await runner.run_test("LLM Response Processing", test_llm_response_processing) await runner.run_test("Full System Functionality", test_full_system_functionality) await runner.run_test("Demo Queries Validation", test_demo_queries) await runner.run_test("Edge Cases & Boundary Conditions", test_edge_cases) - + # Print final summary runner.print_summary() - + # Return appropriate exit code return 0 if runner.failed_tests == 0 else 1 if __name__ == "__main__": exit_code = asyncio.run(main()) - sys.exit(exit_code) \ No newline at end of file + sys.exit(exit_code) diff --git a/tests/test_basic_functionality.py b/tests/test_basic_functionality.py index ef17fbd..c30c47f 100644 --- a/tests/test_basic_functionality.py +++ b/tests/test_basic_functionality.py @@ -12,7 +12,8 @@ # Add src to path for testing import sys -sys.path.insert(0, str(Path(__file__).parent.parent / 'src')) + +sys.path.insert(0, str(Path(__file__).parent.parent / "src")) from src.core.config import Config from src.core.driver import FACTDriver @@ -24,15 +25,15 @@ @pytest.fixture async def temp_database(): """Create a temporary database for testing.""" - with tempfile.NamedTemporaryFile(suffix='.db', delete=False) as tmp: + with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as tmp: db_path = tmp.name - + # Initialize database db_manager = DatabaseManager(db_path) await db_manager.initialize_database() - + yield db_manager - + # Cleanup os.unlink(db_path) @@ -42,40 +43,42 @@ def test_config(temp_database): """Create test configuration with temporary database.""" config = Config() config._config = { - 'database_path': temp_database.database_path, - 'anthropic_api_key': 'test_key', - 'arcade_api_key': 'test_key', - 'cache_prefix': 'test_v1', - 'claude_model': 'claude-3-5-sonnet-20241022', - 'system_prompt': 'Test prompt' + "database_path": temp_database.database_path, + "anthropic_api_key": "test_key", + "arcade_api_key": "test_key", + "cache_prefix": "test_v1", + "claude_model": "claude-3-5-sonnet-20241022", + "system_prompt": "Test prompt", } return config class TestDatabaseIntegration: """Test database integration and SQL tools.""" - + @pytest.mark.asyncio async def test_database_initialization(self, temp_database): """Test that database initializes with schema and sample data.""" db_info = await temp_database.get_database_info() - - assert db_info['total_tables'] >= 2 - assert 'companies' in db_info['tables'] - assert 'financial_records' in db_info['tables'] - assert db_info['tables']['companies']['row_count'] > 0 - assert db_info['tables']['financial_records']['row_count'] > 0 - + + assert db_info["total_tables"] >= 2 + assert "companies" in db_info["tables"] + assert "financial_records" in db_info["tables"] + assert db_info["tables"]["companies"]["row_count"] > 0 + assert db_info["tables"]["financial_records"]["row_count"] > 0 + @pytest.mark.asyncio async def test_sql_query_execution(self, temp_database): """Test SQL query execution with security validation.""" # Test valid SELECT query - result = await temp_database.execute_query("SELECT COUNT(*) as count FROM companies") - + result = await temp_database.execute_query( + "SELECT COUNT(*) as count FROM companies" + ) + assert result.row_count == 1 - assert 'count' in result.columns - assert result.rows[0]['count'] > 0 - + assert "count" in result.columns + assert result.rows[0]["count"] > 0 + @pytest.mark.asyncio async def test_sql_security_validation(self, temp_database): """Test that dangerous SQL operations are blocked.""" @@ -83,9 +86,9 @@ async def test_sql_security_validation(self, temp_database): "DROP TABLE companies", "DELETE FROM companies", "UPDATE companies SET name = 'hacked'", - "INSERT INTO companies VALUES (999, 'hack')" + "INSERT INTO companies VALUES (999, 'hack')", ] - + for query in dangerous_queries: with pytest.raises(Exception): # Should raise SecurityError await temp_database.execute_query(query) @@ -93,83 +96,83 @@ async def test_sql_security_validation(self, temp_database): class TestToolFramework: """Test tool registration and execution framework.""" - + def test_tool_registry_functionality(self): """Test tool registry operations.""" registry = get_tool_registry() - + # Check that SQL tools are registered tool_names = registry.list_tools() assert len(tool_names) > 0 - + # Check schema export schemas = registry.export_all_schemas() assert len(schemas) > 0 - assert all('function' in schema for schema in schemas) - + assert all("function" in schema for schema in schemas) + @pytest.mark.asyncio async def test_sql_tool_integration(self, temp_database): """Test SQL tool initialization and execution.""" # Initialize SQL tool initialize_sql_tool(temp_database) - + # Test tool execution through registry registry = get_tool_registry() sql_tool = registry.get_tool("SQL.QueryReadonly") - + # Execute a test query result = await sql_tool.function(statement="SELECT name FROM companies LIMIT 1") - - assert result['status'] == 'success' - assert 'rows' in result - assert len(result['rows']) > 0 + + assert result["status"] == "success" + assert "rows" in result + assert len(result["rows"]) > 0 class TestSystemIntegration: """Test complete system integration.""" - + @pytest.mark.asyncio async def test_driver_initialization(self, test_config, temp_database): """Test FACT driver initialization.""" # Note: This test would need mocking for actual API calls driver = FACTDriver(test_config) - + # Initialize database component driver.database_manager = temp_database await driver._initialize_tools() - + # Check that tools are registered assert len(driver.tool_registry.list_tools()) > 0 - + # Check metrics metrics = driver.get_metrics() - assert 'total_queries' in metrics - assert 'cache_hits' in metrics + assert "total_queries" in metrics + assert "cache_hits" in metrics class TestConfigurationManagement: """Test configuration loading and validation.""" - + def test_config_creation(self): """Test configuration object creation.""" config = Config() - + # Test property access - assert hasattr(config, 'anthropic_api_key') - assert hasattr(config, 'arcade_api_key') - assert hasattr(config, 'database_path') - assert hasattr(config, 'cache_prefix') - + assert hasattr(config, "anthropic_api_key") + assert hasattr(config, "arcade_api_key") + assert hasattr(config, "database_path") + assert hasattr(config, "cache_prefix") + def test_config_dictionary_export(self): """Test configuration export functionality.""" config = Config() config_dict = config.to_dict() - + assert isinstance(config_dict, dict) - assert 'database_path' in config_dict - assert 'cache_prefix' in config_dict + assert "database_path" in config_dict + assert "cache_prefix" in config_dict # Sensitive keys should be masked - assert config_dict.get('anthropic_api_key') in [None, '***'] + assert config_dict.get("anthropic_api_key") in [None, "***"] @pytest.mark.asyncio @@ -177,30 +180,30 @@ async def test_end_to_end_workflow(temp_database): """Test a complete end-to-end workflow.""" # Initialize SQL tool initialize_sql_tool(temp_database) - + # Get tool registry registry = get_tool_registry() - + # Test that we can execute a sample query through the tool sql_tool = registry.get_tool("SQL.QueryReadonly") - + # Execute query result = await sql_tool.function( statement="SELECT name, sector FROM companies WHERE sector = 'Technology'" ) - + # Verify result structure - assert result['status'] == 'success' - assert 'rows' in result - assert 'columns' in result - assert 'execution_time_ms' in result - + assert result["status"] == "success" + assert "rows" in result + assert "columns" in result + assert "execution_time_ms" in result + # Verify data content - if result['rows']: - assert 'name' in result['rows'][0] - assert 'sector' in result['rows'][0] + if result["rows"]: + assert "name" in result["rows"][0] + assert "sector" in result["rows"][0] if __name__ == "__main__": # Run tests directly - pytest.main([__file__, "-v"]) \ No newline at end of file + pytest.main([__file__, "-v"]) diff --git a/tests/test_fixes_summary.py b/tests/test_fixes_summary.py index 1b343c5..8e94425 100644 --- a/tests/test_fixes_summary.py +++ b/tests/test_fixes_summary.py @@ -14,30 +14,35 @@ import os # Add the project root to Python path -sys.path.insert(0, os.path.abspath('..')) +sys.path.insert(0, os.path.abspath("..")) -from src.tools.connectors.sql import SQLQueryTool, initialize_sql_tool, sql_query_readonly, sql_get_schema +from src.tools.connectors.sql import ( + SQLQueryTool, + initialize_sql_tool, + sql_query_readonly, + sql_get_schema, +) from src.db.connection import DatabaseManager from src.tools.decorators import get_tool_registry async def main(): """Main test function demonstrating key fixes.""" - + print("🧪 FACT System Key Fixes Validation") print("=" * 50) - + # Initialize database and tools print("\nšŸ“‹ Initializing system...") db_manager = DatabaseManager("db/test_validation.db") await db_manager.initialize_database() initialize_sql_tool(db_manager) print("āœ… System initialized") - + # Test 1: SQL Input Validation Fixes print("\nšŸ” Test 1: SQL Input Validation Fixes") print("-" * 40) - + # Test None input print("Testing None input:") try: @@ -48,7 +53,7 @@ async def main(): print("āœ… None input correctly rejected") else: print(f"āŒ Unexpected error: {e}") - + # Test empty string print("Testing empty string:") try: @@ -59,7 +64,7 @@ async def main(): print("āœ… Empty string correctly rejected") else: print(f"āŒ Unexpected error: {e}") - + # Test non-string print("Testing non-string input:") try: @@ -70,11 +75,11 @@ async def main(): print("āœ… Non-string input correctly rejected") else: print(f"āŒ Unexpected error: {e}") - + # Test 2: Async Tool Execution print("\n⚔ Test 2: Async Tool Execution") print("-" * 40) - + # Test simple query that should work print("Testing async tool execution:") try: @@ -85,79 +90,95 @@ async def main(): print(f"āŒ Tool execution failed: {result.get('error', 'Unknown error')}") except Exception as e: print(f"āŒ Tool execution exception: {e}") - + # Test 3: Tool Registry and Schema Export print("\nšŸ“‹ Test 3: Tool Registry and Schema Export") print("-" * 40) - + registry = get_tool_registry() tools = registry.list_tools() print(f"Registered tools: {len(tools)}") - + schemas = registry.export_all_schemas() print(f"Exported schemas: {len(schemas)}") - + if len(tools) > 0 and len(schemas) > 0: print("āœ… Tool registry working correctly") else: print("āŒ Tool registry issues") - + # Test 4: Security Validation print("\nšŸ”’ Test 4: Security Validation") print("-" * 40) - + # Test dangerous query rejection print("Testing dangerous query rejection:") try: result = await sql_query_readonly("DROP TABLE companies") # The tool returns a dict with status=failed for security violations - if isinstance(result, dict) and result.get("status") == "failed" and "error" in result: + if ( + isinstance(result, dict) + and result.get("status") == "failed" + and "error" in result + ): print("āœ… Dangerous query correctly rejected") else: print("āŒ Dangerous query should have been rejected") except Exception as e: print(f"āœ… Dangerous query rejected with exception: {type(e).__name__}") - + # Test injection attempt print("Testing SQL injection attempt:") try: - result = await sql_query_readonly("SELECT * FROM users WHERE id = 1; DROP TABLE users; --") + result = await sql_query_readonly( + "SELECT * FROM users WHERE id = 1; DROP TABLE users; --" + ) # The tool returns a dict with status=failed for security violations - if isinstance(result, dict) and result.get("status") == "failed" and "error" in result: + if ( + isinstance(result, dict) + and result.get("status") == "failed" + and "error" in result + ): print("āœ… SQL injection attempt correctly rejected") else: print("āŒ SQL injection should have been rejected") except Exception as e: print(f"āœ… SQL injection rejected with exception: {type(e).__name__}") - + # Test 5: Schema Operations (tests None handling in schema functions) print("\nšŸ“Š Test 5: Schema Operations") print("-" * 40) - + print("Testing schema retrieval:") try: schema_result = await sql_get_schema() if schema_result.get("status") == "success": - print(f"āœ… Schema retrieval working - found {schema_result.get('total_tables', 0)} tables") + print( + f"āœ… Schema retrieval working - found {schema_result.get('total_tables', 0)} tables" + ) else: - print(f"āŒ Schema retrieval failed: {schema_result.get('error', 'Unknown error')}") + print( + f"āŒ Schema retrieval failed: {schema_result.get('error', 'Unknown error')}" + ) except Exception as e: print(f"āŒ Schema retrieval exception: {e}") - + # Test 6: Database Operations print("\nšŸ’¾ Test 6: Database Operations") print("-" * 40) - + print("Testing database info:") try: db_info = await db_manager.get_database_info() if isinstance(db_info, dict) and "tables" in db_info: - print(f"āœ… Database operations working - {db_info.get('total_tables', 0)} tables") + print( + f"āœ… Database operations working - {db_info.get('total_tables', 0)} tables" + ) else: print("āŒ Database info retrieval failed") except Exception as e: print(f"āŒ Database info exception: {e}") - + # Summary print("\n" + "=" * 50) print("šŸŽÆ SUMMARY") @@ -174,4 +195,4 @@ async def main(): if __name__ == "__main__": - asyncio.run(main()) \ No newline at end of file + asyncio.run(main()) diff --git a/tests/test_imports.py b/tests/test_imports.py index 2c19a7c..3bf69c0 100644 --- a/tests/test_imports.py +++ b/tests/test_imports.py @@ -13,237 +13,242 @@ from pathlib import Path # Add src to path for testing -sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'src')) +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "src")) + class TestImportResolution(unittest.TestCase): """Test that all imports work correctly across the FACT system.""" - + def setUp(self): """Set up test environment.""" - self.src_path = Path(__file__).parent.parent / 'src' + self.src_path = Path(__file__).parent.parent / "src" self.assertTrue(self.src_path.exists(), "Source directory must exist") - + def test_core_module_imports(self): """Test that core module imports work correctly.""" try: # Test core config import from core.config import Config, get_config - + # Test core driver import from core.driver import get_driver, FACTDriver - + # Test core errors import from core.errors import FACTError, ConfigurationError - + print("āœ… Core module imports successful") except ImportError as e: self.fail(f"Core module import failed: {e}") - + def test_db_module_imports(self): """Test that database module imports work correctly.""" try: # Test db connection import from db.connection import DatabaseManager - - # Test db models import + + # Test db models import from db.models import DATABASE_SCHEMA, QueryResult - + print("āœ… Database module imports successful") except ImportError as e: self.fail(f"Database module import failed: {e}") - + def test_tools_module_imports(self): """Test that tools module imports work correctly.""" try: # Test tools decorators import from tools.decorators import Tool, get_tool_registry - + # Test tools executor import from tools.executor import ToolExecutor, ToolCall - + # Test tools validation import from tools.validation import ParameterValidator - + print("āœ… Tools module imports successful") except ImportError as e: self.fail(f"Tools module import failed: {e}") - + def test_security_module_imports(self): """Test that security module imports work correctly.""" try: # Test security auth import from security.auth import AuthorizationManager - + # Test security config import from security.config import SecurityConfig - + print("āœ… Security module imports successful") except ImportError as e: self.fail(f"Security module import failed: {e}") - + def test_arcade_module_imports(self): """Test that arcade module imports work correctly.""" try: # Test arcade client import from arcade.client import ArcadeClient - + # Test arcade gateway import from arcade.gateway import ArcadeGateway - + # Test arcade errors import from arcade.errors import ArcadeError - + print("āœ… Arcade module imports successful") except ImportError as e: self.fail(f"Arcade module import failed: {e}") - + def test_cache_module_imports(self): """Test that cache module imports work correctly.""" try: # Test cache manager import from cache.manager import CacheManager - + # Test cache strategy import from cache.strategy import CacheStrategy - + print("āœ… Cache module imports successful") except ImportError as e: self.fail(f"Cache module import failed: {e}") - + def test_monitoring_module_imports(self): """Test that monitoring module imports work correctly.""" try: # Test monitoring metrics import from monitoring.metrics import MetricsCollector - + print("āœ… Monitoring module imports successful") except ImportError as e: self.fail(f"Monitoring module import failed: {e}") - + def test_tools_connectors_imports(self): """Test that tools connector imports work correctly.""" try: # Test SQL connector import from tools.connectors.sql import initialize_sql_tool - + # Test HTTP connector import from tools.connectors.http import http_get - + # Test file connector import from tools.connectors.file import read_file - + print("āœ… Tools connector imports successful") except ImportError as e: self.fail(f"Tools connector import failed: {e}") - + def test_benchmarking_module_imports(self): """Test that benchmarking module imports work correctly.""" try: # Test benchmarking framework import from benchmarking.framework import BenchmarkFramework - - # Test benchmarking profiler import + + # Test benchmarking profiler import from benchmarking.profiler import SystemProfiler - + print("āœ… Benchmarking module imports successful") except ImportError as e: self.fail(f"Benchmarking module import failed: {e}") - + def test_no_relative_imports_beyond_package(self): """Test that no files contain problematic relative imports.""" problematic_patterns = [ - 'from ..', # Relative imports that might go beyond package + "from ..", # Relative imports that might go beyond package ] - + problematic_files = [] - - for py_file in self.src_path.rglob('*.py'): + + for py_file in self.src_path.rglob("*.py"): try: - with open(py_file, 'r', encoding='utf-8') as f: + with open(py_file, "r", encoding="utf-8") as f: content = f.read() - + for pattern in problematic_patterns: if pattern in content: # Check if this is actually problematic by context lines = content.splitlines() for i, line in enumerate(lines): - if pattern in line and 'import' in line: + if pattern in line and "import" in line: # This could be problematic - problematic_files.append(f"{py_file}:{i+1}: {line.strip()}") + problematic_files.append( + f"{py_file}:{i+1}: {line.strip()}" + ) except Exception as e: self.fail(f"Error reading {py_file}: {e}") - + if problematic_files: print("āš ļø Found potentially problematic imports:") for item in problematic_files: print(f" {item}") else: print("āœ… No problematic relative imports found") - + def test_script_execution_context(self): """Test that the system can be imported from script context.""" # Simulate the context where scripts/init_environment.py imports original_path = sys.path.copy() - + try: # Add the specific path that init_environment.py uses - script_src_path = os.path.join(os.path.dirname(__file__), '..', 'src') + script_src_path = os.path.join(os.path.dirname(__file__), "..", "src") if script_src_path not in sys.path: sys.path.insert(0, script_src_path) - + # Try to import the same modules that init_environment.py imports from core.config import Config from core.driver import get_driver from db.connection import DatabaseManager - + print("āœ… Script execution context imports successful") - + except ImportError as e: self.fail(f"Script context import failed: {e}") finally: sys.path = original_path + class TestMainEntryPoints(unittest.TestCase): """Test main entry points work correctly.""" - + def test_main_py_init_imports(self): """Test that main.py can import required modules for init command.""" try: # These are the imports used in main.py validate command from core.config import get_config from core.driver import get_driver - + print("āœ… Main.py entry point imports successful") except ImportError as e: self.fail(f"Main.py import failed: {e}") - + def test_cli_module_imports(self): """Test that CLI module can be imported correctly.""" try: # Test CLI module import from core.cli import main - - print("āœ… CLI module imports successful") + + print("āœ… CLI module imports successful") except ImportError as e: self.fail(f"CLI module import failed: {e}") + def run_import_tests(): """Run all import tests and provide summary.""" print("🧪 Running FACT System Import Tests") print("=" * 50) - + # Create test suite loader = unittest.TestLoader() suite = unittest.TestSuite() - + # Add test classes suite.addTests(loader.loadTestsFromTestCase(TestImportResolution)) suite.addTests(loader.loadTestsFromTestCase(TestMainEntryPoints)) - + # Run tests runner = unittest.TextTestRunner(verbosity=2) result = runner.run(suite) - + # Print summary print("\n" + "=" * 50) if result.wasSuccessful(): @@ -253,19 +258,20 @@ def run_import_tests(): print("āŒ Some import tests failed!") print(f"Failures: {len(result.failures)}") print(f"Errors: {len(result.errors)}") - + if result.failures: print("\nFailures:") for test, traceback in result.failures: print(f" {test}: {traceback}") - + if result.errors: print("\nErrors:") for test, traceback in result.errors: print(f" {test}: {traceback}") - + return result.wasSuccessful() -if __name__ == '__main__': + +if __name__ == "__main__": success = run_import_tests() - sys.exit(0 if success else 1) \ No newline at end of file + sys.exit(0 if success else 1) diff --git a/tests/test_nonetype_bug.py b/tests/test_nonetype_bug.py index c540621..626e312 100644 --- a/tests/test_nonetype_bug.py +++ b/tests/test_nonetype_bug.py @@ -8,7 +8,7 @@ import os # Add the project root to Python path -sys.path.insert(0, os.path.abspath('..')) +sys.path.insert(0, os.path.abspath("..")) from src.tools.connectors.sql import SQLQueryTool, initialize_sql_tool, get_sql_tool from src.db.connection import DatabaseManager @@ -16,17 +16,17 @@ async def test_nonetype_scenarios(): """Test various scenarios that could trigger the NoneType len() error""" - + # Initialize database manager and database db_manager = DatabaseManager("db/test_fact.db") await db_manager.initialize_database() - + # Initialize SQL tool initialize_sql_tool(db_manager) sql_tool = get_sql_tool() - + print("=== Testing NoneType scenarios ===") - + # Test 1: None statement (this should trigger the AttributeError) print("\n1. Testing None statement...") try: @@ -34,7 +34,7 @@ async def test_nonetype_scenarios(): print(f"Result: {result}") except Exception as e: print(f"Error: {type(e).__name__}: {e}") - + # Test 1b: Test len() error by simulating a None statement in error response print("\n1b. Testing len() error scenario...") try: @@ -43,7 +43,7 @@ async def test_nonetype_scenarios(): print(f"Result: {result}") except Exception as e: print(f"Error: {type(e).__name__}: {e}") - + # Test 2: Empty statement print("\n2. Testing empty statement...") try: @@ -51,7 +51,7 @@ async def test_nonetype_scenarios(): print(f"Result: {result}") except Exception as e: print(f"Error: {type(e).__name__}: {e}") - + # Test 3: Query that returns no results (valid syntax) print("\n3. Testing query with no results...") try: @@ -59,7 +59,7 @@ async def test_nonetype_scenarios(): print(f"Result: {result}") except Exception as e: print(f"Error: {type(e).__name__}: {e}") - + # Test 4: Normal valid query (for comparison) print("\n4. Testing valid query...") try: @@ -70,4 +70,4 @@ async def test_nonetype_scenarios(): if __name__ == "__main__": - asyncio.run(test_nonetype_scenarios()) \ No newline at end of file + asyncio.run(test_nonetype_scenarios()) diff --git a/tests/test_query_error.py b/tests/test_query_error.py index 77181a9..c60e0ec 100644 --- a/tests/test_query_error.py +++ b/tests/test_query_error.py @@ -8,25 +8,26 @@ import asyncio # Add src to path for imports -sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'src')) +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "src")) from src.core.driver import get_driver from src.core.config import get_config + async def test_query_processing(): """Test query processing to reproduce the TypeError.""" try: print("šŸ”§ Initializing FACT system...") driver = await get_driver() print("āœ… System initialized") - + # Test queries that might trigger the TypeError test_queries = [ "Show me all companies", "What is the database schema?", - "Get sample queries" + "Get sample queries", ] - + for i, query in enumerate(test_queries, 1): print(f"\nšŸ“‹ Test Query {i}: {query}") try: @@ -35,15 +36,18 @@ async def test_query_processing(): except Exception as e: print(f"āŒ Error: {e}") print(f"Error type: {type(e).__name__}") - + # Print the full traceback for debugging import traceback + traceback.print_exc() - + except Exception as e: print(f"āŒ System initialization failed: {e}") import traceback + traceback.print_exc() + if __name__ == "__main__": - asyncio.run(test_query_processing()) \ No newline at end of file + asyncio.run(test_query_processing()) diff --git a/tests/test_revenue_trends.py b/tests/test_revenue_trends.py index 5b626d1..0cc3fb5 100644 --- a/tests/test_revenue_trends.py +++ b/tests/test_revenue_trends.py @@ -22,7 +22,7 @@ async def test_revenue_trends(): """Test revenue trends analysis functionality.""" print("šŸ“ˆ Testing Revenue Trends Analysis") print("=" * 40) - + try: # Test 1: Database connection print("\n1. Testing database connection...") @@ -30,15 +30,15 @@ async def test_revenue_trends(): db_manager = DatabaseManager(config.database_path) await db_manager.initialize_database() print("āœ… Database connection established") - + # Test 2: Check for sample data print("\n2. Checking for sample financial data...") result = await db_manager.execute_query( "SELECT COUNT(*) as count FROM financial_records" ) - record_count = result.rows[0]['count'] if result.rows else 0 + record_count = result.rows[0]["count"] if result.rows else 0 print(f"āœ… Found {record_count} financial records") - + # Test 3: Revenue by quarter analysis print("\n3. Testing revenue by quarter query...") quarterly_query = """ @@ -51,17 +51,19 @@ async def test_revenue_trends(): ORDER BY year, quarter LIMIT 8 """ - + quarterly_result = await db_manager.execute_query(quarterly_query) - + if quarterly_result.rows: print(f"āœ… Revenue trends analysis completed") print(" Quarterly Revenue Summary:") for row in quarterly_result.rows: - print(f" - {row['quarter']}: ${row['revenue']:,.2f} ({row['transaction_count']} transactions)") + print( + f" - {row['quarter']}: ${row['revenue']:,.2f} ({row['transaction_count']} transactions)" + ) else: print("āš ļø No revenue data found or query failed") - + # Test 4: Company revenue comparison print("\n4. Testing company revenue comparison...") company_query = """ @@ -75,33 +77,37 @@ async def test_revenue_trends(): ORDER BY total_revenue DESC LIMIT 5 """ - + company_result = await db_manager.execute_query(company_query) - + if company_result.rows: print(f"āœ… Company revenue comparison completed") print(" Top Companies by Revenue:") for row in company_result.rows: - revenue = row['total_revenue'] or 0 - print(f" - {row['company_name']}: ${revenue:,.2f} ({row['transaction_count']} transactions)") + revenue = row["total_revenue"] or 0 + print( + f" - {row['company_name']}: ${revenue:,.2f} ({row['transaction_count']} transactions)" + ) else: print("āš ļø No company data found or query failed") - + # Test 5: Driver integration test print("\n5. Testing driver integration with revenue query...") try: driver = FACTDriver(config) - + # Test a revenue-related query through the driver revenue_query = "Show me revenue trends by quarter" result = await driver.process_query(revenue_query) print(f"āœ… Driver processed revenue query successfully") print(f" Response length: {len(result)} characters") except Exception as e: - print(f"āš ļø Driver integration failed (expected without valid API keys): {e}") - + print( + f"āš ļø Driver integration failed (expected without valid API keys): {e}" + ) + # Database manager uses connection pooling, no explicit close needed - + print("\n" + "=" * 40) print("šŸ“Š REVENUE TRENDS TEST RESULTS") print("=" * 40) @@ -110,20 +116,22 @@ async def test_revenue_trends(): print("āœ… Quarterly revenue analysis: PASSED") print("āœ… Company revenue comparison: PASSED") print("āœ… Revenue analysis system is operational!") - + return True - + except Exception as e: print(f"\nāŒ Test failed: {e}") import traceback + traceback.print_exc() return False if __name__ == "__main__": + async def main(): success = await test_revenue_trends() - + if success: print("\nšŸš€ Revenue trends tests passed!") print("The FACT system can analyze financial data successfully.") @@ -131,5 +139,5 @@ async def main(): else: print("\nšŸ’„ Revenue trends tests failed!") sys.exit(1) - - asyncio.run(main()) \ No newline at end of file + + asyncio.run(main()) diff --git a/tests/test_runner.py b/tests/test_runner.py index bc897b3..1728ffd 100644 --- a/tests/test_runner.py +++ b/tests/test_runner.py @@ -16,6 +16,7 @@ @dataclass class TestRunResults: """Results from a test run.""" + timestamp: float total_tests: int passed_tests: int @@ -31,6 +32,7 @@ class TestRunResults: @dataclass class BenchmarkMetrics: """Benchmark metrics for performance tracking.""" + cache_hit_latency_ms: float cache_miss_latency_ms: float tool_execution_latency_ms: float @@ -43,156 +45,168 @@ class BenchmarkMetrics: class FactTestRunner: """Comprehensive test runner for FACT system testing.""" - + def __init__(self, project_root: Path = None): self.project_root = project_root or Path(__file__).parent.parent self.results_dir = self.project_root / "test_results" self.results_dir.mkdir(exist_ok=True) - + def run_unit_tests(self, verbose: bool = False) -> TestRunResults: """Run unit tests for individual components.""" print("🧪 Running unit tests...") - + args = [ "tests/unit/", "-v" if verbose else "-q", "--tb=short", - "-m", "not slow", + "-m", + "not slow", "--junitxml=test_results/unit_tests.xml", "--cov=src", "--cov-report=term-missing", - "--cov-report=html:test_results/coverage_html" + "--cov-report=html:test_results/coverage_html", ] - + start_time = time.time() exit_code = pytest.main(args) duration = time.time() - start_time - + # Parse results results = self._parse_junit_xml("test_results/unit_tests.xml") results.duration_seconds = duration results.test_categories["unit"] = results.total_tests - + print(f"āœ… Unit tests completed in {duration:.2f}s") - print(f" Passed: {results.passed_tests}, Failed: {results.failed_tests}, Skipped: {results.skipped_tests}") - + print( + f" Passed: {results.passed_tests}, Failed: {results.failed_tests}, Skipped: {results.skipped_tests}" + ) + return results - + def run_integration_tests(self, verbose: bool = False) -> TestRunResults: """Run integration tests for component interactions.""" print("šŸ”— Running integration tests...") - + args = [ "tests/integration/", "-v" if verbose else "-q", "--tb=short", - "-m", "integration", - "--junitxml=test_results/integration_tests.xml" + "-m", + "integration", + "--junitxml=test_results/integration_tests.xml", ] - + start_time = time.time() exit_code = pytest.main(args) duration = time.time() - start_time - + results = self._parse_junit_xml("test_results/integration_tests.xml") results.duration_seconds = duration results.test_categories["integration"] = results.total_tests - + print(f"āœ… Integration tests completed in {duration:.2f}s") - print(f" Passed: {results.passed_tests}, Failed: {results.failed_tests}, Skipped: {results.skipped_tests}") - + print( + f" Passed: {results.passed_tests}, Failed: {results.failed_tests}, Skipped: {results.skipped_tests}" + ) + return results - + def run_performance_benchmarks(self, verbose: bool = False) -> BenchmarkMetrics: """Run performance benchmarks and collect metrics.""" print("šŸ“ˆ Running performance benchmarks...") - + args = [ "tests/performance/", "-v" if verbose else "-q", "--tb=short", - "-m", "performance", + "-m", + "performance", "--benchmark-only", "--benchmark-json=test_results/benchmark_results.json", - "--junitxml=test_results/performance_tests.xml" + "--junitxml=test_results/performance_tests.xml", ] - + start_time = time.time() exit_code = pytest.main(args) duration = time.time() - start_time - + # Parse benchmark results metrics = self._parse_benchmark_results("test_results/benchmark_results.json") - + print(f"āœ… Performance benchmarks completed in {duration:.2f}s") print(f" Cache hit latency: {metrics.cache_hit_latency_ms:.2f}ms") print(f" Cache miss latency: {metrics.cache_miss_latency_ms:.2f}ms") print(f" Tool execution: {metrics.tool_execution_latency_ms:.2f}ms") - + return metrics - + def run_security_tests(self, verbose: bool = False) -> TestRunResults: """Run security-focused tests.""" print("šŸ›”ļø Running security tests...") - + args = [ "tests/unit/", "tests/integration/", "-v" if verbose else "-q", "--tb=short", - "-m", "security", - "--junitxml=test_results/security_tests.xml" + "-m", + "security", + "--junitxml=test_results/security_tests.xml", ] - + start_time = time.time() exit_code = pytest.main(args) duration = time.time() - start_time - + results = self._parse_junit_xml("test_results/security_tests.xml") results.duration_seconds = duration results.test_categories["security"] = results.total_tests - + print(f"āœ… Security tests completed in {duration:.2f}s") - print(f" Passed: {results.passed_tests}, Failed: {results.failed_tests}, Skipped: {results.skipped_tests}") - + print( + f" Passed: {results.passed_tests}, Failed: {results.failed_tests}, Skipped: {results.skipped_tests}" + ) + return results - - def run_all_tests(self, verbose: bool = False, skip_slow: bool = True) -> Dict[str, Any]: + + def run_all_tests( + self, verbose: bool = False, skip_slow: bool = True + ) -> Dict[str, Any]: """Run complete test suite with all categories.""" print("šŸš€ Running complete FACT test suite...") print("=" * 60) - + total_start_time = time.time() - + # Run test categories in order results = {} - + try: # Unit tests (fast, run first) results["unit"] = self.run_unit_tests(verbose) - + # Integration tests results["integration"] = self.run_integration_tests(verbose) - + # Security tests results["security"] = self.run_security_tests(verbose) - + # Performance benchmarks (can be slow) if not skip_slow: results["benchmarks"] = self.run_performance_benchmarks(verbose) - + except KeyboardInterrupt: print("\nāŒ Test run interrupted by user") return results - + total_duration = time.time() - total_start_time - + # Generate summary report summary = self._generate_summary_report(results, total_duration) - + # Save results self._save_results(summary) - + print("\n" + "=" * 60) print("šŸ“Š Test Summary:") print(f" Total Duration: {total_duration:.2f}s") @@ -200,65 +214,77 @@ def run_all_tests(self, verbose: bool = False, skip_slow: bool = True) -> Dict[s print(f" Passed: {summary['total_passed']}") print(f" Failed: {summary['total_failed']}") print(f" Coverage: {summary.get('coverage_percentage', 0):.1f}%") - - if summary['total_failed'] > 0: + + if summary["total_failed"] > 0: print("āŒ Some tests failed!") return results else: print("āœ… All tests passed!") return results - - def run_continuous_benchmarks(self, duration_minutes: int = 10, interval_seconds: int = 30) -> List[BenchmarkMetrics]: + + def run_continuous_benchmarks( + self, duration_minutes: int = 10, interval_seconds: int = 30 + ) -> List[BenchmarkMetrics]: """Run continuous benchmarks for monitoring.""" print(f"šŸ“Š Running continuous benchmarks for {duration_minutes} minutes...") - + metrics_history = [] start_time = time.time() end_time = start_time + (duration_minutes * 60) - + while time.time() < end_time: print(f"šŸ“ˆ Benchmark run at {time.strftime('%H:%M:%S')}") - + try: metrics = self.run_performance_benchmarks(verbose=False) metrics_history.append(metrics) - + # Save incremental results self._save_continuous_metrics(metrics_history) - + except Exception as e: print(f"āŒ Benchmark run failed: {e}") - + # Wait for next interval time.sleep(interval_seconds) - - print(f"āœ… Continuous benchmarking completed. Collected {len(metrics_history)} data points.") + + print( + f"āœ… Continuous benchmarking completed. Collected {len(metrics_history)} data points." + ) return metrics_history - - def validate_performance_targets(self, metrics: BenchmarkMetrics) -> Dict[str, bool]: + + def validate_performance_targets( + self, metrics: BenchmarkMetrics + ) -> Dict[str, bool]: """Validate performance metrics against targets.""" targets = { - "cache_hit_latency": 50.0, # ms - "cache_miss_latency": 140.0, # ms - "tool_execution": 10.0, # ms - "overall_response": 100.0, # ms - "cost_reduction_hit": 0.90, # 90% - "cost_reduction_miss": 0.65, # 65% - "min_throughput": 10.0 # QPS + "cache_hit_latency": 50.0, # ms + "cache_miss_latency": 140.0, # ms + "tool_execution": 10.0, # ms + "overall_response": 100.0, # ms + "cost_reduction_hit": 0.90, # 90% + "cost_reduction_miss": 0.65, # 65% + "min_throughput": 10.0, # QPS } - + validation_results = { - "cache_hit_latency": metrics.cache_hit_latency_ms <= targets["cache_hit_latency"], - "cache_miss_latency": metrics.cache_miss_latency_ms <= targets["cache_miss_latency"], - "tool_execution": metrics.tool_execution_latency_ms <= targets["tool_execution"], - "overall_response": metrics.overall_response_latency_ms <= targets["overall_response"], - "cost_reduction_hit": metrics.cost_reduction_cache_hit >= targets["cost_reduction_hit"], - "cost_reduction_miss": metrics.cost_reduction_cache_miss >= targets["cost_reduction_miss"], - "throughput": metrics.throughput_qps >= targets["min_throughput"] + "cache_hit_latency": metrics.cache_hit_latency_ms + <= targets["cache_hit_latency"], + "cache_miss_latency": metrics.cache_miss_latency_ms + <= targets["cache_miss_latency"], + "tool_execution": metrics.tool_execution_latency_ms + <= targets["tool_execution"], + "overall_response": metrics.overall_response_latency_ms + <= targets["overall_response"], + "cost_reduction_hit": metrics.cost_reduction_cache_hit + >= targets["cost_reduction_hit"], + "cost_reduction_miss": metrics.cost_reduction_cache_miss + >= targets["cost_reduction_miss"], + "throughput": metrics.throughput_qps >= targets["min_throughput"], } - + return validation_results - + def _parse_junit_xml(self, xml_path: str) -> TestRunResults: """Parse JUnit XML results.""" # Simplified parsing - in real implementation, use xml.etree.ElementTree @@ -272,9 +298,9 @@ def _parse_junit_xml(self, xml_path: str) -> TestRunResults: test_categories={}, performance_metrics={}, coverage_percentage=0.0, - errors=[] + errors=[], ) - + def _parse_benchmark_results(self, json_path: str) -> BenchmarkMetrics: """Parse benchmark JSON results.""" # Default metrics if file doesn't exist yet @@ -286,10 +312,12 @@ def _parse_benchmark_results(self, json_path: str) -> BenchmarkMetrics: cost_reduction_cache_hit=0.92, cost_reduction_cache_miss=0.68, throughput_qps=25.0, - memory_usage_mb=150.0 + memory_usage_mb=150.0, ) - - def _generate_summary_report(self, results: Dict, total_duration: float) -> Dict[str, Any]: + + def _generate_summary_report( + self, results: Dict, total_duration: float + ) -> Dict[str, Any]: """Generate comprehensive summary report.""" summary = { "timestamp": time.time(), @@ -301,9 +329,9 @@ def _generate_summary_report(self, results: Dict, total_duration: float) -> Dict "test_categories": {}, "coverage_percentage": 0.0, "performance_targets_met": {}, - "recommendations": [] + "recommendations": [], } - + # Aggregate results for category, result in results.items(): if isinstance(result, TestRunResults): @@ -312,80 +340,90 @@ def _generate_summary_report(self, results: Dict, total_duration: float) -> Dict summary["total_failed"] += result.failed_tests summary["total_skipped"] += result.skipped_tests summary["test_categories"][category] = result.total_tests - + if result.coverage_percentage > summary["coverage_percentage"]: summary["coverage_percentage"] = result.coverage_percentage - + elif isinstance(result, BenchmarkMetrics): validation = self.validate_performance_targets(result) summary["performance_targets_met"] = validation - + # Generate recommendations if not validation["cache_hit_latency"]: summary["recommendations"].append("Optimize cache hit performance") if not validation["cost_reduction_hit"]: - summary["recommendations"].append("Improve cache hit cost reduction") - + summary["recommendations"].append( + "Improve cache hit cost reduction" + ) + return summary - + def _save_results(self, summary: Dict[str, Any]): """Save test results to files.""" timestamp = time.strftime("%Y%m%d_%H%M%S") - + # Save JSON summary summary_path = self.results_dir / f"test_summary_{timestamp}.json" - with open(summary_path, 'w') as f: + with open(summary_path, "w") as f: json.dump(summary, f, indent=2, default=str) - + # Save latest results latest_path = self.results_dir / "latest_results.json" - with open(latest_path, 'w') as f: + with open(latest_path, "w") as f: json.dump(summary, f, indent=2, default=str) - + print(f"šŸ“ Results saved to {summary_path}") - + def _save_continuous_metrics(self, metrics_history: List[BenchmarkMetrics]): """Save continuous benchmark metrics.""" metrics_data = [asdict(m) for m in metrics_history] - + continuous_path = self.results_dir / "continuous_benchmarks.json" - with open(continuous_path, 'w') as f: + with open(continuous_path, "w") as f: json.dump(metrics_data, f, indent=2) def main(): """Main entry point for test runner.""" parser = argparse.ArgumentParser(description="FACT System Test Runner") - parser.add_argument("--test-type", choices=["unit", "integration", "performance", "security", "all"], - default="all", help="Type of tests to run") + parser.add_argument( + "--test-type", + choices=["unit", "integration", "performance", "security", "all"], + default="all", + help="Type of tests to run", + ) parser.add_argument("--verbose", "-v", action="store_true", help="Verbose output") parser.add_argument("--skip-slow", action="store_true", help="Skip slow tests") - parser.add_argument("--continuous", type=int, metavar="MINUTES", - help="Run continuous benchmarks for specified minutes") - parser.add_argument("--validate-targets", action="store_true", - help="Validate performance targets") - + parser.add_argument( + "--continuous", + type=int, + metavar="MINUTES", + help="Run continuous benchmarks for specified minutes", + ) + parser.add_argument( + "--validate-targets", action="store_true", help="Validate performance targets" + ) + args = parser.parse_args() - + runner = FactTestRunner() - + try: if args.continuous: # Run continuous benchmarking metrics_history = runner.run_continuous_benchmarks( - duration_minutes=args.continuous, - interval_seconds=30 + duration_minutes=args.continuous, interval_seconds=30 ) - + if metrics_history: latest_metrics = metrics_history[-1] validation = runner.validate_performance_targets(latest_metrics) - + print("\nšŸ“Š Performance Target Validation:") for target, passed in validation.items(): status = "āœ…" if passed else "āŒ" print(f" {status} {target}: {passed}") - + elif args.test_type == "unit": runner.run_unit_tests(args.verbose) elif args.test_type == "integration": @@ -402,16 +440,17 @@ def main(): runner.run_security_tests(args.verbose) else: # all results = runner.run_all_tests(args.verbose, args.skip_slow) - + # Exit with error code if tests failed total_failed = sum( - r.failed_tests for r in results.values() + r.failed_tests + for r in results.values() if isinstance(r, TestRunResults) ) - + if total_failed > 0: sys.exit(1) - + except KeyboardInterrupt: print("\nāŒ Test run interrupted") sys.exit(1) @@ -421,4 +460,4 @@ def main(): if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/tests/test_simple.py b/tests/test_simple.py index 1542904..1e2fa9d 100644 --- a/tests/test_simple.py +++ b/tests/test_simple.py @@ -21,7 +21,7 @@ async def test_basic_system(): """Test basic system initialization and functionality.""" print("🧪 Testing Basic FACT System") print("=" * 40) - + try: # Test 1: Configuration loading print("\n1. Testing configuration loading...") @@ -29,17 +29,17 @@ async def test_basic_system(): print(f"āœ… Configuration loaded successfully") print(f" Database path: {config.database_path}") print(f" Cache prefix: {config.cache_prefix}") - + # Test 2: Driver initialization print("\n2. Testing driver initialization...") driver = FACTDriver(config) print("āœ… Driver initialized successfully") - + # Test 3: Basic metrics print("\n3. Testing metrics collection...") metrics = driver.get_metrics() print(f"āœ… Metrics available: {list(metrics.keys())}") - + # Test 4: Simple query (if API keys are available) print("\n4. Testing simple query processing...") try: @@ -50,7 +50,7 @@ async def test_basic_system(): print(f" Response length: {len(result)} characters") except Exception as e: print(f"āš ļø Query processing failed (expected without valid API keys): {e}") - + print("\n" + "=" * 40) print("šŸŽ‰ BASIC SYSTEM TEST RESULTS") print("=" * 40) @@ -58,20 +58,22 @@ async def test_basic_system(): print("āœ… Driver initialization: PASSED") print("āœ… Metrics collection: PASSED") print("āœ… System is operational!") - + return True - + except Exception as e: print(f"\nāŒ Test failed: {e}") import traceback + traceback.print_exc() return False if __name__ == "__main__": + async def main(): success = await test_basic_system() - + if success: print("\nšŸš€ Basic system tests passed!") print("The FACT system is ready for use.") @@ -79,5 +81,5 @@ async def main(): else: print("\nšŸ’„ Basic system tests failed!") sys.exit(1) - - asyncio.run(main()) \ No newline at end of file + + asyncio.run(main()) diff --git a/tests/test_sql_fixes.py b/tests/test_sql_fixes.py index e5fb61f..7505b50 100644 --- a/tests/test_sql_fixes.py +++ b/tests/test_sql_fixes.py @@ -8,25 +8,30 @@ import os # Add the project root to Python path -sys.path.insert(0, os.path.abspath('..')) +sys.path.insert(0, os.path.abspath("..")) -from src.tools.connectors.sql import SQLQueryTool, initialize_sql_tool, get_sql_tool, sql_get_schema +from src.tools.connectors.sql import ( + SQLQueryTool, + initialize_sql_tool, + get_sql_tool, + sql_get_schema, +) from src.db.connection import DatabaseManager async def test_comprehensive_fixes(): """Test all NoneType fixes comprehensively""" - + # Initialize database manager and database db_manager = DatabaseManager("db/test_sql_fixes.db") await db_manager.initialize_database() - + # Initialize SQL tool initialize_sql_tool(db_manager) sql_tool = get_sql_tool() - + print("=== Comprehensive SQL NoneType Fix Tests ===") - + # Test 1: None statement (should handle gracefully now) print("\n1. Testing None statement (fixed)...") try: @@ -34,7 +39,7 @@ async def test_comprehensive_fixes(): print(f"āœ“ Result: {result}") except Exception as e: print(f"āœ— Error: {type(e).__name__}: {e}") - + # Test 2: Empty string statement print("\n2. Testing empty statement...") try: @@ -42,7 +47,7 @@ async def test_comprehensive_fixes(): print(f"āœ“ Result: {result}") except Exception as e: print(f"āœ“ Expected error: {type(e).__name__}: {e}") - + # Test 3: Non-string statement print("\n3. Testing non-string statement...") try: @@ -50,7 +55,7 @@ async def test_comprehensive_fixes(): print(f"āœ“ Result: {result}") except Exception as e: print(f"āœ“ Expected error: {type(e).__name__}: {e}") - + # Test 4: Valid query with results print("\n4. Testing valid query...") try: @@ -58,7 +63,7 @@ async def test_comprehensive_fixes(): print(f"āœ“ Result: {result}") except Exception as e: print(f"āœ— Error: {type(e).__name__}: {e}") - + # Test 5: Query that returns no results (empty result set) print("\n5. Testing query with no results...") try: @@ -66,7 +71,7 @@ async def test_comprehensive_fixes(): print(f"āœ“ Result: {result}") except Exception as e: print(f"āœ— Error: {type(e).__name__}: {e}") - + # Test 6: Schema retrieval (tests the len() fixes in schema function) print("\n6. Testing schema retrieval...") try: @@ -75,7 +80,7 @@ async def test_comprehensive_fixes(): print(f"āœ“ Total tables: {result.get('total_tables', 0)}") except Exception as e: print(f"āœ— Error: {type(e).__name__}: {e}") - + # Test 7: Very long statement (tests truncation logic) print("\n7. Testing long statement truncation...") try: @@ -84,9 +89,9 @@ async def test_comprehensive_fixes(): print(f"āœ“ Result: {result}") except Exception as e: print(f"āœ— Error: {type(e).__name__}: {e}") - + print("\n=== All tests completed ===") if __name__ == "__main__": - asyncio.run(test_comprehensive_fixes()) \ No newline at end of file + asyncio.run(test_comprehensive_fixes()) diff --git a/tests/test_sql_fixes_validation.py b/tests/test_sql_fixes_validation.py index 224e21a..a18fe61 100644 --- a/tests/test_sql_fixes_validation.py +++ b/tests/test_sql_fixes_validation.py @@ -11,7 +11,7 @@ import os # Add the project root to Python path -sys.path.insert(0, os.path.abspath('.')) +sys.path.insert(0, os.path.abspath(".")) from src.tools.connectors.sql import SQLQueryTool from src.db.connection import DatabaseManager @@ -20,89 +20,95 @@ async def test_validation_fixes(): """Test all NoneType error scenarios have been fixed.""" - + print("šŸ”§ SQL NoneType Error Validation Tests") print("=" * 50) - + # Create database manager for testing db_manager = DatabaseManager("db/test_validation.db") await db_manager.initialize_database() - + sql_tool = SQLQueryTool(db_manager) - + test_cases = [ { "name": "None SQL statement", "statement": None, "expected_error": "SQL statement cannot be None", - "description": "Validates None input doesn't cause AttributeError on .lower() or len()" + "description": "Validates None input doesn't cause AttributeError on .lower() or len()", }, { "name": "Empty SQL statement", "statement": "", "expected_error": "SQL statement cannot be empty", - "description": "Validates empty string doesn't cause issues" + "description": "Validates empty string doesn't cause issues", }, { "name": "Whitespace-only SQL statement", "statement": " \n\t ", "expected_error": "SQL statement cannot be empty", - "description": "Validates whitespace-only strings are handled" + "description": "Validates whitespace-only strings are handled", }, { "name": "Non-string input (integer)", "statement": 123, "expected_error": "SQL statement must be a string, got int", - "description": "Validates non-string inputs are properly handled" + "description": "Validates non-string inputs are properly handled", }, { "name": "Non-string input (list)", "statement": ["SELECT", "*", "FROM", "table"], "expected_error": "SQL statement must be a string, got list", - "description": "Validates list inputs are properly handled" + "description": "Validates list inputs are properly handled", }, { "name": "Valid SELECT statement", "statement": "SELECT COUNT(*) as total FROM companies", "expected_error": None, - "description": "Validates that valid queries still work" - } + "description": "Validates that valid queries still work", + }, ] - + passed_tests = 0 total_tests = len(test_cases) - + for i, test_case in enumerate(test_cases, 1): print(f"\n{i}. {test_case['name']}") print(f" {test_case['description']}") - + try: - result = await sql_tool.execute_query(test_case['statement']) - - if test_case['expected_error'] is None: + result = await sql_tool.execute_query(test_case["statement"]) + + if test_case["expected_error"] is None: # This should succeed - if result.get('status') == 'success' or 'rows' in result: + if result.get("status") == "success" or "rows" in result: print(f" āœ… PASSED: Query executed successfully") passed_tests += 1 - elif result.get('status') == 'failed': - print(f" āš ļø PARTIAL: Query failed but didn't crash: {result.get('error', 'Unknown error')}") + elif result.get("status") == "failed": + print( + f" āš ļø PARTIAL: Query failed but didn't crash: {result.get('error', 'Unknown error')}" + ) passed_tests += 1 # Still counts as passing the NoneType fix else: print(f" āŒ FAILED: Unexpected result format") else: # This should fail with specific error - if result.get('status') == 'failed' and test_case['expected_error'] in result.get('error', ''): + if result.get("status") == "failed" and test_case[ + "expected_error" + ] in result.get("error", ""): print(f" āœ… PASSED: Correct error message returned") passed_tests += 1 else: - print(f" āŒ FAILED: Expected '{test_case['expected_error']}', got '{result.get('error', 'No error')}'") - + print( + f" āŒ FAILED: Expected '{test_case['expected_error']}', got '{result.get('error', 'No error')}'" + ) + except Exception as e: print(f" āŒ FAILED: Exception thrown - {type(e).__name__}: {e}") - + print("\n" + "=" * 50) print(f"šŸŽÆ Test Results: {passed_tests}/{total_tests} tests passed") - + if passed_tests == total_tests: print("šŸŽ‰ ALL TESTS PASSED - NoneType errors have been successfully fixed!") return True @@ -113,12 +119,12 @@ async def test_validation_fixes(): async def test_direct_validation(): """Test the validation method directly.""" - + print("\nšŸ” Direct Validation Method Tests") print("=" * 50) - + db_manager = DatabaseManager("db/test_validation.db") - + test_cases = [ (None, "SQL statement cannot be None"), ("", "SQL statement cannot be empty"), @@ -127,13 +133,13 @@ async def test_direct_validation(): ("DROP TABLE users", "Only SELECT statements are allowed"), ("SELECT * FROM users; DROP TABLE users", "Multiple statements"), ] - + passed = 0 total = len(test_cases) - + for i, (statement, expected_error_part) in enumerate(test_cases, 1): print(f"\n{i}. Testing: {repr(statement)}") - + try: db_manager.validate_sql_query(statement) print(f" āŒ FAILED: No exception thrown") @@ -142,19 +148,22 @@ async def test_direct_validation(): print(f" āœ… PASSED: Correct error - {e}") passed += 1 else: - print(f" āŒ FAILED: Wrong error - Expected '{expected_error_part}', got '{e}'") + print( + f" āŒ FAILED: Wrong error - Expected '{expected_error_part}', got '{e}'" + ) except Exception as e: print(f" āŒ FAILED: Unexpected exception - {type(e).__name__}: {e}") - + print(f"\nšŸŽÆ Direct validation tests: {passed}/{total} passed") return passed == total if __name__ == "__main__": + async def main(): success1 = await test_validation_fixes() success2 = await test_direct_validation() - + if success1 and success2: print("\nšŸŽ‰ ALL VALIDATIONS PASSED!") print("The NoneType error fixes are working correctly.") @@ -162,5 +171,5 @@ async def main(): else: print("\nāš ļø Some validations failed.") exit(1) - - asyncio.run(main()) \ No newline at end of file + + asyncio.run(main()) diff --git a/tests/unit/test_cache_manager_integration.py b/tests/unit/test_cache_manager_integration.py index 4991e9c..df945f7 100644 --- a/tests/unit/test_cache_manager_integration.py +++ b/tests/unit/test_cache_manager_integration.py @@ -18,7 +18,7 @@ class TestCacheManagerIntegration: """Test suite for integrated cache manager functionality.""" - + @pytest.fixture def cache_config(self): """Provide cache configuration for tests.""" @@ -28,276 +28,285 @@ def cache_config(self): "max_size": "1MB", "ttl_seconds": 3600, "hit_target_ms": 50, - "miss_target_ms": 150 + "miss_target_ms": 150, } - + @pytest.fixture def cache_manager(self, cache_config): """Create cache manager for testing.""" return CacheManager(cache_config) - + def test_cache_manager_with_security_validation(self, cache_manager): """TEST: Cache manager integrates with security validation""" # Store clean content - should succeed clean_content = "This is clean, safe content " * 50 # Meet min tokens entry = cache_manager.store("clean_query", clean_content) - + assert entry.content == clean_content assert entry.token_count >= cache_manager.min_tokens - + def test_cache_manager_blocks_malicious_content(self, cache_manager): """TEST: Cache manager blocks malicious content via security validation""" - malicious_content = "password: secret123 " * 50 # Meet min tokens but include sensitive data - + malicious_content = ( + "password: secret123 " * 50 + ) # Meet min tokens but include sensitive data + with pytest.raises(CacheError): cache_manager.store("malicious_query", malicious_content) - + def test_cache_manager_handles_security_validation_failure(self, cache_manager): """TEST: Cache manager handles security validation failures gracefully""" # Mock security validation to fail - with patch('src.cache.manager.validate_cache_content_security', side_effect=SecurityError("Test security error")): + with patch( + "src.cache.manager.validate_cache_content_security", + side_effect=SecurityError("Test security error"), + ): content = "Safe content " * 50 - + with pytest.raises(CacheError): cache_manager.store("test_query", content) - + def test_cache_manager_validation_integration(self, cache_manager): """TEST: Cache manager works with cache validation""" # Add several entries for i in range(5): content = f"Test content {i} " * 100 cache_manager.store(f"query_{i}", content) - + # Create validator and validate validator = CacheValidator(cache_manager=cache_manager) - + # Should work without errors assert validator.cache_manager == cache_manager assert len(validator.thresholds) > 0 - + @pytest.mark.asyncio async def test_end_to_end_cache_validation(self, cache_manager): """TEST: End-to-end cache validation workflow""" # Add test entries with different characteristics - + # Valid entry valid_content = "Valid business content " * 100 cache_manager.store("valid_query", valid_content) - + # Entry that will become stale old_entry = CacheEntry.create(cache_manager.prefix, "Old content " * 100) old_entry.created_at = time.time() - 7200 # 2 hours old cache_manager.cache["old_query"] = old_entry - + # Create validator and run validation validator = CacheValidator(cache_manager=cache_manager) result = await validator.validate_cache(ValidationLevel.COMPREHENSIVE) - + assert result.total_entries_checked == 2 assert result.overall_health in ["healthy", "warning", "critical"] assert len(result.recommendations) >= 0 - + def test_get_cache_manager_singleton(self, cache_config): """TEST: get_cache_manager returns singleton instance""" manager1 = get_cache_manager(cache_config) manager2 = get_cache_manager() # Should return same instance - + assert manager1 is manager2 - + def test_get_cache_manager_environment_config(self): """TEST: get_cache_manager loads from environment""" env_vars = { "CACHE_PREFIX": "env_test", "CACHE_MIN_TOKENS": "1000", - "CACHE_MAX_SIZE": "5MB" + "CACHE_MAX_SIZE": "5MB", } - + with patch.dict(os.environ, env_vars): - with patch('src.cache.manager.load_cache_config_from_env') as mock_load: + with patch("src.cache.manager.load_cache_config_from_env") as mock_load: mock_config = CacheConfig( - prefix="env_test", - min_tokens=1000, - max_size="5MB" + prefix="env_test", min_tokens=1000, max_size="5MB" ) mock_load.return_value = mock_config - + # Clear singleton for test import src.cache.manager + src.cache.manager._cache_manager_instance = None - + manager = get_cache_manager() - + assert manager is not None mock_load.assert_called_once() - + def test_cache_manager_error_handling(self, cache_config): """TEST: Cache manager handles various error conditions""" manager = CacheManager(cache_config) - + # Test content too small small_content = "Too small" with pytest.raises(CacheError): manager.store("small_query", small_content) - + # Test cache size limits very_large_content = "X" * (2 * 1024 * 1024) # 2MB, exceeds 1MB limit with pytest.raises(CacheError): manager.store("large_query", very_large_content) - + @pytest.mark.asyncio async def test_cache_auto_repair_integration(self, cache_manager): """TEST: Cache auto-repair integration with validation""" # Add some problematic entries - + # Corrupted entry (empty content) corrupted_entry = CacheEntry( - prefix=cache_manager.prefix, - content="", - token_count=0, - validate=False + prefix=cache_manager.prefix, content="", token_count=0, validate=False ) cache_manager.cache["corrupted"] = corrupted_entry - + # Expired entry expired_entry = CacheEntry( prefix=cache_manager.prefix, content="Expired content " * 100, token_count=500, created_at=time.time() - 7200, # 2 hours old - validate=False + validate=False, ) cache_manager.cache["expired"] = expired_entry - + # Valid entry valid_content = "Valid content " * 100 cache_manager.store("valid", valid_content) - + # Run validation and auto-repair validator = CacheValidator(cache_manager=cache_manager) - validation_result = await validator.validate_cache(ValidationLevel.COMPREHENSIVE) - + validation_result = await validator.validate_cache( + ValidationLevel.COMPREHENSIVE + ) + # Should find issues - assert validation_result.invalid_entries > 0 or validation_result.expired_entries > 0 - + assert ( + validation_result.invalid_entries > 0 + or validation_result.expired_entries > 0 + ) + # Run auto-repair repair_summary = await validator.auto_repair_cache(validation_result) - + # Should have removed problematic entries assert repair_summary["entries_removed"] > 0 assert "valid" in cache_manager.cache # Valid entry should remain - + def test_cache_metrics_integration(self, cache_manager): """TEST: Cache metrics integration with validation""" # Add entries and access them for i in range(3): content = f"Content {i} " * 100 cache_manager.store(f"query_{i}", content) - + # Access some entries to generate hit metrics cache_manager.get("query_0") cache_manager.get("query_1") cache_manager.get("nonexistent") # Miss - + metrics = cache_manager.get_metrics() - + assert metrics.total_entries == 3 assert metrics.cache_hits >= 2 assert metrics.cache_misses >= 1 assert metrics.hit_rate > 0 - + @pytest.mark.asyncio async def test_database_integration_validation(self): """TEST: Database integration with cache validation""" # Create temporary database file with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as temp_db: db_path = temp_db.name - + try: config = { "prefix": "db_test", "min_tokens": 100, # Lower for testing "max_size": "1MB", "ttl_seconds": 3600, - "database_path": db_path + "database_path": db_path, } - - manager = CacheManager({ - "prefix": "db_test", - "min_tokens": 100, # Lower for testing - "max_size": "1MB", - "ttl_seconds": 3600, - "hit_target_ms": 50, - "miss_target_ms": 150 - }) - + + manager = CacheManager( + { + "prefix": "db_test", + "min_tokens": 100, # Lower for testing + "max_size": "1MB", + "ttl_seconds": 3600, + "hit_target_ms": 50, + "miss_target_ms": 150, + } + ) + # Add database-related cache entries sql_query = "SELECT * FROM financial_data WHERE quarter='Q1' " * 50 manager.store("db_query_1", sql_query) - + # Create validator with database config validator = CacheValidator(cache_manager=manager, config=config) - + # Run database cache validation db_result = await validator.validate_database_cache_integrity() - + # Should handle database validation (even if DB doesn't exist) assert "database_validation" in db_result - + finally: # Cleanup if os.path.exists(db_path): os.unlink(db_path) - + @pytest.mark.asyncio async def test_comprehensive_security_scan(self, cache_manager): """TEST: Comprehensive security scanning of cache""" # Add entries with various security characteristics - + # Safe content safe_content = "Safe business content about financial analysis " * 50 cache_manager.store("safe_query", safe_content) - + # Add entry with potential issues (bypassing security for test) - with patch('src.cache.manager.validate_cache_content_security'): + with patch("src.cache.manager.validate_cache_content_security"): suspicious_content = "Content with api_key=secret123 " * 50 cache_manager.store("suspicious_query", suspicious_content) - + # Create security validator and scan from src.cache.security import CacheSecurityValidator + security_validator = CacheSecurityValidator() - + scan_results = [] for entry_key, entry in cache_manager.cache.items(): scan_result = security_validator.scan_content(entry.content) scan_results.append(scan_result) - + # Generate security report report = security_validator.generate_security_report(scan_results) - + assert report["total_scans"] >= 2 assert report["total_content_scanned_bytes"] > 0 assert "security_posture" in report - + def test_cache_manager_configuration_validation(self): """TEST: Cache manager validates configuration properly""" # Invalid configuration should raise error invalid_config = { "prefix": "123invalid", # Invalid prefix - "min_tokens": -1, # Invalid min_tokens - "max_size": "invalid", # Invalid size format + "min_tokens": -1, # Invalid min_tokens + "max_size": "invalid", # Invalid size format } - + with pytest.raises((CacheError, ConfigurationError, ValueError)): CacheManager(invalid_config) - + def test_cache_manager_thread_safety(self, cache_manager): """TEST: Cache manager thread safety with concurrent operations""" import threading import time - + results = [] errors = [] - + def store_content(thread_id): try: content = f"Thread {thread_id} content " * 100 @@ -305,26 +314,26 @@ def store_content(thread_id): results.append(entry) except Exception as e: errors.append(e) - + # Create multiple threads threads = [] for i in range(5): thread = threading.Thread(target=store_content, args=(i,)) threads.append(thread) - + # Start all threads for thread in threads: thread.start() - + # Wait for completion for thread in threads: thread.join() - + # Check results assert len(errors) == 0, f"Errors occurred: {errors}" assert len(results) == 5 assert len(cache_manager.cache) >= 5 - + @pytest.mark.asyncio async def test_performance_monitoring_integration(self, cache_manager): """TEST: Performance monitoring integration""" @@ -334,26 +343,26 @@ async def test_performance_monitoring_integration(self, cache_manager): start_time = time.perf_counter() cache_manager.store(f"perf_query_{i}", content) store_time = (time.perf_counter() - start_time) * 1000 - + # Verify store time is reasonable (should be fast) assert store_time < 100 # Less than 100ms - + # Test retrieval performance for i in range(5): start_time = time.perf_counter() result = cache_manager.get(f"perf_query_{i}") get_time = (time.perf_counter() - start_time) * 1000 - + assert result is not None assert get_time < 50 # Less than 50ms for cache hit - + # Get performance metrics metrics = cache_manager.get_metrics() - + assert metrics.total_entries == 10 assert metrics.cache_hits >= 5 assert metrics.hit_rate > 0 if __name__ == "__main__": - pytest.main([__file__, "-v", "--cov=src.cache", "--cov-report=term-missing"]) \ No newline at end of file + pytest.main([__file__, "-v", "--cov=src.cache", "--cov-report=term-missing"]) diff --git a/tests/unit/test_cache_mechanism.py b/tests/unit/test_cache_mechanism.py index eac932a..059890d 100644 --- a/tests/unit/test_cache_mechanism.py +++ b/tests/unit/test_cache_mechanism.py @@ -15,16 +15,16 @@ class TestCacheEntry: """Test suite for cache entry functionality.""" - + def test_cache_entry_initialization_sets_proper_attributes(self): """TEST: Cache entry initialization sets proper attributes""" # Arrange prefix = "fact_v1" content = "A" * 500 # Minimum 500 tokens - + # Act entry = CacheEntry(prefix=prefix, content=content) - + # Assert assert entry.prefix == prefix assert entry.content == content @@ -34,7 +34,7 @@ def test_cache_entry_initialization_sets_proper_attributes(self): assert entry.is_valid == True assert entry.access_count == 0 assert entry.last_accessed is None - + def test_cache_entry_calculates_token_count_accurately(self): """TEST: Cache entry calculates token count accurately""" # Arrange @@ -42,50 +42,57 @@ def test_cache_entry_calculates_token_count_accurately(self): ("Hello world", 2), ("A" * 100, 100), # Single character tokens ("The quick brown fox jumps over the lazy dog", 9), - ("" * 1000, 0) # Empty content + ("" * 1000, 0), # Empty content ] - + # Act & Assert for content, expected_tokens in test_cases: - entry = CacheEntry(prefix="test", content=content, skip_min_tokens=True, skip_content_validation=True) + entry = CacheEntry( + prefix="test", + content=content, + skip_min_tokens=True, + skip_content_validation=True, + ) # Allow for reasonable token counting variations assert abs(entry.token_count - expected_tokens) <= expected_tokens * 0.1 - + def test_cache_entry_validates_minimum_token_requirement(self): """TEST: Cache entry validates minimum token requirement""" # Arrange short_content = "A" * 10 # Less than 500 tokens - + # Act & Assert with pytest.raises(CacheError) as exc_info: CacheEntry(prefix="test", content=short_content) - + assert "minimum 500 tokens" in str(exc_info.value) - + def test_cache_entry_tracks_access_patterns(self): """TEST: Cache entry tracks access patterns correctly""" # Arrange entry = CacheEntry(prefix="test", content="A" * 500) - + # Act entry.record_access() time.sleep(0.001) # Small delay to ensure different timestamp entry.record_access() - + # Assert assert entry.access_count == 2 assert entry.last_accessed is not None assert entry.last_accessed > entry.created_at - + def test_cache_entry_serialization_to_dict(self): """TEST: Cache entry serialization to dictionary format""" # Arrange - entry = CacheEntry(prefix="test", content="Test content " * 100, skip_min_tokens=True) + entry = CacheEntry( + prefix="test", content="Test content " * 100, skip_min_tokens=True + ) entry.record_access() - + # Act entry_dict = entry.to_dict() - + # Assert assert entry_dict["prefix"] == "test" assert entry_dict["content"] == "Test content " * 100 @@ -97,35 +104,35 @@ def test_cache_entry_serialization_to_dict(self): class TestCacheManager: """Test suite for cache manager functionality.""" - + def test_cache_manager_initialization_loads_configuration(self, cache_config): """TEST: Cache manager initialization loads configuration""" # Act manager = CacheManager(config=cache_config) - + # Assert assert manager.prefix == cache_config["prefix"] assert manager.min_tokens == cache_config["min_tokens"] assert manager.max_size == cache_config["max_size"] assert manager.ttl_seconds == cache_config["ttl_seconds"] assert len(manager.cache) == 0 - + def test_cache_manager_stores_entries_correctly(self, cache_config): """TEST: Cache manager stores cache entries correctly""" # Arrange manager = CacheManager(config=cache_config) content = "Sample cache content " * 50 # Ensure > 500 tokens query_hash = "test_query_hash_123" - + # Act entry = manager.store(query_hash, content) - + # Assert assert entry.prefix == cache_config["prefix"] assert entry.content == content assert query_hash in manager.cache assert manager.cache[query_hash] == entry - + def test_cache_manager_retrieves_entries_correctly(self, cache_config): """TEST: Cache manager retrieves cache entries correctly""" # Arrange @@ -133,94 +140,94 @@ def test_cache_manager_retrieves_entries_correctly(self, cache_config): content = "Retrievable content " * 50 query_hash = "retrieve_test_hash" stored_entry = manager.store(query_hash, content) - + # Act retrieved_entry = manager.get(query_hash) - + # Assert assert retrieved_entry is not None assert retrieved_entry == stored_entry assert retrieved_entry.access_count == 1 assert retrieved_entry.last_accessed is not None - + def test_cache_manager_handles_cache_misses(self, cache_config): """TEST: Cache manager handles cache misses gracefully""" # Arrange manager = CacheManager(config=cache_config) - + # Act result = manager.get("nonexistent_hash") - + # Assert assert result is None - + def test_cache_manager_enforces_size_limits(self, cache_config): """TEST: Cache manager enforces size limits""" # Arrange small_config = cache_config.copy() small_config["max_size"] = "1KB" # Very small limit manager = CacheManager(config=small_config) - + # Act large_content = "X" * 2000 # 2KB content - + with pytest.raises(CacheError) as exc_info: manager.store("large_hash", large_content) - + # Assert assert "size limit" in str(exc_info.value).lower() - + def test_cache_manager_implements_ttl_expiration(self, cache_config): """TEST: Cache manager implements TTL expiration""" # Arrange short_ttl_config = cache_config.copy() short_ttl_config["ttl_seconds"] = 0.1 # 100ms TTL manager = CacheManager(config=short_ttl_config) - + content = "Expiring content " * 50 query_hash = "expiring_hash" - + # Act manager.store(query_hash, content) time.sleep(0.2) # Wait for expiration - + # Assert assert manager.get(query_hash) is None - + def test_cache_manager_invalidates_entries_by_prefix(self, cache_config): """TEST: Cache manager invalidates entries by prefix""" # Arrange manager = CacheManager(config=cache_config) - + # Store multiple entries for i in range(3): content = f"Content {i} " * 50 manager.store(f"hash_{i}", content) - + # Act invalidated_count = manager.invalidate_by_prefix(cache_config["prefix"]) - + # Assert assert invalidated_count == 3 assert len(manager.cache) == 0 - + def test_cache_manager_calculates_metrics_accurately(self, cache_config): """TEST: Cache manager calculates metrics accurately""" # Arrange manager = CacheManager(config=cache_config) - + # Create test scenario content = "Metrics test content " * 50 manager.store("hash_1", content) manager.store("hash_2", content) - + # Access one entry to create hit/miss data manager.get("hash_1") manager.get("nonexistent") # Miss - + # Act metrics = manager.get_metrics() - + # Assert assert isinstance(metrics, CacheMetrics) assert metrics.total_entries == 2 @@ -232,7 +239,7 @@ def test_cache_manager_calculates_metrics_accurately(self, cache_config): class TestCachePerformance: """Test suite for cache performance requirements.""" - + @pytest.mark.performance def test_cache_hit_latency_under_50ms(self, cache_config, performance_timer): """TEST: Cache hits achieve target latency under 50ms""" @@ -241,120 +248,133 @@ def test_cache_hit_latency_under_50ms(self, cache_config, performance_timer): content = "Performance test content " * 100 query_hash = "perf_test_hash" manager.store(query_hash, content) - + # Act with performance_timer() as timer: result = manager.get(query_hash) - + # Assert assert result is not None - assert timer.duration_ms < 50, f"Cache hit took {timer.duration_ms}ms, exceeds 50ms target" - + assert ( + timer.duration_ms < 50 + ), f"Cache hit took {timer.duration_ms}ms, exceeds 50ms target" + @pytest.mark.performance - def test_cache_storage_performance_under_10ms(self, cache_config, performance_timer): + def test_cache_storage_performance_under_10ms( + self, cache_config, performance_timer + ): """TEST: Cache storage operations complete under 10ms""" # Arrange manager = CacheManager(config=cache_config) content = "Storage performance test " * 100 query_hash = "storage_perf_hash" - + # Act with performance_timer() as timer: entry = manager.store(query_hash, content) - + # Assert assert entry is not None - assert timer.duration_ms < 10, f"Cache storage took {timer.duration_ms}ms, exceeds 10ms target" - + assert ( + timer.duration_ms < 10 + ), f"Cache storage took {timer.duration_ms}ms, exceeds 10ms target" + @pytest.mark.performance def test_concurrent_cache_access_performance(self, cache_config): """TEST: Concurrent cache access maintains performance""" # Arrange import asyncio + manager = CacheManager(config=cache_config) - + # Pre-populate cache for i in range(10): content = f"Concurrent test content {i} " * 50 manager.store(f"concurrent_hash_{i}", content) - + async def concurrent_access(hash_id): start_time = time.perf_counter() result = manager.get(f"concurrent_hash_{hash_id}") end_time = time.perf_counter() return result, (end_time - start_time) * 1000 - + # Act async def run_concurrent_tests(): tasks = [concurrent_access(i % 10) for i in range(50)] return await asyncio.gather(*tasks) - + results = asyncio.run(run_concurrent_tests()) - + # Assert for result, latency_ms in results: assert result is not None assert latency_ms < 50, f"Concurrent access took {latency_ms}ms" - + @pytest.mark.performance def test_cache_memory_efficiency(self, cache_config): """TEST: Cache maintains memory efficiency""" # Arrange manager = CacheManager(config=cache_config) import sys - + # Measure baseline memory baseline_size = sys.getsizeof(manager) - + # Add many entries content = "Memory efficiency test " * 50 for i in range(100): manager.store(f"memory_hash_{i}", content) - + # Act current_size = sys.getsizeof(manager) size_per_entry = (current_size - baseline_size) / 100 - + # Assert # Each entry should be reasonably sized (less than 10KB overhead) - assert size_per_entry < 10240, f"Cache overhead {size_per_entry} bytes per entry too high" + assert ( + size_per_entry < 10240 + ), f"Cache overhead {size_per_entry} bytes per entry too high" class TestCacheIntegration: """Test suite for cache integration with other components.""" - - def test_cache_integration_with_anthropic_client(self, mock_anthropic_client, cache_config): + + def test_cache_integration_with_anthropic_client( + self, mock_anthropic_client, cache_config + ): """TEST: Cache integrates properly with Anthropic client""" # Arrange manager = CacheManager(config=cache_config) query = "What is Q1-2025 revenue?" query_hash = manager.generate_hash(query) - + # Mock cached response cached_content = "Cached response: Q1-2025 revenue was $1,234,567.89" manager.store(query_hash, cached_content) - + # Act from src.cache.manager import get_cached_response - with patch('src.cache.manager.cache_manager', manager): + + with patch("src.cache.manager.cache_manager", manager): response = get_cached_response(query, mock_anthropic_client) - + # Assert assert response is not None assert "1,234,567.89" in response # Anthropic client should not be called for cache hit mock_anthropic_client.messages.create.assert_not_called() - + def test_cache_warming_improves_performance(self, cache_config, benchmark_queries): """TEST: Cache warming improves subsequent performance""" # Arrange manager = CacheManager(config=cache_config) - + # Act - Warm cache from src.cache.manager import warm_cache - with patch('src.cache.manager.cache_manager', manager): + + with patch("src.cache.manager.cache_manager", manager): warm_cache(benchmark_queries[:5]) - + # Measure performance after warming warmed_times = [] for query in benchmark_queries[:5]: @@ -362,30 +382,33 @@ def test_cache_warming_improves_performance(self, cache_config, benchmark_querie query_hash = manager.generate_hash(query) result = manager.get(query_hash) end_time = time.perf_counter() - + if result: # Cache hit warmed_times.append((end_time - start_time) * 1000) - + # Assert assert len(warmed_times) > 0, "Cache warming should create cached entries" avg_warmed_time = sum(warmed_times) / len(warmed_times) - assert avg_warmed_time < 50, f"Warmed cache average {avg_warmed_time}ms exceeds target" - + assert ( + avg_warmed_time < 50 + ), f"Warmed cache average {avg_warmed_time}ms exceeds target" + def test_cache_invalidation_on_schema_changes(self, cache_config): """TEST: Cache invalidation occurs on schema changes""" # Arrange manager = CacheManager(config=cache_config) - + # Store entries with old schema version old_content = "Old schema content " * 50 manager.store("schema_hash_1", old_content) manager.store("schema_hash_2", old_content) - + # Act - Simulate schema change from src.cache.manager import invalidate_on_schema_change - with patch('src.cache.manager.cache_manager', manager): + + with patch("src.cache.manager.cache_manager", manager): invalidated_count = invalidate_on_schema_change("Database schema updated") - + # Assert assert invalidated_count == 2 assert manager.get("schema_hash_1") is None @@ -394,29 +417,29 @@ def test_cache_invalidation_on_schema_changes(self, cache_config): class TestCacheMetrics: """Test suite for cache metrics and monitoring.""" - + def test_cache_metrics_calculation_accuracy(self, cache_config): """TEST: Cache metrics calculation is accurate""" # Arrange manager = CacheManager(config=cache_config) - + # Create test scenario content = "Metrics calculation test " * 50 - + # Store 5 entries for i in range(5): manager.store(f"metrics_hash_{i}", content) - + # Access 3 entries (hits) and 2 non-existent (misses) for i in range(3): manager.get(f"metrics_hash_{i}") - + for i in range(2): manager.get(f"nonexistent_hash_{i}") - + # Act metrics = manager.get_metrics() - + # Assert assert metrics.total_entries == 5 assert metrics.cache_hits == 3 @@ -424,31 +447,31 @@ def test_cache_metrics_calculation_accuracy(self, cache_config): assert metrics.total_requests == 5 assert abs(metrics.hit_rate - 60.0) < 0.1 # 3/5 = 60% assert abs(metrics.miss_rate - 40.0) < 0.1 # 2/5 = 40% - + def test_cache_metrics_cost_calculation(self, cache_config, performance_targets): """TEST: Cache metrics calculate cost savings accurately""" # Arrange manager = CacheManager(config=cache_config) content = "Cost calculation test " * 100 - + # Store and access entries manager.store("cost_hash", content) manager.get("cost_hash") # Hit manager.get("miss_hash") # Miss - + # Act metrics = manager.get_metrics() - + # Assert - assert hasattr(metrics, 'cost_savings') - assert hasattr(metrics, 'token_efficiency') + assert hasattr(metrics, "cost_savings") + assert hasattr(metrics, "token_efficiency") # Verify cost reduction meets targets expected_hit_savings = performance_targets["cost_reduction_cache_hit"] expected_miss_savings = performance_targets["cost_reduction_cache_miss"] - - assert metrics.cost_savings['cache_hit_reduction'] >= expected_hit_savings - assert metrics.cost_savings['cache_miss_reduction'] >= expected_miss_savings - + + assert metrics.cost_savings["cache_hit_reduction"] >= expected_hit_savings + assert metrics.cost_savings["cache_miss_reduction"] >= expected_miss_savings + def test_cache_metrics_export_to_json(self, cache_config): """TEST: Cache metrics export to JSON format""" # Arrange @@ -456,12 +479,12 @@ def test_cache_metrics_export_to_json(self, cache_config): content = "JSON export test " * 50 manager.store("json_hash", content) manager.get("json_hash") - + # Act metrics = manager.get_metrics() json_metrics = metrics.to_json() parsed_metrics = json.loads(json_metrics) - + # Assert assert "total_entries" in parsed_metrics assert "hit_rate" in parsed_metrics @@ -475,82 +498,83 @@ def test_cache_metrics_export_to_json(self, cache_config): @pytest.mark.cache class TestCacheEdgeCases: """Test suite for cache edge cases and error conditions.""" - + def test_cache_handles_corrupted_entries(self, cache_config): """TEST: Cache handles corrupted entries gracefully""" # Arrange manager = CacheManager(config=cache_config) - + # Manually corrupt a cache entry corrupted_entry = CacheEntry(prefix="test", content="Valid content " * 50) corrupted_entry.content = None # Corrupt the content manager.cache["corrupted_hash"] = corrupted_entry - + # Act result = manager.get("corrupted_hash") - + # Assert assert result is None # Should handle corruption gracefully assert "corrupted_hash" not in manager.cache # Should remove corrupted entry - + def test_cache_handles_memory_pressure(self, cache_config): """TEST: Cache handles memory pressure situations""" # Arrange tight_config = cache_config.copy() tight_config["max_size"] = "100KB" # Small limit manager = CacheManager(config=tight_config) - + # Act - Try to store many large entries large_content = "X" * 1000 # 1KB each stored_count = 0 - + for i in range(200): # Try to store 200KB try: manager.store(f"pressure_hash_{i}", large_content) stored_count += 1 except CacheError: break - + # Assert assert stored_count < 200 # Should hit limit before storing all - assert stored_count > 0 # Should store some entries - + assert stored_count > 0 # Should store some entries + # Verify cache is still functional metrics = manager.get_metrics() assert metrics.total_entries == stored_count - + def test_cache_handles_concurrent_invalidation(self, cache_config): """TEST: Cache handles concurrent invalidation safely""" # Arrange import threading + manager = CacheManager(config=cache_config) - + # Store test entries content = "Concurrent invalidation test " * 50 for i in range(10): manager.store(f"concurrent_hash_{i}", content) - + # Act - Concurrent access and invalidation results = [] - + def access_cache(): for i in range(10): result = manager.get(f"concurrent_hash_{i}") results.append(result is not None) - + def invalidate_cache(): manager.invalidate_by_prefix(cache_config["prefix"]) - + access_thread = threading.Thread(target=access_cache) invalidate_thread = threading.Thread(target=invalidate_cache) - + access_thread.start() invalidate_thread.start() - + access_thread.join() invalidate_thread.join() - + # Assert # Should not crash, results may vary due to race conditions assert len(results) == 10 - assert len(manager.cache) == 0 # All entries should be invalidated \ No newline at end of file + assert len(manager.cache) == 0 # All entries should be invalidated diff --git a/tests/unit/test_cache_security.py b/tests/unit/test_cache_security.py index 06a9eda..0bd8cfc 100644 --- a/tests/unit/test_cache_security.py +++ b/tests/unit/test_cache_security.py @@ -11,8 +11,13 @@ from typing import Dict, Any, List from src.cache.security import ( - CacheSecurityValidator, ThreatLevel, SecurityThreat, SecurityScanResult, - get_security_validator, validate_cache_content_security, sanitize_cache_output + CacheSecurityValidator, + ThreatLevel, + SecurityThreat, + SecurityScanResult, + get_security_validator, + validate_cache_content_security, + sanitize_cache_output, ) from src.cache.config import SecurityConfig from src.core.errors import SecurityError, ValidationError @@ -20,22 +25,22 @@ class TestSecurityConfig: """Test suite for security configuration.""" - + def test_security_config_default_values(self): """TEST: Security config initializes with correct defaults""" config = SecurityConfig() - + assert config.enable_input_validation == True assert config.enable_output_sanitization == True assert len(config.sensitive_data_patterns) > 0 assert config.max_content_length == 1048576 # 1MB - + def test_security_config_sensitive_patterns(self): """TEST: Security config includes comprehensive sensitive data patterns""" config = SecurityConfig() - + patterns = " ".join(config.sensitive_data_patterns) - + # Check for common sensitive data patterns assert "password" in patterns.lower() assert "api" in patterns.lower() @@ -45,7 +50,7 @@ def test_security_config_sensitive_patterns(self): class TestCacheSecurityValidator: """Test suite for cache security validator.""" - + @pytest.fixture def security_config(self): """Create test security configuration.""" @@ -53,105 +58,105 @@ def security_config(self): enable_input_validation=True, enable_output_sanitization=True, sensitive_data_patterns=[ - r'\bpassword\s*[:=]\s*\S+', - r'\bapi[_-]?key\s*[:=]\s*\S+', - r'\b\d{3}-\d{2}-\d{4}\b', # SSN - r'\b\d{4}[-\s]?\d{4}[-\s]?\d{4}[-\s]?\d{4}\b', # Credit card + r"\bpassword\s*[:=]\s*\S+", + r"\bapi[_-]?key\s*[:=]\s*\S+", + r"\b\d{3}-\d{2}-\d{4}\b", # SSN + r"\b\d{4}[-\s]?\d{4}[-\s]?\d{4}[-\s]?\d{4}\b", # Credit card ], - max_content_length=10000 + max_content_length=10000, ) - + @pytest.fixture def validator(self, security_config): """Create security validator for testing.""" return CacheSecurityValidator(security_config) - + def test_validator_initialization(self, security_config): """TEST: Security validator initializes correctly""" validator = CacheSecurityValidator(security_config) - + assert validator.config == security_config assert len(validator._compiled_patterns) > 0 assert "sql_injection" in validator._injection_patterns assert "xss_injection" in validator._injection_patterns - + def test_validator_initialization_with_invalid_patterns(self): """TEST: Validator handles invalid regex patterns gracefully""" config = SecurityConfig( sensitive_data_patterns=["[invalid_regex"] # Invalid regex ) - + # Should not raise exception validator = CacheSecurityValidator(config) assert validator.config == config - + def test_validator_initialization_without_config(self): """TEST: Validator loads default config when none provided""" - with patch('src.cache.security.load_security_config_from_env') as mock_load: + with patch("src.cache.security.load_security_config_from_env") as mock_load: mock_config = SecurityConfig() mock_load.return_value = mock_config - + validator = CacheSecurityValidator() assert validator.config == mock_config - + def test_input_validation_passes_clean_content(self, validator): """TEST: Input validation passes clean content""" clean_content = "This is clean content without any threats." - + # Should not raise exception validator.validate_input(clean_content, "test_source") - + def test_input_validation_blocks_sensitive_data(self, validator): """TEST: Input validation blocks content with sensitive data""" sensitive_content = "User password: secret123 and api_key: abc123def" - + with pytest.raises(SecurityError): validator.validate_input(sensitive_content, "test_source") - + def test_input_validation_blocks_sql_injection(self, validator): """TEST: Input validation blocks SQL injection attempts""" sql_injection = "SELECT * FROM users; DROP TABLE users; --" - + with pytest.raises(SecurityError): validator.validate_input(sql_injection, "test_source") - + def test_input_validation_blocks_xss_injection(self, validator): """TEST: Input validation blocks XSS injection attempts""" xss_injection = "" - + with pytest.raises(SecurityError): validator.validate_input(xss_injection, "test_source") - + def test_input_validation_blocks_command_injection(self, validator): """TEST: Input validation blocks command injection attempts""" command_injection = "normal content; rm -rf / ;" - + with pytest.raises(SecurityError): validator.validate_input(command_injection, "test_source") - + def test_input_validation_content_too_large(self, validator): """TEST: Input validation blocks oversized content""" large_content = "X" * (validator.config.max_content_length + 1) - + with pytest.raises(ValidationError): validator.validate_input(large_content, "test_source") - + def test_input_validation_disabled(self, validator): """TEST: Input validation can be disabled""" validator.config.enable_input_validation = False sensitive_content = "password: secret123" - + # Should not raise exception when disabled validator.validate_input(sensitive_content, "test_source") - + def test_input_validation_high_threats_allowed(self, validator): """TEST: Input validation allows limited high threats""" # Single high threat should be allowed with warning content_with_single_threat = "Some content with " - + # Should not raise exception for single high threat validator.validate_input(content_with_single_threat, "test_source") - + def test_input_validation_multiple_high_threats_blocked(self, validator): """TEST: Input validation blocks multiple high threats""" content_with_multiple_threats = """ @@ -160,240 +165,260 @@ def test_input_validation_multiple_high_threats_blocked(self, validator): javascript:alert('xss') """ - + with pytest.raises(SecurityError): validator.validate_input(content_with_multiple_threats, "test_source") - + def test_output_sanitization_removes_scripts(self, validator): """TEST: Output sanitization removes script tags""" malicious_output = "Safe content more content" - + sanitized = validator.sanitize_output(malicious_output) - + assert "" - + sanitized = validator.sanitize_output(malicious_content) - + # Should return original content when disabled assert sanitized == malicious_content - + def test_output_sanitization_handles_errors(self, validator): """TEST: Output sanitization handles errors gracefully""" # Mock re.sub to raise exception - with patch('re.sub', side_effect=Exception("Regex error")): + with patch("re.sub", side_effect=Exception("Regex error")): content = "test content" sanitized = validator.sanitize_output(content) - + # Should return original content on error assert sanitized == content - + def test_content_scan_detects_sensitive_data(self, validator): """TEST: Content scan detects sensitive data patterns""" sensitive_content = "User SSN: 123-45-6789 and password: secret" - + result = validator.scan_content(sensitive_content) - + assert result.threats_detected > 0 assert result.overall_risk_level == ThreatLevel.CRITICAL assert "sensitive_data" in result.threat_breakdown assert len(result.threats) > 0 assert any(threat.threat_type == "sensitive_data" for threat in result.threats) - + def test_content_scan_detects_sql_injection(self, validator): """TEST: Content scan detects SQL injection patterns""" sql_injection = "'; DROP TABLE users; --" - + result = validator.scan_content(sql_injection) - + assert result.threats_detected > 0 assert "sql_injection" in result.threat_breakdown assert any(threat.threat_type == "sql_injection" for threat in result.threats) - + def test_content_scan_detects_xss_injection(self, validator): """TEST: Content scan detects XSS injection patterns""" xss_content = "" - + result = validator.scan_content(xss_content) - + assert result.threats_detected > 0 assert "xss_injection" in result.threat_breakdown - + def test_content_scan_detects_command_injection(self, validator): """TEST: Content scan detects command injection patterns""" command_injection = "file.txt; rm -rf /" - + result = validator.scan_content(command_injection) - + assert result.threats_detected > 0 assert "command_injection" in result.threat_breakdown - + def test_content_scan_detects_path_traversal(self, validator): """TEST: Content scan detects path traversal patterns""" path_traversal = "../../../etc/passwd" - + result = validator.scan_content(path_traversal) - + assert result.threats_detected > 0 assert "path_traversal" in result.threat_breakdown - + def test_content_scan_detects_suspicious_encoding(self, validator): """TEST: Content scan detects suspicious encoding patterns""" # Excessive URL encoding suspicious_content = "data=" + "%20%20%20" * 10 # Lots of encoded spaces - + result = validator.scan_content(suspicious_content) - + assert "suspicious_encoding" in result.threat_breakdown - + def test_content_scan_detects_unicode_bypass(self, validator): """TEST: Content scan detects Unicode bypass attempts""" unicode_content = "\\u0041\\u0042\\u0043" * 5 # Excessive Unicode - + result = validator.scan_content(unicode_content) - + assert "unicode_bypass" in result.threat_breakdown - + def test_content_scan_detects_base64_payload(self, validator): """TEST: Content scan detects suspicious base64 content""" import base64 - + # Create base64 encoded script malicious_script = "script>alert('xss') 0 - + def test_content_scan_error_handling(self, validator): """TEST: Content scan handles errors gracefully""" # Mock pattern matching to raise exception - with patch.object(validator, '_compiled_patterns', {'bad_pattern': Exception("Test error")}): + with patch.object( + validator, "_compiled_patterns", {"bad_pattern": Exception("Test error")} + ): result = validator.scan_content("test content") - + # Should handle error and return result assert result.scanned_content_length > 0 assert "scan_error" in result.threat_breakdown - + def test_threat_level_determination_sql_injection(self, validator): """TEST: Threat level determination for SQL injection""" # Critical SQL injection - critical_level = validator._determine_injection_threat_level("sql_injection", ["DROP TABLE"]) + critical_level = validator._determine_injection_threat_level( + "sql_injection", ["DROP TABLE"] + ) assert critical_level == ThreatLevel.CRITICAL - + # High SQL injection - high_level = validator._determine_injection_threat_level("sql_injection", ["SELECT * FROM"]) + high_level = validator._determine_injection_threat_level( + "sql_injection", ["SELECT * FROM"] + ) assert high_level == ThreatLevel.HIGH - + def test_threat_level_determination_xss_injection(self, validator): """TEST: Threat level determination for XSS injection""" # Critical XSS - critical_level = validator._determine_injection_threat_level("xss_injection", ["Safe content" - + sanitized = sanitize_cache_output(malicious_content) - + assert "