diff --git a/examples/arcade-dev/01_basic_integration/basic_arcade_client.py b/examples/arcade-dev/01_basic_integration/basic_arcade_client.py
index 38f92c5..9266c4e 100644
--- a/examples/arcade-dev/01_basic_integration/basic_arcade_client.py
+++ b/examples/arcade-dev/01_basic_integration/basic_arcade_client.py
@@ -38,6 +38,7 @@
# Import official arcade SDK
try:
from arcadepy import Arcade
+
ARCADE_SDK_AVAILABLE = True
except ImportError:
ARCADE_SDK_AVAILABLE = False
@@ -52,6 +53,7 @@
@dataclass
class ArcadeConfig:
"""Configuration for Arcade.dev API client."""
+
api_key: str
user_id: str = "demo@example.com"
timeout: int = 30
@@ -61,65 +63,82 @@ class ArcadeConfig:
class MockArcadeResponse:
"""Mock response class for demo mode."""
+
def __init__(self, data: Dict[str, Any]):
- self.id = data.get('id', 'demo_response_123')
- self.status = data.get('status', 'success')
- self.result = data.get('result', {})
+ self.id = data.get("id", "demo_response_123")
+ self.status = data.get("status", "success")
+ self.result = data.get("result", {})
self.data = data
-
+
def __getattr__(self, name):
return self.data.get(name)
class BasicArcadeClient:
"""Basic Arcade.dev API client with FACT integration using official SDK."""
-
- def __init__(self, config: ArcadeConfig, cache_manager: Optional[CacheManager] = None):
+
+ def __init__(
+ self, config: ArcadeConfig, cache_manager: Optional[CacheManager] = None
+ ):
self.config = config
self.cache_manager = cache_manager
self.logger = logging.getLogger(__name__)
self.client: Optional[Arcade] = None
-
+
# Mock data for demo mode
self._demo_tools = [
{
"name": "Math.Sqrt",
"description": "Calculate square root of a number",
"category": "mathematics",
- "input_schema": {"type": "object", "properties": {"a": {"type": "number"}}}
+ "input_schema": {
+ "type": "object",
+ "properties": {"a": {"type": "number"}},
+ },
},
{
- "name": "Google.ListEmails",
+ "name": "Google.ListEmails",
"description": "List emails from Gmail",
"category": "email",
- "input_schema": {"type": "object", "properties": {"n_emails": {"type": "integer"}}}
+ "input_schema": {
+ "type": "object",
+ "properties": {"n_emails": {"type": "integer"}},
+ },
},
{
"name": "Slack.PostMessage",
"description": "Post message to Slack channel",
- "category": "messaging",
- "input_schema": {"type": "object", "properties": {"channel": {"type": "string"}, "message": {"type": "string"}}}
- }
+ "category": "messaging",
+ "input_schema": {
+ "type": "object",
+ "properties": {
+ "channel": {"type": "string"},
+ "message": {"type": "string"},
+ },
+ },
+ },
]
-
+
async def __aenter__(self):
"""Async context manager entry."""
await self.connect()
return self
-
+
async def __aexit__(self, exc_type, exc_val, exc_tb):
"""Async context manager exit."""
await self.disconnect()
-
+
async def connect(self):
"""Establish connection to Arcade.dev API."""
if self.config.demo_mode:
self.logger.info("Running in demo mode - using mock responses")
return
-
+
if not ARCADE_SDK_AVAILABLE:
- raise RuntimeError("arcadepy SDK not available. Install with: pip install arcadepy")
-
+ raise RuntimeError(
+ "arcadepy SDK not available. Install with: pip install arcadepy"
+ )
+
try:
# Initialize the official Arcade client
self.client = Arcade(api_key=self.config.api_key)
@@ -127,19 +146,21 @@ async def connect(self):
except Exception as e:
self.logger.error(f"Failed to connect to Arcade.dev: {e}")
raise
-
+
async def disconnect(self):
"""Close connection to Arcade.dev API."""
if self.config.demo_mode:
self.logger.info("Demo mode - no connection to close")
return
-
+
if self.client:
# The official SDK doesn't require explicit disconnection
self.client = None
self.logger.info("Disconnected from Arcade.dev API")
-
- async def _execute_with_cache_and_retry(self, operation_name: str, operation_func, *args, **kwargs):
+
+ async def _execute_with_cache_and_retry(
+ self, operation_name: str, operation_func, *args, **kwargs
+ ):
"""Execute operation with caching and retry logic."""
# Check cache first
cache_key = f"arcade:{operation_name}:{hash(str(args + tuple(kwargs.items())))}"
@@ -149,38 +170,44 @@ async def _execute_with_cache_and_retry(self, operation_name: str, operation_fun
self.logger.debug(f"Cache hit for {operation_name}")
# Return cached result, assuming it's stored as JSON string
try:
- if hasattr(cached_result, 'content'):
+ if hasattr(cached_result, "content"):
return json.loads(cached_result.content)
else:
return cached_result
except (json.JSONDecodeError, AttributeError):
pass # Fall through to actual operation
-
+
# Execute operation with retries
last_exception = None
for attempt in range(self.config.max_retries):
try:
result = await operation_func(*args, **kwargs)
-
+
# Cache successful results
if self.cache_manager and result:
try:
- cache_data = result if isinstance(result, dict) else result.__dict__
+ cache_data = (
+ result if isinstance(result, dict) else result.__dict__
+ )
self.cache_manager.store(cache_key, json.dumps(cache_data))
except Exception as e:
self.logger.warning(f"Failed to cache result: {e}")
-
+
return result
-
+
except Exception as e:
last_exception = e
- self.logger.warning(f"Operation {operation_name} attempt {attempt + 1} failed: {e}")
+ self.logger.warning(
+ f"Operation {operation_name} attempt {attempt + 1} failed: {e}"
+ )
if attempt == self.config.max_retries - 1:
break
- await asyncio.sleep(2 ** attempt) # Exponential backoff
-
- raise last_exception or RuntimeError(f"All retry attempts failed for {operation_name}")
-
+ await asyncio.sleep(2**attempt) # Exponential backoff
+
+ raise last_exception or RuntimeError(
+ f"All retry attempts failed for {operation_name}"
+ )
+
async def health_check(self) -> Dict[str, Any]:
"""Check API health and connectivity."""
if self.config.demo_mode:
@@ -191,12 +218,12 @@ async def health_check(self) -> Dict[str, Any]:
"services": {
"auth": "healthy",
"tools": "healthy",
- "database": "healthy"
+ "database": "healthy",
},
- "_demo_mode": True
+ "_demo_mode": True,
}
return result
-
+
# For real API, we'll try to list tools as a health check
try:
await self.list_tools()
@@ -204,18 +231,15 @@ async def health_check(self) -> Dict[str, Any]:
"status": "healthy",
"version": "1.4.0",
"timestamp": "2025-05-25T19:27:00Z",
- "services": {
- "auth": "healthy",
- "tools": "healthy"
- }
+ "services": {"auth": "healthy", "tools": "healthy"},
}
except Exception as e:
return {
"status": "unhealthy",
"error": str(e),
- "timestamp": "2025-05-25T19:27:00Z"
+ "timestamp": "2025-05-25T19:27:00Z",
}
-
+
async def get_user_info(self) -> Dict[str, Any]:
"""Get current user information."""
if self.config.demo_mode:
@@ -226,26 +250,26 @@ async def get_user_info(self) -> Dict[str, Any]:
"tools_available": 100,
"tools_used": 5,
"quota_remaining": 95,
- "_demo_mode": True
+ "_demo_mode": True,
}
-
+
# The arcadepy SDK doesn't have a direct user info endpoint in the examples
# So we'll return basic info based on the configuration
return {
"user_id": self.config.user_id,
"api_key_status": "valid" if self.client else "invalid",
- "sdk_version": "1.4.0"
+ "sdk_version": "1.4.0",
}
-
+
async def list_tools(self) -> List[Dict[str, Any]]:
"""List available tools."""
if self.config.demo_mode:
return {
"tools": self._demo_tools,
"count": len(self._demo_tools),
- "_demo_mode": True
+ "_demo_mode": True,
}
-
+
async def _list_tools():
# The arcadepy SDK doesn't provide a direct list tools method in the examples
# This would need to be implemented based on the actual API
@@ -253,53 +277,55 @@ async def _list_tools():
return {
"tools": [],
"count": 0,
- "message": "Tool listing requires specific API implementation"
+ "message": "Tool listing requires specific API implementation",
}
-
+
return await self._execute_with_cache_and_retry("list_tools", _list_tools)
-
- async def execute_tool(self, tool_name: str, tool_input: Dict[str, Any]) -> Dict[str, Any]:
+
+ async def execute_tool(
+ self, tool_name: str, tool_input: Dict[str, Any]
+ ) -> Dict[str, Any]:
"""Execute a tool using Arcade.dev."""
if self.config.demo_mode:
return self._generate_mock_tool_execution(tool_name, tool_input)
-
+
async def _execute_tool():
try:
response = self.client.tools.execute(
- tool_name=tool_name,
- input=tool_input,
- user_id=self.config.user_id
+ tool_name=tool_name, input=tool_input, user_id=self.config.user_id
)
-
+
# Convert response to dict format
result = {
"id": response.id,
- "status": getattr(response, 'status', 'completed'),
+ "status": getattr(response, "status", "completed"),
"tool_name": tool_name,
"input": tool_input,
- "result": getattr(response, 'result', None),
- "execution_time_ms": getattr(response, 'execution_time_ms', None)
+ "result": getattr(response, "result", None),
+ "execution_time_ms": getattr(response, "execution_time_ms", None),
}
-
+
return result
-
+
except Exception as e:
self.logger.error(f"Tool execution failed: {e}")
return {
"status": "failed",
"tool_name": tool_name,
"input": tool_input,
- "error": str(e)
+ "error": str(e),
}
-
+
return await self._execute_with_cache_and_retry("execute_tool", _execute_tool)
-
- def _generate_mock_tool_execution(self, tool_name: str, tool_input: Dict[str, Any]) -> Dict[str, Any]:
+
+ def _generate_mock_tool_execution(
+ self, tool_name: str, tool_input: Dict[str, Any]
+ ) -> Dict[str, Any]:
"""Generate a mock tool execution response."""
# Generate realistic mock responses based on the tool
if tool_name == "Math.Sqrt":
- number = float(tool_input.get('a', 16))
- result = number ** 0.5
+ number = float(tool_input.get("a", 16))
+ result = number**0.5
return {
"_demo_mode": True,
"_demo_timestamp": "2025-05-25T19:27:00Z",
@@ -308,13 +334,17 @@ def _generate_mock_tool_execution(self, tool_name: str, tool_input: Dict[str, An
"status": "completed",
"result": {"value": result, "input": number},
"input": tool_input,
- "execution_time_ms": 150
+ "execution_time_ms": 150,
}
-
+
elif tool_name == "Google.ListEmails":
- n_emails = tool_input.get('n_emails', 5)
+ n_emails = tool_input.get("n_emails", 5)
emails = [
- {"id": f"email_{i}", "subject": f"Demo Email {i}", "from": f"sender{i}@example.com"}
+ {
+ "id": f"email_{i}",
+ "subject": f"Demo Email {i}",
+ "from": f"sender{i}@example.com",
+ }
for i in range(1, min(n_emails + 1, 11)) # Cap at 10 emails
]
return {
@@ -325,9 +355,9 @@ def _generate_mock_tool_execution(self, tool_name: str, tool_input: Dict[str, An
"status": "completed",
"result": {"emails": emails, "count": len(emails)},
"input": tool_input,
- "execution_time_ms": 1200
+ "execution_time_ms": 1200,
}
-
+
elif tool_name == "Slack.PostMessage":
return {
"_demo_mode": True,
@@ -336,14 +366,14 @@ def _generate_mock_tool_execution(self, tool_name: str, tool_input: Dict[str, An
"tool_name": tool_name,
"status": "completed",
"result": {
- "message_ts": "1640995200.000100",
- "channel": tool_input.get('channel', '#general'),
- "message_id": "demo_msg_789"
+ "message_ts": "1640995200.000100",
+ "channel": tool_input.get("channel", "#general"),
+ "message_id": "demo_msg_789",
},
"input": tool_input,
- "execution_time_ms": 800
+ "execution_time_ms": 800,
}
-
+
# Default tool execution response
return {
"_demo_mode": True,
@@ -353,7 +383,7 @@ def _generate_mock_tool_execution(self, tool_name: str, tool_input: Dict[str, An
"status": "completed",
"result": {"message": f"Demo execution of {tool_name}", "data": tool_input},
"input": tool_input,
- "execution_time_ms": 500
+ "execution_time_ms": 500,
}
@@ -362,31 +392,31 @@ async def main():
# Configure logging
logging.basicConfig(
level=logging.INFO,
- format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
+ format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
)
-
+
print("š® Arcade.dev Python SDK Integration Example")
print("=" * 50)
-
+
# Load configuration from environment
api_key = os.getenv("ARCADE_API_KEY", "")
user_id = os.getenv("ARCADE_USER_ID", "demo@example.com")
-
+
# Enable demo mode if no API key or if API key looks like a placeholder
demo_mode = (
- not bool(api_key.strip()) or
- api_key.strip() in ["your_api_key", "demo_key", "placeholder"] or
- len(api_key.strip()) < 10 # Real API keys are typically longer
+ not bool(api_key.strip())
+ or api_key.strip() in ["your_api_key", "demo_key", "placeholder"]
+ or len(api_key.strip()) < 10 # Real API keys are typically longer
)
-
+
config = ArcadeConfig(
api_key=api_key if not demo_mode else "demo_key",
user_id=user_id,
timeout=int(os.getenv("ARCADE_TIMEOUT", "30")),
max_retries=int(os.getenv("ARCADE_MAX_RETRIES", "3")),
- demo_mode=demo_mode
+ demo_mode=demo_mode,
)
-
+
if demo_mode:
print("š Demo Mode: No API key found - using mock responses")
print("š” To use real API: Set ARCADE_API_KEY environment variable")
@@ -397,7 +427,7 @@ async def main():
print(f"š Using API key: {api_key[:10]}...")
print(f"š¤ User ID: {user_id}")
print()
-
+
# Initialize cache manager with config
cache_config = {
"prefix": "arcade_demo",
@@ -405,84 +435,92 @@ async def main():
"max_size": "10MB",
"ttl_seconds": 3600,
"hit_target_ms": 30,
- "miss_target_ms": 120
+ "miss_target_ms": 120,
}
cache_manager = CacheManager(cache_config)
-
+
# Demonstrate basic API usage
async with BasicArcadeClient(config, cache_manager) as client:
try:
# Health check
print("š Checking API health...")
health = await client.health_check()
- status_icon = "š" if health.get('_demo_mode') else "ā
"
+ status_icon = "š" if health.get("_demo_mode") else "ā
"
print(f"{status_icon} API Status: {health.get('status', 'unknown')}")
- if health.get('_demo_mode'):
+ if health.get("_demo_mode"):
print(" (Demo response)")
- elif health.get('status') == 'unhealthy':
+ elif health.get("status") == "unhealthy":
print(f" Error: {health.get('error', 'Unknown error')}")
-
+
# User info
print("\nš¤ Getting user information...")
user_info = await client.get_user_info()
- user_icon = "š" if user_info.get('_demo_mode') else "ā
"
+ user_icon = "š" if user_info.get("_demo_mode") else "ā
"
print(f"{user_icon} User: {user_info.get('user_id', 'unknown')}")
- if user_info.get('tools_available'):
+ if user_info.get("tools_available"):
print(f" Tools available: {user_info.get('tools_available', 'N/A')}")
print(f" Quota remaining: {user_info.get('quota_remaining', 'N/A')}")
- if user_info.get('_demo_mode'):
+ if user_info.get("_demo_mode"):
print(" (Demo response)")
-
+
# List available tools
print("\nš ļø Listing available tools...")
tools_response = await client.list_tools()
- tools_icon = "š" if tools_response.get('_demo_mode') else "ā
"
- available_tools = tools_response.get('tools', [])
+ tools_icon = "š" if tools_response.get("_demo_mode") else "ā
"
+ available_tools = tools_response.get("tools", [])
print(f"{tools_icon} Found {len(available_tools)} tools")
- if tools_response.get('_demo_mode'):
+ if tools_response.get("_demo_mode"):
print(" (Demo response)")
-
+
# Display available tools
if available_tools:
print("\nš Available Tools:")
for i, tool in enumerate(available_tools[:5], 1): # Show first 5 tools
- print(f" {i}. {tool.get('name', 'Unknown')} - {tool.get('description', 'No description')}")
- if tool.get('category'):
+ print(
+ f" {i}. {tool.get('name', 'Unknown')} - {tool.get('description', 'No description')}"
+ )
+ if tool.get("category"):
print(f" Category: {tool.get('category')}")
-
+
# Tool execution examples
print("\nš§ Tool Execution Examples:")
-
+
# Example 1: Math.Sqrt
print("\n1. Math.Sqrt - Calculate square root of 625...")
execution = await client.execute_tool("Math.Sqrt", {"a": 625})
- exec_icon = "š" if execution.get('_demo_mode') else "ā
"
+ exec_icon = "š" if execution.get("_demo_mode") else "ā
"
print(f"{exec_icon} Status: {execution.get('status', 'unknown')}")
- if execution.get('result'):
- result_val = execution['result'].get('value') if isinstance(execution['result'], dict) else execution['result']
+ if execution.get("result"):
+ result_val = (
+ execution["result"].get("value")
+ if isinstance(execution["result"], dict)
+ else execution["result"]
+ )
print(f" Result: ā625 = {result_val}")
print(f" Execution time: {execution.get('execution_time_ms', 'N/A')}ms")
- if execution.get('_demo_mode'):
+ if execution.get("_demo_mode"):
print(" (Demo response)")
-
+
# Example 2: Google.ListEmails
print("\n2. Google.ListEmails - List 3 emails...")
execution = await client.execute_tool("Google.ListEmails", {"n_emails": 3})
- exec_icon = "š" if execution.get('_demo_mode') else "ā
"
+ exec_icon = "š" if execution.get("_demo_mode") else "ā
"
print(f"{exec_icon} Status: {execution.get('status', 'unknown')}")
- if execution.get('result') and isinstance(execution['result'], dict):
- emails = execution['result'].get('emails', [])
+ if execution.get("result") and isinstance(execution["result"], dict):
+ emails = execution["result"].get("emails", [])
print(f" Found {len(emails)} emails")
for i, email in enumerate(emails[:2], 1): # Show first 2
- print(f" {i}. {email.get('subject', 'No subject')} from {email.get('from', 'Unknown')}")
+ print(
+ f" {i}. {email.get('subject', 'No subject')} from {email.get('from', 'Unknown')}"
+ )
print(f" Execution time: {execution.get('execution_time_ms', 'N/A')}ms")
- if execution.get('_demo_mode'):
+ if execution.get("_demo_mode"):
print(" (Demo response)")
-
+
except Exception as e:
print(f"ā Error: {e}")
return 1
-
+
print("\n" + "=" * 50)
if demo_mode:
print("š Basic integration example completed successfully in demo mode!")
@@ -496,4 +534,4 @@ async def main():
if __name__ == "__main__":
exit_code = asyncio.run(main())
- sys.exit(exit_code)
\ No newline at end of file
+ sys.exit(exit_code)
diff --git a/examples/arcade-dev/02_tool_registration/register_fact_tools.py b/examples/arcade-dev/02_tool_registration/register_fact_tools.py
index 32e653d..6413d03 100644
--- a/examples/arcade-dev/02_tool_registration/register_fact_tools.py
+++ b/examples/arcade-dev/02_tool_registration/register_fact_tools.py
@@ -33,6 +33,7 @@
# Import official arcade SDK
try:
from arcadepy import Arcade
+
ARCADE_SDK_AVAILABLE = True
except ImportError:
ARCADE_SDK_AVAILABLE = False
@@ -50,6 +51,7 @@
@dataclass
class ToolRegistrationConfig:
"""Configuration for tool registration with Arcade.dev."""
+
arcade_api_key: str
arcade_base_url: str = "https://api.arcade.dev"
workspace_id: Optional[str] = None
@@ -62,188 +64,194 @@ class ToolRegistrationConfig:
class FactToolRegistrar:
"""
Manages registration of FACT tools with Arcade.dev platform.
-
+
Handles schema generation, authentication setup, permission configuration,
and batch registration operations.
"""
-
+
def __init__(self, config: ToolRegistrationConfig):
"""
Initialize tool registrar with configuration.
-
+
Args:
config: Registration configuration
"""
self.config = config
-
+
# Create arcade config from tool registration config
arcade_config = ArcadeConfig(
api_key=config.arcade_api_key,
user_id=config.user_id,
timeout=config.default_timeout,
- demo_mode=config.demo_mode
+ demo_mode=config.demo_mode,
)
-
+
self.arcade_client = BasicArcadeClient(arcade_config)
self.tool_registry = get_tool_registry()
self.auth_manager = AuthorizationManager()
self.logger = logging.getLogger(__name__)
-
+
# Track registration results
self.registration_results: List[Dict[str, Any]] = []
-
+
async def connect(self) -> None:
"""Establish connection to Arcade.dev platform."""
try:
await self.arcade_client.connect()
self.logger.info("Successfully connected to Arcade.dev platform")
-
+
# Verify workspace access if specified
if self.config.workspace_id and not self.config.demo_mode:
await self._verify_workspace_access()
-
+
except Exception as e:
self.logger.error(f"Failed to connect to Arcade.dev: {e}")
raise
-
+
async def register_all_tools(self) -> Dict[str, Any]:
"""
Register all FACT tools from the registry with Arcade.dev.
-
+
Returns:
Registration summary with results for each tool
"""
tool_names = self.tool_registry.list_tools()
self.logger.info(f"Registering {len(tool_names)} tools with Arcade.dev")
-
+
registration_summary = {
"total_tools": len(tool_names),
"successful": 0,
"failed": 0,
"skipped": 0,
- "results": []
+ "results": [],
}
-
+
for tool_name in tool_names:
try:
result = await self.register_single_tool(tool_name)
registration_summary["results"].append(result)
-
+
if result["status"] == "success":
registration_summary["successful"] += 1
elif result["status"] == "failed":
registration_summary["failed"] += 1
else:
registration_summary["skipped"] += 1
-
+
except Exception as e:
error_result = {
"tool_name": tool_name,
"status": "failed",
"error": str(e),
- "arcade_tool_id": None
+ "arcade_tool_id": None,
}
registration_summary["results"].append(error_result)
registration_summary["failed"] += 1
-
+
self.logger.error(f"Failed to register tool {tool_name}: {e}")
-
+
self.logger.info(
f"Tool registration completed: {registration_summary['successful']} successful, "
f"{registration_summary['failed']} failed, {registration_summary['skipped']} skipped"
)
-
+
return registration_summary
-
+
async def register_single_tool(self, tool_name: str) -> Dict[str, Any]:
"""
Register a single FACT tool with Arcade.dev.
-
+
Args:
tool_name: Name of the tool to register
-
+
Returns:
Registration result dictionary
"""
try:
# Get tool definition from FACT registry
tool_definition = self.tool_registry.get_tool(tool_name)
-
+
# Check if tool already exists on Arcade.dev
existing_tool = await self._check_existing_tool(tool_name)
if existing_tool:
self.logger.info(f"Tool {tool_name} already exists, updating...")
- return await self._update_existing_tool(tool_name, tool_definition, existing_tool)
-
+ return await self._update_existing_tool(
+ tool_name, tool_definition, existing_tool
+ )
+
# Prepare tool definition for Arcade.dev
- arcade_tool_definition = self._prepare_arcade_tool_definition(tool_definition)
-
+ arcade_tool_definition = self._prepare_arcade_tool_definition(
+ tool_definition
+ )
+
# Register tool with Arcade.dev
- registration_result = await self._register_tool_with_arcade(arcade_tool_definition)
-
+ registration_result = await self._register_tool_with_arcade(
+ arcade_tool_definition
+ )
+
# Set up permissions and authentication
if tool_definition.requires_auth:
await self._setup_tool_permissions(
- registration_result.get("id"),
- tool_name,
- tool_definition
+ registration_result.get("id"), tool_name, tool_definition
)
-
+
result = {
"tool_name": tool_name,
"status": "success",
"arcade_tool_id": registration_result.get("id"),
"version": tool_definition.version,
"requires_auth": tool_definition.requires_auth,
- "message": "Tool registered successfully"
+ "message": "Tool registered successfully",
}
-
+
self.logger.info(f"Successfully registered tool: {tool_name}")
return result
-
+
except Exception as e:
self.logger.error(f"Failed to register tool {tool_name}: {e}")
return {
"tool_name": tool_name,
"status": "failed",
"error": str(e),
- "arcade_tool_id": None
+ "arcade_tool_id": None,
}
-
- async def update_tool_permissions(self, tool_name: str, permissions: Dict[str, Any]) -> bool:
+
+ async def update_tool_permissions(
+ self, tool_name: str, permissions: Dict[str, Any]
+ ) -> bool:
"""
Update permissions for a registered tool.
-
+
Args:
tool_name: Name of the tool
permissions: Permission configuration
-
+
Returns:
True if update was successful
"""
try:
# Get tool info from Arcade.dev
tool_info = await self._get_tool_info(tool_name)
-
+
if not tool_info:
raise ValueError(f"Tool {tool_name} not found on Arcade.dev")
-
+
tool_id = tool_info.get("id")
-
+
# Update permissions via Arcade.dev API
await self._update_tool_permissions(tool_id, permissions)
-
+
self.logger.info(f"Updated permissions for tool: {tool_name}")
return True
-
+
except Exception as e:
self.logger.error(f"Failed to update permissions for {tool_name}: {e}")
return False
-
+
async def list_registered_tools(self) -> List[Dict[str, Any]]:
"""
List all tools registered on Arcade.dev platform.
-
+
Returns:
List of registered tools with metadata
"""
@@ -257,7 +265,7 @@ async def list_registered_tools(self) -> List[Dict[str, Any]]:
"version": "1.0.0",
"status": "active",
"requires_auth": True,
- "_demo_mode": True
+ "_demo_mode": True,
},
{
"id": "demo_tool_data_transform",
@@ -265,40 +273,42 @@ async def list_registered_tools(self) -> List[Dict[str, Any]]:
"version": "1.0.0",
"status": "active",
"requires_auth": False,
- "_demo_mode": True
- }
+ "_demo_mode": True,
+ },
]
tools = demo_tools
else:
# For real mode, would use actual API
tools = []
-
+
# Enhance with FACT registry information
enhanced_tools = []
for tool in tools:
tool_name = tool.get("name")
enhanced_tool = dict(tool)
-
+
# Add FACT registry info if available
try:
fact_tool = self.tool_registry.get_tool(tool_name)
if fact_tool:
enhanced_tool["fact_version"] = fact_tool.version
- enhanced_tool["fact_created_at"] = getattr(fact_tool, 'created_at', None)
+ enhanced_tool["fact_created_at"] = getattr(
+ fact_tool, "created_at", None
+ )
enhanced_tool["fact_timeout"] = fact_tool.timeout_seconds
except:
enhanced_tool["fact_version"] = None
enhanced_tool["fact_created_at"] = None
enhanced_tool["fact_timeout"] = None
-
+
enhanced_tools.append(enhanced_tool)
-
+
return enhanced_tools
-
+
except Exception as e:
self.logger.error(f"Failed to list registered tools: {e}")
return []
-
+
async def close(self) -> None:
"""Close connections and cleanup resources."""
try:
@@ -306,14 +316,14 @@ async def close(self) -> None:
self.logger.info("Tool registrar closed successfully")
except Exception as e:
self.logger.warning(f"Error closing tool registrar: {e}")
-
+
def _prepare_arcade_tool_definition(self, tool_definition) -> Dict[str, Any]:
"""
Convert FACT tool definition to Arcade.dev format.
-
+
Args:
tool_definition: FACT tool definition
-
+
Returns:
Arcade.dev compatible tool definition
"""
@@ -323,7 +333,9 @@ def _prepare_arcade_tool_definition(self, tool_definition) -> Dict[str, Any]:
"parameters": {
"type": "object",
"properties": tool_definition.parameters,
- "required": self._extract_required_parameters(tool_definition.parameters)
+ "required": self._extract_required_parameters(
+ tool_definition.parameters
+ ),
},
"version": tool_definition.version,
"requires_auth": tool_definition.requires_auth,
@@ -331,17 +343,21 @@ def _prepare_arcade_tool_definition(self, tool_definition) -> Dict[str, Any]:
"metadata": {
"created_by": "FACT",
"framework_version": "1.0.0",
- "category": tool_definition.name.split(".")[0] if "." in tool_definition.name else "general",
- "workspace_id": self.config.workspace_id
- }
+ "category": (
+ tool_definition.name.split(".")[0]
+ if "." in tool_definition.name
+ else "general"
+ ),
+ "workspace_id": self.config.workspace_id,
+ },
}
-
+
# Add workspace context if specified
if self.config.workspace_id:
arcade_definition["workspace_id"] = self.config.workspace_id
-
+
return arcade_definition
-
+
def _extract_required_parameters(self, parameters: Dict[str, Any]) -> List[str]:
"""Extract required parameter names from parameter schema."""
required = []
@@ -351,7 +367,7 @@ def _extract_required_parameters(self, parameters: Dict[str, Any]) -> List[str]:
if "default" not in param_schema and param_schema.get("required", True):
required.append(param_name)
return required
-
+
async def _check_existing_tool(self, tool_name: str) -> Optional[Dict[str, Any]]:
"""Check if tool already exists on Arcade.dev."""
try:
@@ -362,7 +378,7 @@ async def _check_existing_tool(self, tool_name: str) -> Optional[Dict[str, Any]]
"id": f"demo_tool_{tool_name.replace('.', '_').lower()}",
"name": tool_name,
"version": "1.0.0",
- "status": "active"
+ "status": "active",
}
return None
else:
@@ -371,28 +387,34 @@ async def _check_existing_tool(self, tool_name: str) -> Optional[Dict[str, Any]]
return None
except Exception:
return None
-
- async def _update_existing_tool(self, tool_name: str, tool_definition, existing_tool: Dict[str, Any]) -> Dict[str, Any]:
+
+ async def _update_existing_tool(
+ self, tool_name: str, tool_definition, existing_tool: Dict[str, Any]
+ ) -> Dict[str, Any]:
"""Update an existing tool with new definition."""
try:
# Compare versions to decide if update is needed
existing_version = existing_tool.get("version", "0.0.0")
new_version = tool_definition.version
-
+
if self._is_newer_version(new_version, existing_version):
# Update tool definition
- arcade_definition = self._prepare_arcade_tool_definition(tool_definition)
+ arcade_definition = self._prepare_arcade_tool_definition(
+ tool_definition
+ )
# Note: Actual update implementation would depend on Arcade.dev API
# For now, we'll delete and re-register
await self._delete_tool(tool_name)
- registration_result = await self._register_tool_with_arcade(arcade_definition)
-
+ registration_result = await self._register_tool_with_arcade(
+ arcade_definition
+ )
+
return {
"tool_name": tool_name,
"status": "success",
"arcade_tool_id": registration_result.get("id"),
"version": new_version,
- "message": f"Tool updated from {existing_version} to {new_version}"
+ "message": f"Tool updated from {existing_version} to {new_version}",
}
else:
return {
@@ -400,43 +422,45 @@ async def _update_existing_tool(self, tool_name: str, tool_definition, existing_
"status": "skipped",
"arcade_tool_id": existing_tool.get("id"),
"version": existing_version,
- "message": "Tool version is up to date"
+ "message": "Tool version is up to date",
}
-
+
except Exception as e:
raise Exception(f"Failed to update existing tool: {e}")
-
+
def _is_newer_version(self, new_version: str, existing_version: str) -> bool:
"""Compare version strings to determine if new version is newer."""
+
def version_tuple(version: str) -> tuple:
try:
- return tuple(map(int, version.split('.')))
+ return tuple(map(int, version.split(".")))
except ValueError:
return (0, 0, 0)
-
+
return version_tuple(new_version) > version_tuple(existing_version)
-
- async def _setup_tool_permissions(self, tool_id: str, tool_name: str, tool_definition) -> None:
+
+ async def _setup_tool_permissions(
+ self, tool_id: str, tool_name: str, tool_definition
+ ) -> None:
"""Set up authentication and permissions for a tool."""
try:
# Configure tool permissions based on FACT requirements
permissions_config = {
"require_authentication": tool_definition.requires_auth,
"allowed_scopes": ["tool:execute"],
- "rate_limit": {
- "calls_per_minute": 60,
- "calls_per_hour": 1000
- },
- "audit_logging": True
+ "rate_limit": {"calls_per_minute": 60, "calls_per_hour": 1000},
+ "audit_logging": True,
}
-
+
# Apply permissions via Arcade.dev API
await self._update_tool_permissions(tool_id, permissions_config)
-
+
except Exception as e:
self.logger.warning(f"Failed to setup permissions for {tool_name}: {e}")
-
- async def _update_tool_permissions(self, tool_id: str, permissions: Dict[str, Any]) -> None:
+
+ async def _update_tool_permissions(
+ self, tool_id: str, permissions: Dict[str, Any]
+ ) -> None:
"""Update tool permissions via Arcade.dev API."""
if self.config.demo_mode:
self.logger.info(f"Demo: Updated permissions for tool {tool_id}")
@@ -444,17 +468,21 @@ async def _update_tool_permissions(self, tool_id: str, permissions: Dict[str, An
# Note: This would use actual Arcade.dev permissions API
# Implementation depends on their specific API structure
self.logger.debug(f"Updating permissions for tool {tool_id}: {permissions}")
-
+
async def _verify_workspace_access(self) -> None:
"""Verify access to the specified workspace."""
try:
# Verify workspace access
# Implementation would depend on Arcade.dev workspace API
- self.logger.info(f"Verified access to workspace: {self.config.workspace_id}")
+ self.logger.info(
+ f"Verified access to workspace: {self.config.workspace_id}"
+ )
except Exception as e:
raise Exception(f"Workspace access verification failed: {e}")
- async def _register_tool_with_arcade(self, tool_definition: Dict[str, Any]) -> Dict[str, Any]:
+ async def _register_tool_with_arcade(
+ self, tool_definition: Dict[str, Any]
+ ) -> Dict[str, Any]:
"""Register a tool with Arcade.dev platform."""
if self.config.demo_mode:
# Generate mock registration response
@@ -465,7 +493,7 @@ async def _register_tool_with_arcade(self, tool_definition: Dict[str, Any]) -> D
"status": "registered",
"version": tool_definition["version"],
"created_at": "2025-05-25T19:33:00Z",
- "_demo_mode": True
+ "_demo_mode": True,
}
else:
# For real mode, would use actual Arcade.dev API
@@ -475,7 +503,7 @@ async def _register_tool_with_arcade(self, tool_definition: Dict[str, Any]) -> D
"id": f"real_tool_{tool_definition['name'].replace('.', '_').lower()}",
"name": tool_definition["name"],
"status": "registered",
- "version": tool_definition["version"]
+ "version": tool_definition["version"],
}
async def _get_tool_info(self, tool_name: str) -> Optional[Dict[str, Any]]:
@@ -487,7 +515,7 @@ async def _get_tool_info(self, tool_name: str) -> Optional[Dict[str, Any]]:
"name": tool_name,
"version": "1.0.0",
"status": "active",
- "_demo_mode": True
+ "_demo_mode": True,
}
else:
# For real mode, would query actual API
@@ -513,29 +541,28 @@ async def _delete_tool(self, tool_name: str) -> bool:
"type": "string",
"description": "Text content to process",
"minLength": 1,
- "maxLength": 10000
+ "maxLength": 10000,
},
"analysis_type": {
"type": "string",
"description": "Type of analysis to perform",
"enum": ["sentiment", "keywords", "summary", "entities"],
- "default": "summary"
- }
+ "default": "summary",
+ },
},
requires_auth=True,
- timeout_seconds=30
+ timeout_seconds=30,
)
-def process_text_content(content: str, analysis_type: str = "summary") -> Dict[str, Any]:
+def process_text_content(
+ content: str, analysis_type: str = "summary"
+) -> Dict[str, Any]:
"""Process text content with specified analysis type."""
# Mock implementation for example
return {
"content_length": len(content),
"analysis_type": analysis_type,
"result": f"Processed {len(content)} characters with {analysis_type} analysis",
- "metadata": {
- "processing_time_ms": 150,
- "word_count": len(content.split())
- }
+ "metadata": {"processing_time_ms": 150, "word_count": len(content.split())},
}
@@ -543,37 +570,36 @@ def process_text_content(content: str, analysis_type: str = "summary") -> Dict[s
name="Data_Transform",
description="Transform data between different formats and structures",
parameters={
- "data": {
- "type": "object",
- "description": "Data to transform"
- },
+ "data": {"type": "object", "description": "Data to transform"},
"target_format": {
"type": "string",
"description": "Target format for transformation",
- "enum": ["json", "csv", "xml", "yaml"]
+ "enum": ["json", "csv", "xml", "yaml"],
},
"options": {
"type": "object",
"description": "Transformation options",
"required": False,
- "default": {}
- }
+ "default": {},
+ },
},
requires_auth=False,
- timeout_seconds=60
+ timeout_seconds=60,
)
-def transform_data(data: Dict[str, Any], target_format: str, options: Dict[str, Any] = None) -> Dict[str, Any]:
+def transform_data(
+ data: Dict[str, Any], target_format: str, options: Dict[str, Any] = None
+) -> Dict[str, Any]:
"""Transform data to specified format."""
if options is None:
options = {}
-
+
# Mock implementation for example
return {
"original_format": "object",
"target_format": target_format,
"transformed": True,
"options_applied": options,
- "transformation_id": "tx_123456"
+ "transformation_id": "tx_123456",
}
@@ -586,70 +612,77 @@ async def main():
# Configure logging
logging.basicConfig(
level=logging.INFO,
- format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
+ format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
)
-
+
print("š FACT Tool Registration with Arcade.dev Example")
print("=" * 55)
-
+
# Example tools are automatically registered by @Tool decorator
print("\nš§ Checking registered FACT tools...")
registry = get_tool_registry()
print(f"ā
Found {len(registry.list_tools())} registered tools")
-
+
# Load configuration from environment
arcade_api_key = os.getenv("ARCADE_API_KEY", "")
demo_mode = not arcade_api_key or not ARCADE_SDK_AVAILABLE
-
+
if demo_mode:
print("š Running in DEMO MODE")
print(" - No ARCADE_API_KEY found or arcadepy SDK not available")
print(" - Will simulate tool registration with mock responses")
print(" - To use real API: install arcadepy and set ARCADE_API_KEY")
print()
-
+
config = ToolRegistrationConfig(
arcade_api_key=arcade_api_key or "demo_key",
arcade_base_url=os.getenv("ARCADE_BASE_URL", "https://api.arcade.dev"),
workspace_id=os.getenv("ARCADE_WORKSPACE_ID"),
default_timeout=int(os.getenv("ARCADE_TIMEOUT", "30")),
- requires_auth_by_default=os.getenv("ARCADE_REQUIRE_AUTH", "true").lower() == "true",
+ requires_auth_by_default=os.getenv("ARCADE_REQUIRE_AUTH", "true").lower()
+ == "true",
demo_mode=demo_mode,
- user_id=os.getenv("ARCADE_USER_ID", "demo@example.com")
+ user_id=os.getenv("ARCADE_USER_ID", "demo@example.com"),
)
-
+
# Initialize tool registrar
registrar = FactToolRegistrar(config)
-
+
try:
# Connect to Arcade.dev
print("\nš Connecting to Arcade.dev platform...")
await registrar.connect()
print("ā
Connected successfully")
-
+
# List existing tools
print("\nš Listing currently registered tools...")
existing_tools = await registrar.list_registered_tools()
print(f"ā
Found {len(existing_tools)} existing tools")
-
+
# Register all FACT tools
print("\nš§ Registering FACT tools with Arcade.dev...")
registration_summary = await registrar.register_all_tools()
-
+
# Display results
print("\nš Registration Summary:")
print(f" Total tools: {registration_summary['total_tools']}")
print(f" Successful: {registration_summary['successful']}")
print(f" Failed: {registration_summary['failed']}")
print(f" Skipped: {registration_summary['skipped']}")
-
+
# Show detailed results
- if registration_summary['results']:
+ if registration_summary["results"]:
print("\nš Detailed Results:")
- for result in registration_summary['results']:
- status_icon = "ā
" if result['status'] == 'success' else "ā" if result['status'] == 'failed' else "āļø"
- print(f" {status_icon} {result['tool_name']}: {result.get('message', result['status'])}")
-
+ for result in registration_summary["results"]:
+ status_icon = (
+ "ā
"
+ if result["status"] == "success"
+ else "ā" if result["status"] == "failed" else "āļø"
+ )
+ print(
+ f" {status_icon} {result['tool_name']}: {result.get('message', result['status'])}"
+ )
+
# Demonstrate permission updates
print("\nš Demonstrating permission updates...")
success = await registrar.update_tool_permissions(
@@ -657,30 +690,32 @@ async def main():
{
"require_authentication": True,
"allowed_scopes": ["tool:execute", "data:read"],
- "rate_limit": {"calls_per_minute": 30}
- }
+ "rate_limit": {"calls_per_minute": 30},
+ },
)
print(f"ā
Permission update {'successful' if success else 'failed'}")
-
+
# Final tool list
print("\nš Final tool list:")
final_tools = await registrar.list_registered_tools()
for tool in final_tools:
auth_status = "š" if tool.get("requires_auth") else "š"
- print(f" {auth_status} {tool.get('name')} (v{tool.get('version', 'unknown')})")
-
+ print(
+ f" {auth_status} {tool.get('name')} (v{tool.get('version', 'unknown')})"
+ )
+
except Exception as e:
print(f"ā Error during tool registration: {e}")
return 1
-
+
finally:
# Clean up
await registrar.close()
-
+
print("\nš Tool registration example completed successfully!")
return 0
if __name__ == "__main__":
exit_code = asyncio.run(main())
- sys.exit(exit_code)
\ No newline at end of file
+ sys.exit(exit_code)
diff --git a/examples/arcade-dev/03_intelligent_routing/hybrid_execution.py b/examples/arcade-dev/03_intelligent_routing/hybrid_execution.py
index 693c4b7..4f5e087 100644
--- a/examples/arcade-dev/03_intelligent_routing/hybrid_execution.py
+++ b/examples/arcade-dev/03_intelligent_routing/hybrid_execution.py
@@ -34,18 +34,22 @@
# Create alias for compatibility
ArcadeClient = BasicArcadeClient
+
# Define classes that might not exist in the actual FACT implementation
@dataclass
class ToolCall:
"""Tool call data structure."""
+
id: str
name: str
arguments: Dict[str, Any]
user_id: Optional[str] = None
+
@dataclass
class ToolResult:
"""Tool execution result."""
+
call_id: str
tool_name: str
success: bool
@@ -55,17 +59,18 @@ class ToolResult:
status_code: int = 200
metadata: Optional[Dict[str, Any]] = None
+
# Mock ToolExecutor for this example
class MockToolExecutor:
"""Mock tool executor for local execution."""
-
+
def __init__(self):
self.registry = {}
-
+
async def execute_tool_call(self, tool_call: ToolCall) -> ToolResult:
"""Execute a tool call locally."""
start_time = time.time()
-
+
# Check if we have a registered function for this tool
tool_func = self.registry.get(tool_call.name)
if not tool_func:
@@ -75,14 +80,14 @@ async def execute_tool_call(self, tool_call: ToolCall) -> ToolResult:
success=False,
error=f"Tool '{tool_call.name}' not found",
execution_time_ms=(time.time() - start_time) * 1000,
- status_code=404
+ status_code=404,
)
-
+
try:
# Execute the tool function
result = tool_func(**tool_call.arguments)
execution_time = (time.time() - start_time) * 1000
-
+
return ToolResult(
call_id=tool_call.id,
tool_name=tool_call.name,
@@ -90,9 +95,9 @@ async def execute_tool_call(self, tool_call: ToolCall) -> ToolResult:
data=result,
execution_time_ms=execution_time,
status_code=200,
- metadata={"execution_mode": "local"}
+ metadata={"execution_mode": "local"},
)
-
+
except Exception as e:
execution_time = (time.time() - start_time) * 1000
return ToolResult(
@@ -101,19 +106,21 @@ async def execute_tool_call(self, tool_call: ToolCall) -> ToolResult:
success=False,
error=str(e),
execution_time_ms=execution_time,
- status_code=500
+ status_code=500,
)
-
+
def register_tool(self, name: str, func):
"""Register a tool function."""
self.registry[name] = func
+
# Replace ToolExecutor with our mock
ToolExecutor = MockToolExecutor
class ExecutionMode(Enum):
"""Execution mode enumeration."""
+
LOCAL = "local"
REMOTE = "remote"
HYBRID = "hybrid"
@@ -123,6 +130,7 @@ class ExecutionMode(Enum):
@dataclass
class RoutingRule:
"""Rule for routing tool execution."""
+
tool_pattern: str # Tool name pattern (supports wildcards)
preferred_mode: ExecutionMode
conditions: Dict[str, Any] # Conditions for applying this rule
@@ -132,13 +140,14 @@ class RoutingRule:
@dataclass
class ExecutionMetrics:
"""Metrics for execution performance tracking."""
+
mode: ExecutionMode
execution_time_ms: float
success: bool
error_type: Optional[str] = None
tool_name: str = ""
timestamp: float = 0.0
-
+
def __post_init__(self):
if self.timestamp == 0.0:
self.timestamp = time.time()
@@ -147,17 +156,19 @@ def __post_init__(self):
class IntelligentRouter:
"""
Intelligent router for deciding between local and remote tool execution.
-
+
Makes routing decisions based on tool characteristics, performance history,
network conditions, and configurable rules.
"""
-
- def __init__(self,
- arcade_client: Optional[ArcadeClient] = None,
- cache_manager: Optional[CacheManager] = None):
+
+ def __init__(
+ self,
+ arcade_client: Optional[ArcadeClient] = None,
+ cache_manager: Optional[CacheManager] = None,
+ ):
"""
Initialize intelligent router.
-
+
Args:
arcade_client: Arcade.dev client for remote execution
cache_manager: Cache manager for performance data
@@ -167,52 +178,56 @@ def __init__(self,
self.local_executor = ToolExecutor()
self.metrics_collector = MetricsCollector()
self.logger = logging.getLogger(__name__)
-
+
# Routing configuration
self.routing_rules: List[RoutingRule] = []
self.execution_history: List[ExecutionMetrics] = []
self.performance_cache: Dict[str, Dict[str, float]] = {}
-
+
# Performance thresholds
self.local_timeout_threshold = 5.0 # seconds
self.remote_timeout_threshold = 30.0 # seconds
self.network_latency_threshold = 2.0 # seconds
-
+
# Initialize default routing rules
self._setup_default_routing_rules()
-
+
def add_routing_rule(self, rule: RoutingRule) -> None:
"""
Add a routing rule to the decision engine.
-
+
Args:
rule: Routing rule to add
"""
self.routing_rules.append(rule)
# Sort by priority (highest first)
self.routing_rules.sort(key=lambda r: r.priority, reverse=True)
-
- self.logger.info(f"Added routing rule for pattern '{rule.tool_pattern}' "
- f"with mode {rule.preferred_mode.value}")
-
+
+ self.logger.info(
+ f"Added routing rule for pattern '{rule.tool_pattern}' "
+ f"with mode {rule.preferred_mode.value}"
+ )
+
async def execute_tool(self, tool_call: ToolCall) -> ToolResult:
"""
Execute tool with intelligent routing decision.
-
+
Args:
tool_call: Tool call to execute
-
+
Returns:
Tool execution result
"""
start_time = time.time()
-
+
try:
# Determine optimal execution mode
execution_mode = await self._determine_execution_mode(tool_call)
-
- self.logger.info(f"Executing tool '{tool_call.name}' using {execution_mode.value} mode")
-
+
+ self.logger.info(
+ f"Executing tool '{tool_call.name}' using {execution_mode.value} mode"
+ )
+
# Execute based on determined mode
if execution_mode == ExecutionMode.LOCAL:
result = await self._execute_locally(tool_call)
@@ -222,34 +237,38 @@ async def execute_tool(self, tool_call: ToolCall) -> ToolResult:
result = await self._execute_hybrid(tool_call)
else: # AUTO mode
result = await self._execute_auto(tool_call)
-
+
# Record successful execution metrics
execution_time = (time.time() - start_time) * 1000
metrics = ExecutionMetrics(
mode=execution_mode,
execution_time_ms=execution_time,
success=result.success,
- tool_name=tool_call.name
+ tool_name=tool_call.name,
)
-
+
await self._record_execution_metrics(metrics)
-
+
return result
-
+
except Exception as e:
execution_time = (time.time() - start_time) * 1000
-
+
# Record failed execution metrics
metrics = ExecutionMetrics(
- mode=execution_mode if 'execution_mode' in locals() else ExecutionMode.LOCAL,
+ mode=(
+ execution_mode
+ if "execution_mode" in locals()
+ else ExecutionMode.LOCAL
+ ),
execution_time_ms=execution_time,
success=False,
error_type=type(e).__name__,
- tool_name=tool_call.name
+ tool_name=tool_call.name,
)
-
+
await self._record_execution_metrics(metrics)
-
+
# Create error result
return ToolResult(
call_id=tool_call.id,
@@ -257,16 +276,16 @@ async def execute_tool(self, tool_call: ToolCall) -> ToolResult:
success=False,
error=str(e),
execution_time_ms=execution_time,
- status_code=500
+ status_code=500,
)
-
+
async def _determine_execution_mode(self, tool_call: ToolCall) -> ExecutionMode:
"""
Determine the optimal execution mode for a tool call.
-
+
Args:
tool_call: Tool call to analyze
-
+
Returns:
Optimal execution mode
"""
@@ -274,27 +293,31 @@ async def _determine_execution_mode(self, tool_call: ToolCall) -> ExecutionMode:
for rule in self.routing_rules:
if self._matches_rule(tool_call.name, rule):
if await self._evaluate_rule_conditions(tool_call, rule):
- self.logger.debug(f"Tool '{tool_call.name}' matched rule: {rule.preferred_mode.value}")
+ self.logger.debug(
+ f"Tool '{tool_call.name}' matched rule: {rule.preferred_mode.value}"
+ )
return rule.preferred_mode
-
+
# Performance-based decision
performance_score = await self._calculate_performance_score(tool_call.name)
-
+
# Network condition check
network_latency = await self._check_network_latency()
-
+
# Tool complexity analysis
complexity_score = self._analyze_tool_complexity(tool_call)
-
+
# Make decision based on weighted factors
decision_score = (
- performance_score * 0.4 +
- (1.0 - min(network_latency / self.network_latency_threshold, 1.0)) * 0.3 +
- complexity_score * 0.3
+ performance_score * 0.4
+ + (1.0 - min(network_latency / self.network_latency_threshold, 1.0)) * 0.3
+ + complexity_score * 0.3
)
-
- self.logger.debug(f"Decision score for '{tool_call.name}': {decision_score:.3f}")
-
+
+ self.logger.debug(
+ f"Decision score for '{tool_call.name}': {decision_score:.3f}"
+ )
+
# Decision thresholds
if decision_score > 0.7:
return ExecutionMode.LOCAL
@@ -302,7 +325,7 @@ async def _determine_execution_mode(self, tool_call: ToolCall) -> ExecutionMode:
return ExecutionMode.REMOTE
else:
return ExecutionMode.HYBRID
-
+
async def _execute_locally(self, tool_call: ToolCall) -> ToolResult:
"""Execute tool locally using FACT executor."""
try:
@@ -314,23 +337,24 @@ async def _execute_locally(self, tool_call: ToolCall) -> ToolResult:
self.logger.info("Falling back to remote execution")
return await self._execute_remotely(tool_call)
raise
-
+
async def _execute_remotely(self, tool_call: ToolCall) -> ToolResult:
"""Execute tool remotely via Arcade.dev."""
if not self.arcade_client:
- raise RuntimeError("Remote execution not available: Arcade client not configured")
-
+ raise RuntimeError(
+ "Remote execution not available: Arcade client not configured"
+ )
+
try:
start_time = time.time()
-
+
# Execute via Arcade.dev
arcade_result = await self.arcade_client.execute_tool(
- tool_name=tool_call.name,
- tool_input=tool_call.arguments
+ tool_name=tool_call.name, tool_input=tool_call.arguments
)
-
+
execution_time = (time.time() - start_time) * 1000
-
+
# Convert Arcade result to ToolResult format
success = arcade_result.get("status") != "failed"
return ToolResult(
@@ -344,35 +368,34 @@ async def _execute_remotely(self, tool_call: ToolCall) -> ToolResult:
metadata={
"execution_mode": "remote",
"arcade_execution_time": arcade_result.get("execution_time_ms"),
- "user_id": tool_call.user_id
- }
+ "user_id": tool_call.user_id,
+ },
)
-
+
except Exception as e:
self.logger.warning(f"Remote execution failed for '{tool_call.name}': {e}")
# Fallback to local if possible
self.logger.info("Falling back to local execution")
return await self._execute_locally(tool_call)
-
+
async def _execute_hybrid(self, tool_call: ToolCall) -> ToolResult:
"""
Execute tool using hybrid approach (race between local and remote).
-
+
Returns the first successful result, cancels the other.
"""
self.logger.info(f"Executing '{tool_call.name}' in hybrid mode")
-
+
# Create tasks for both execution modes
local_task = asyncio.create_task(self._execute_locally(tool_call))
remote_task = asyncio.create_task(self._execute_remotely(tool_call))
-
+
try:
# Wait for first completion
done, pending = await asyncio.wait(
- [local_task, remote_task],
- return_when=asyncio.FIRST_COMPLETED
+ [local_task, remote_task], return_when=asyncio.FIRST_COMPLETED
)
-
+
# Cancel pending tasks
for task in pending:
task.cancel()
@@ -380,24 +403,24 @@ async def _execute_hybrid(self, tool_call: ToolCall) -> ToolResult:
await task
except asyncio.CancelledError:
pass
-
+
# Get result from completed task
completed_task = done.pop()
result = await completed_task
-
+
# Determine which mode completed first
winning_mode = "local" if completed_task == local_task else "remote"
-
+
# Add hybrid execution metadata
if result.metadata is None:
result.metadata = {}
result.metadata["execution_mode"] = "hybrid"
result.metadata["winning_mode"] = winning_mode
-
+
self.logger.info(f"Hybrid execution completed: {winning_mode} mode won")
-
+
return result
-
+
except Exception as e:
# Cancel any remaining tasks
for task in [local_task, remote_task]:
@@ -407,21 +430,21 @@ async def _execute_hybrid(self, tool_call: ToolCall) -> ToolResult:
await task
except asyncio.CancelledError:
pass
-
+
raise RuntimeError(f"Hybrid execution failed: {e}")
-
+
async def _execute_auto(self, tool_call: ToolCall) -> ToolResult:
"""
Execute tool in auto mode with intelligent retry and fallback.
"""
# Start with performance-based decision
performance_data = await self._get_performance_data(tool_call.name)
-
+
if performance_data:
# Use historical performance to decide
- local_avg = performance_data.get("local_avg_ms", float('inf'))
- remote_avg = performance_data.get("remote_avg_ms", float('inf'))
-
+ local_avg = performance_data.get("local_avg_ms", float("inf"))
+ remote_avg = performance_data.get("remote_avg_ms", float("inf"))
+
if local_avg < remote_avg * 1.5: # Local is significantly faster
primary_mode = ExecutionMode.LOCAL
fallback_mode = ExecutionMode.REMOTE
@@ -432,7 +455,7 @@ async def _execute_auto(self, tool_call: ToolCall) -> ToolResult:
# No historical data, default to local first
primary_mode = ExecutionMode.LOCAL
fallback_mode = ExecutionMode.REMOTE
-
+
# Try primary mode first
try:
if primary_mode == ExecutionMode.LOCAL:
@@ -441,7 +464,7 @@ async def _execute_auto(self, tool_call: ToolCall) -> ToolResult:
return await self._execute_remotely(tool_call)
except Exception as e:
self.logger.warning(f"Primary mode ({primary_mode.value}) failed: {e}")
-
+
# Try fallback mode
try:
if fallback_mode == ExecutionMode.LOCAL:
@@ -449,18 +472,25 @@ async def _execute_auto(self, tool_call: ToolCall) -> ToolResult:
else:
return await self._execute_remotely(tool_call)
except Exception as fallback_error:
- self.logger.error(f"Fallback mode ({fallback_mode.value}) also failed: {fallback_error}")
- raise RuntimeError(f"Both execution modes failed: {e}, {fallback_error}")
-
+ self.logger.error(
+ f"Fallback mode ({fallback_mode.value}) also failed: {fallback_error}"
+ )
+ raise RuntimeError(
+ f"Both execution modes failed: {e}, {fallback_error}"
+ )
+
def _matches_rule(self, tool_name: str, rule: RoutingRule) -> bool:
"""Check if tool name matches routing rule pattern."""
import fnmatch
+
return fnmatch.fnmatch(tool_name, rule.tool_pattern)
-
- async def _evaluate_rule_conditions(self, tool_call: ToolCall, rule: RoutingRule) -> bool:
+
+ async def _evaluate_rule_conditions(
+ self, tool_call: ToolCall, rule: RoutingRule
+ ) -> bool:
"""Evaluate whether rule conditions are met."""
conditions = rule.conditions
-
+
# Check user-based conditions
if "user_id" in conditions:
allowed_users = conditions["user_id"]
@@ -469,86 +499,87 @@ async def _evaluate_rule_conditions(self, tool_call: ToolCall, rule: RoutingRule
return False
elif tool_call.user_id != allowed_users:
return False
-
+
# Check argument-based conditions
if "argument_size" in conditions:
max_size = conditions["argument_size"]
args_size = len(str(tool_call.arguments))
if args_size > max_size:
return False
-
+
# Check time-based conditions
if "time_window" in conditions:
import datetime
+
time_window = conditions["time_window"]
current_hour = datetime.datetime.now().hour
-
+
if isinstance(time_window, dict):
start_hour = time_window.get("start", 0)
end_hour = time_window.get("end", 23)
if not (start_hour <= current_hour <= end_hour):
return False
-
+
return True
-
+
async def _calculate_performance_score(self, tool_name: str) -> float:
"""Calculate performance score for a tool (0.0 to 1.0, higher is better for local)."""
if tool_name in self.performance_cache:
data = self.performance_cache[tool_name]
local_avg = data.get("local_avg_ms", 1000)
remote_avg = data.get("remote_avg_ms", 2000)
-
+
# Score based on relative performance
if remote_avg > 0:
return min(remote_avg / (local_avg + remote_avg), 1.0)
-
+
# Default score if no performance data
return 0.5
-
+
async def _check_network_latency(self) -> float:
"""Check current network latency to Arcade.dev."""
if not self.arcade_client:
- return float('inf')
-
+ return float("inf")
+
try:
start_time = time.time()
# Simple ping-like check - use health check instead
await self.arcade_client.health_check()
return time.time() - start_time
except:
- return float('inf')
-
+ return float("inf")
+
def _analyze_tool_complexity(self, tool_call: ToolCall) -> float:
"""Analyze tool complexity to inform routing decision."""
complexity_score = 0.0
-
+
# Argument complexity
args_str = str(tool_call.arguments)
if len(args_str) > 1000:
complexity_score += 0.3
elif len(args_str) > 100:
complexity_score += 0.1
-
+
# Tool name patterns indicating complexity
complex_patterns = ["AI.", "ML.", "Analysis.", "Transform."]
simple_patterns = ["Util.", "Helper.", "Cache."]
-
+
for pattern in complex_patterns:
if tool_call.name.startswith(pattern):
complexity_score += 0.4
break
-
+
for pattern in simple_patterns:
if tool_call.name.startswith(pattern):
complexity_score -= 0.2
break
-
+
return max(0.0, min(1.0, complexity_score))
-
+
async def _record_execution_metrics(self, metrics: ExecutionMetrics) -> None:
"""Record execution metrics for future routing decisions."""
self.execution_history.append(metrics)
-
+
# Update performance cache
tool_name = metrics.tool_name
if tool_name not in self.performance_cache:
@@ -556,47 +587,53 @@ async def _record_execution_metrics(self, metrics: ExecutionMetrics) -> None:
"local_avg_ms": 0.0,
"remote_avg_ms": 0.0,
"local_count": 0,
- "remote_count": 0
+ "remote_count": 0,
}
-
+
cache_data = self.performance_cache[tool_name]
-
+
if metrics.mode == ExecutionMode.LOCAL and metrics.success:
count = cache_data["local_count"]
avg = cache_data["local_avg_ms"]
- cache_data["local_avg_ms"] = (avg * count + metrics.execution_time_ms) / (count + 1)
+ cache_data["local_avg_ms"] = (avg * count + metrics.execution_time_ms) / (
+ count + 1
+ )
cache_data["local_count"] = count + 1
elif metrics.mode == ExecutionMode.REMOTE and metrics.success:
count = cache_data["remote_count"]
avg = cache_data["remote_avg_ms"]
- cache_data["remote_avg_ms"] = (avg * count + metrics.execution_time_ms) / (count + 1)
+ cache_data["remote_avg_ms"] = (avg * count + metrics.execution_time_ms) / (
+ count + 1
+ )
cache_data["remote_count"] = count + 1
-
+
# Record with metrics collector
self.metrics_collector.record_tool_execution(
tool_name=metrics.tool_name,
success=metrics.success,
execution_time=metrics.execution_time_ms,
- metadata={"execution_mode": metrics.mode.value}
+ metadata={"execution_mode": metrics.mode.value},
)
-
+
# Cache performance data if cache manager available
if self.cache_manager:
cache_key = f"routing:performance:{tool_name}"
# CacheManager uses store method, not set
try:
import json
+
self.cache_manager.store(cache_key, json.dumps(cache_data))
except Exception as e:
import structlog
+
logger = structlog.get_logger(__name__)
logger.warning(f"Failed to cache performance data: {e}")
-
+
async def _get_performance_data(self, tool_name: str) -> Optional[Dict[str, float]]:
"""Get cached performance data for a tool."""
if tool_name in self.performance_cache:
return self.performance_cache[tool_name]
-
+
if self.cache_manager:
cache_key = f"routing:performance:{tool_name}"
try:
@@ -604,50 +641,60 @@ async def _get_performance_data(self, tool_name: str) -> Optional[Dict[str, floa
cached_entry = self.cache_manager.retrieve(cache_key)
if cached_entry:
import json
+
cached_data = json.loads(cached_entry.content)
self.performance_cache[tool_name] = cached_data
return cached_data
except Exception as e:
import structlog
+
logger = structlog.get_logger(__name__)
logger.warning(f"Failed to retrieve cached performance data: {e}")
-
+
return None
-
+
def _setup_default_routing_rules(self) -> None:
"""Set up default routing rules."""
# High-performance tools should run locally
- self.add_routing_rule(RoutingRule(
- tool_pattern="Cache.*",
- preferred_mode=ExecutionMode.LOCAL,
- conditions={"argument_size": 1000}, # Small arguments only
- priority=100
- ))
-
+ self.add_routing_rule(
+ RoutingRule(
+ tool_pattern="Cache.*",
+ preferred_mode=ExecutionMode.LOCAL,
+ conditions={"argument_size": 1000}, # Small arguments only
+ priority=100,
+ )
+ )
+
# AI/ML tools prefer remote execution (more resources)
- self.add_routing_rule(RoutingRule(
- tool_pattern="AI.*",
- preferred_mode=ExecutionMode.REMOTE,
- conditions={},
- priority=90
- ))
-
+ self.add_routing_rule(
+ RoutingRule(
+ tool_pattern="AI.*",
+ preferred_mode=ExecutionMode.REMOTE,
+ conditions={},
+ priority=90,
+ )
+ )
+
# Analysis tools use hybrid approach
- self.add_routing_rule(RoutingRule(
- tool_pattern="Analysis.*",
- preferred_mode=ExecutionMode.HYBRID,
- conditions={},
- priority=80
- ))
-
+ self.add_routing_rule(
+ RoutingRule(
+ tool_pattern="Analysis.*",
+ preferred_mode=ExecutionMode.HYBRID,
+ conditions={},
+ priority=80,
+ )
+ )
+
# Utility tools run locally
- self.add_routing_rule(RoutingRule(
- tool_pattern="Util.*",
- preferred_mode=ExecutionMode.LOCAL,
- conditions={},
- priority=70
- ))
-
+ self.add_routing_rule(
+ RoutingRule(
+ tool_pattern="Util.*",
+ preferred_mode=ExecutionMode.LOCAL,
+ conditions={},
+ priority=70,
+ )
+ )
+
def get_performance_summary(self) -> Dict[str, Any]:
"""Get performance summary for all tools."""
summary = {
@@ -655,27 +702,27 @@ def get_performance_summary(self) -> Dict[str, Any]:
"tools_tracked": len(self.performance_cache),
"mode_distribution": {},
"average_execution_times": {},
- "success_rates": {}
+ "success_rates": {},
}
-
+
# Calculate mode distribution
mode_counts = {}
for metrics in self.execution_history:
mode = metrics.mode.value
mode_counts[mode] = mode_counts.get(mode, 0) + 1
-
+
total_executions = len(self.execution_history)
if total_executions > 0:
for mode, count in mode_counts.items():
summary["mode_distribution"][mode] = count / total_executions
-
+
# Calculate average execution times and success rates
for tool_name, cache_data in self.performance_cache.items():
summary["average_execution_times"][tool_name] = {
"local_ms": cache_data.get("local_avg_ms", 0),
- "remote_ms": cache_data.get("remote_avg_ms", 0)
+ "remote_ms": cache_data.get("remote_avg_ms", 0),
}
-
+
return summary
@@ -683,10 +730,8 @@ def get_performance_summary(self) -> Dict[str, Any]:
@Tool(
name="Cache_FastLookup",
description="Fast cache lookup operation",
- parameters={
- "key": {"type": "string", "description": "Cache key to lookup"}
- },
- timeout_seconds=5
+ parameters={"key": {"type": "string", "description": "Cache key to lookup"}},
+ timeout_seconds=5,
)
def fast_cache_lookup(key: str) -> Dict[str, Any]:
"""Fast local cache operation."""
@@ -699,9 +744,9 @@ def fast_cache_lookup(key: str) -> Dict[str, Any]:
description="Complex AI analysis requiring significant resources",
parameters={
"data": {"type": "object", "description": "Data to analyze"},
- "model_type": {"type": "string", "description": "AI model to use"}
+ "model_type": {"type": "string", "description": "AI model to use"},
},
- timeout_seconds=60
+ timeout_seconds=60,
)
def complex_ai_analysis(data: Dict[str, Any], model_type: str) -> Dict[str, Any]:
"""Complex AI analysis operation."""
@@ -709,24 +754,22 @@ def complex_ai_analysis(data: Dict[str, Any], model_type: str) -> Dict[str, Any]
return {
"analysis_result": f"AI analysis with {model_type}",
"processed_items": len(data.get("items", [])),
- "confidence": 0.95
+ "confidence": 0.95,
}
@Tool(
name="Util_StringHelper",
description="Simple string utility function",
- parameters={
- "text": {"type": "string", "description": "Text to process"}
- },
- timeout_seconds=10
+ parameters={"text": {"type": "string", "description": "Text to process"}},
+ timeout_seconds=10,
)
def string_helper(text: str) -> Dict[str, Any]:
"""Simple string processing utility."""
return {
"length": len(text),
"word_count": len(text.split()),
- "uppercase": text.upper()
+ "uppercase": text.upper(),
}
@@ -735,12 +778,12 @@ async def main():
# Configure logging
logging.basicConfig(
level=logging.INFO,
- format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
+ format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
)
-
+
print("š§ Intelligent Routing and Hybrid Execution Example")
print("=" * 53)
-
+
# Initialize components
cache_config = {
"prefix": "intelligent_routing",
@@ -748,10 +791,10 @@ async def main():
"max_size": "10MB",
"ttl_seconds": 3600,
"hit_target_ms": 48,
- "miss_target_ms": 140
+ "miss_target_ms": 140,
}
cache_manager = CacheManager(cache_config)
-
+
# Initialize Arcade client if configured
arcade_client = None
arcade_api_key = os.getenv("ARCADE_API_KEY")
@@ -759,7 +802,7 @@ async def main():
config = ArcadeConfig(
api_key=arcade_api_key,
user_id=os.getenv("ARCADE_USER_ID", "demo@example.com"),
- demo_mode=False
+ demo_mode=False,
)
arcade_client = ArcadeClient(config, cache_manager)
try:
@@ -773,50 +816,52 @@ async def main():
print("ā¹ļø No Arcade API key configured, using demo mode")
# Create a demo mode client for testing
config = ArcadeConfig(
- api_key="demo_key",
- user_id="demo@example.com",
- demo_mode=True
+ api_key="demo_key", user_id="demo@example.com", demo_mode=True
)
arcade_client = ArcadeClient(config, cache_manager)
await arcade_client.connect()
print("ā
Using demo mode for testing intelligent routing")
-
+
# Initialize intelligent router
router = IntelligentRouter(arcade_client, cache_manager)
-
+
# Register tools with the local executor
router.local_executor.register_tool("Cache_FastLookup", fast_cache_lookup)
router.local_executor.register_tool("AI_ComplexAnalysis", complex_ai_analysis)
router.local_executor.register_tool("Util_StringHelper", string_helper)
-
+
# Add custom routing rules
print("\nš Setting up custom routing rules...")
-
- router.add_routing_rule(RoutingRule(
- tool_pattern="Cache.*",
- preferred_mode=ExecutionMode.LOCAL,
- conditions={"argument_size": 500},
- priority=100
- ))
-
- router.add_routing_rule(RoutingRule(
- tool_pattern="AI.*",
- preferred_mode=ExecutionMode.REMOTE if arcade_client else ExecutionMode.LOCAL,
- conditions={},
- priority=90
- ))
-
+
+ router.add_routing_rule(
+ RoutingRule(
+ tool_pattern="Cache.*",
+ preferred_mode=ExecutionMode.LOCAL,
+ conditions={"argument_size": 500},
+ priority=100,
+ )
+ )
+
+ router.add_routing_rule(
+ RoutingRule(
+ tool_pattern="AI.*",
+ preferred_mode=(
+ ExecutionMode.REMOTE if arcade_client else ExecutionMode.LOCAL
+ ),
+ conditions={},
+ priority=90,
+ )
+ )
+
print("ā
Custom routing rules configured")
-
+
# Test different execution scenarios
test_cases = [
{
"name": "Fast cache lookup",
"tool_call": ToolCall(
- id="test_1",
- name="Cache_FastLookup",
- arguments={"key": "user_123"}
- )
+ id="test_1", name="Cache_FastLookup", arguments={"key": "user_123"}
+ ),
},
{
"name": "Complex AI analysis",
@@ -825,38 +870,42 @@ async def main():
name="AI_ComplexAnalysis",
arguments={
"data": {"items": list(range(100))},
- "model_type": "neural_network"
- }
- )
+ "model_type": "neural_network",
+ },
+ ),
},
{
"name": "Simple string utility",
"tool_call": ToolCall(
id="test_3",
name="Util_StringHelper",
- arguments={"text": "Hello, intelligent routing world!"}
- )
- }
+ arguments={"text": "Hello, intelligent routing world!"},
+ ),
+ },
]
-
+
print("\nš Executing test cases with intelligent routing...")
-
+
for i, test_case in enumerate(test_cases, 1):
print(f"\n{i}. {test_case['name']}")
- print("-" * (len(test_case['name']) + 3))
-
+ print("-" * (len(test_case["name"]) + 3))
+
try:
start_time = time.time()
- result = await router.execute_tool(test_case['tool_call'])
+ result = await router.execute_tool(test_case["tool_call"])
total_time = (time.time() - start_time) * 1000
-
+
if result.success:
execution_mode = result.metadata.get("execution_mode", "unknown")
winning_mode = result.metadata.get("winning_mode", "")
-
- print(f"ā
Success ({execution_mode} mode{f', {winning_mode} won' if winning_mode else ''})")
- print(f" Execution time: {result.execution_time_ms:.1f}ms (total: {total_time:.1f}ms)")
-
+
+ print(
+ f"ā
Success ({execution_mode} mode{f', {winning_mode} won' if winning_mode else ''})"
+ )
+ print(
+ f" Execution time: {result.execution_time_ms:.1f}ms (total: {total_time:.1f}ms)"
+ )
+
if result.data:
data_preview = str(result.data)[:100]
if len(str(result.data)) > 100:
@@ -864,42 +913,42 @@ async def main():
print(f" Result: {data_preview}")
else:
print(f"ā Failed: {result.error}")
-
+
except Exception as e:
print(f"ā Error: {e}")
-
+
# Performance summary
print("\nš Performance Summary")
print("-" * 20)
-
+
summary = router.get_performance_summary()
print(f"Total executions: {summary['total_executions']}")
print(f"Tools tracked: {summary['tools_tracked']}")
-
- if summary['mode_distribution']:
+
+ if summary["mode_distribution"]:
print("\nExecution mode distribution:")
- for mode, percentage in summary['mode_distribution'].items():
+ for mode, percentage in summary["mode_distribution"].items():
print(f" {mode}: {percentage:.1%}")
-
- if summary['average_execution_times']:
+
+ if summary["average_execution_times"]:
print("\nAverage execution times by tool:")
- for tool_name, times in summary['average_execution_times'].items():
- local_time = times['local_ms']
- remote_time = times['remote_ms']
+ for tool_name, times in summary["average_execution_times"].items():
+ local_time = times["local_ms"]
+ remote_time = times["remote_ms"]
print(f" {tool_name}:")
if local_time > 0:
print(f" Local: {local_time:.1f}ms")
if remote_time > 0:
print(f" Remote: {remote_time:.1f}ms")
-
+
# Cleanup
if arcade_client:
await arcade_client.disconnect()
-
+
print("\nš Intelligent routing example completed successfully!")
return 0
if __name__ == "__main__":
exit_code = asyncio.run(main())
- sys.exit(exit_code)
\ No newline at end of file
+ sys.exit(exit_code)
diff --git a/examples/arcade-dev/04_error_handling/resilient_execution.py b/examples/arcade-dev/04_error_handling/resilient_execution.py
index ec3eec8..09d26ec 100644
--- a/examples/arcade-dev/04_error_handling/resilient_execution.py
+++ b/examples/arcade-dev/04_error_handling/resilient_execution.py
@@ -33,83 +33,104 @@
# Mock ToolExecutor for demo mode
class ToolExecutor:
"""Mock ToolExecutor for demo purposes."""
+
def __init__(self):
self.logger = logging.getLogger(__name__)
-
- async def execute_tool(self, tool_name: str, arguments: Dict[str, Any]) -> Dict[str, Any]:
+
+ async def execute_tool(
+ self, tool_name: str, arguments: Dict[str, Any]
+ ) -> Dict[str, Any]:
"""Mock tool execution."""
return {"status": "mock_execution", "tool": tool_name, "args": arguments}
+
try:
from src.tools.decorators import Tool
except ImportError:
+
def Tool(*args, **kwargs):
"""Mock Tool decorator for demo mode."""
+
def decorator(func):
- func._tool_name = kwargs.get('name', func.__name__)
- func._tool_description = kwargs.get('description', '')
- func._tool_parameters = kwargs.get('parameters', {})
+ func._tool_name = kwargs.get("name", func.__name__)
+ func._tool_description = kwargs.get("description", "")
+ func._tool_parameters = kwargs.get("parameters", {})
return func
+
return decorator
+
try:
from src.core.errors import *
except ImportError:
# Define error classes if they don't exist
class ToolExecutionError(Exception):
"""Tool execution error."""
+
pass
class ToolValidationError(Exception):
"""Tool validation error."""
+
pass
class ToolNotFoundError(Exception):
"""Tool not found error."""
+
pass
class UnauthorizedError(Exception):
"""Unauthorized error."""
+
pass
class SecurityError(Exception):
"""Security error."""
+
pass
+
try:
from src.monitoring.metrics import MetricsCollector
except ImportError:
+
class MetricsCollector:
"""Mock metrics collector for demo mode."""
+
def __init__(self):
self.metrics = {}
-
+
def record(self, name, value, tags=None):
"""Record a metric."""
pass
-
+
def increment(self, name, tags=None):
"""Increment a counter."""
pass
+
from src.cache.manager import CacheManager
# Import the BasicArcadeClient from the basic integration example
sys.path.insert(0, str(Path(__file__).parent.parent / "01_basic_integration"))
from basic_arcade_client import BasicArcadeClient, ArcadeConfig
+
# Define classes for error handling demonstration
@dataclass
class ToolCall:
"""Tool call data structure."""
+
id: str
name: str
arguments: Dict[str, Any]
user_id: Optional[str] = None
+
@dataclass
class ToolResult:
"""Tool execution result."""
+
call_id: str
tool_name: str
success: bool
@@ -119,13 +140,16 @@ class ToolResult:
status_code: int = 200
metadata: Optional[Dict[str, Any]] = None
+
class FinalRetryError(Exception):
"""Final retry error when all retries are exhausted."""
+
pass
class ErrorType(Enum):
"""Classification of error types for different handling strategies."""
+
NETWORK_ERROR = "network_error"
AUTHENTICATION_ERROR = "authentication_error"
VALIDATION_ERROR = "validation_error"
@@ -138,6 +162,7 @@ class ErrorType(Enum):
class RetryStrategy(Enum):
"""Retry strategy enumeration."""
+
NO_RETRY = "no_retry"
LINEAR_BACKOFF = "linear_backoff"
EXPONENTIAL_BACKOFF = "exponential_backoff"
@@ -147,6 +172,7 @@ class RetryStrategy(Enum):
@dataclass
class ErrorContext:
"""Context information for error handling decisions."""
+
error: Exception
error_type: ErrorType
tool_name: str
@@ -159,6 +185,7 @@ class ErrorContext:
@dataclass
class RetryConfig:
"""Configuration for retry behavior."""
+
strategy: RetryStrategy
max_attempts: int = 3
base_delay: float = 1.0
@@ -171,6 +198,7 @@ class RetryConfig:
@dataclass
class CircuitBreakerConfig:
"""Configuration for circuit breaker pattern."""
+
failure_threshold: int = 5
success_threshold: int = 3
timeout_seconds: float = 60.0
@@ -179,6 +207,7 @@ class CircuitBreakerConfig:
class CircuitState(Enum):
"""Circuit breaker states."""
+
CLOSED = "closed"
OPEN = "open"
HALF_OPEN = "half_open"
@@ -187,11 +216,11 @@ class CircuitState(Enum):
class CircuitBreaker:
"""
Circuit breaker implementation for preventing cascading failures.
-
+
Tracks failure rates and opens the circuit when threshold is exceeded,
allowing the system to fail fast and recover gracefully.
"""
-
+
def __init__(self, config: CircuitBreakerConfig, name: str = "default"):
"""Initialize circuit breaker."""
self.config = config
@@ -202,24 +231,30 @@ def __init__(self, config: CircuitBreakerConfig, name: str = "default"):
self.last_failure_time = 0.0
self.failure_history: List[float] = []
self.logger = logging.getLogger(f"{__name__}.{name}")
-
+
async def call(self, func: Callable, *args, **kwargs) -> Any:
"""Execute function through circuit breaker protection."""
if self.state == CircuitState.OPEN:
if self._should_attempt_reset():
self.state = CircuitState.HALF_OPEN
- self.logger.info(f"Circuit breaker '{self.name}' entering half-open state")
+ self.logger.info(
+ f"Circuit breaker '{self.name}' entering half-open state"
+ )
else:
raise FinalRetryError(f"Circuit breaker '{self.name}' is open")
-
+
try:
- result = await func(*args, **kwargs) if asyncio.iscoroutinefunction(func) else func(*args, **kwargs)
+ result = (
+ await func(*args, **kwargs)
+ if asyncio.iscoroutinefunction(func)
+ else func(*args, **kwargs)
+ )
await self._on_success()
return result
except Exception as e:
await self._on_failure()
raise
-
+
async def _on_success(self) -> None:
"""Handle successful execution."""
if self.state == CircuitState.HALF_OPEN:
@@ -233,28 +268,32 @@ async def _on_success(self) -> None:
elif self.state == CircuitState.CLOSED:
# Reset failure count on success
self.failure_count = max(0, self.failure_count - 1)
-
+
async def _on_failure(self) -> None:
"""Handle failed execution."""
current_time = time.time()
self.failure_count += 1
self.last_failure_time = current_time
self.failure_history.append(current_time)
-
+
# Clean old failures outside monitoring window
window_start = current_time - self.config.monitoring_window
self.failure_history = [t for t in self.failure_history if t > window_start]
-
+
if self.state == CircuitState.CLOSED:
if len(self.failure_history) >= self.config.failure_threshold:
self.state = CircuitState.OPEN
self.success_count = 0
- self.logger.warning(f"Circuit breaker '{self.name}' opened due to failure threshold")
+ self.logger.warning(
+ f"Circuit breaker '{self.name}' opened due to failure threshold"
+ )
elif self.state == CircuitState.HALF_OPEN:
self.state = CircuitState.OPEN
self.success_count = 0
- self.logger.warning(f"Circuit breaker '{self.name}' reopened after half-open failure")
-
+ self.logger.warning(
+ f"Circuit breaker '{self.name}' reopened after half-open failure"
+ )
+
def _should_attempt_reset(self) -> bool:
"""Check if circuit should attempt reset."""
return time.time() - self.last_failure_time >= self.config.timeout_seconds
@@ -263,16 +302,18 @@ def _should_attempt_reset(self) -> bool:
class ResilientExecutor:
"""
Resilient tool executor with comprehensive error handling.
-
+
Provides intelligent retry strategies, circuit breaking, graceful degradation,
and detailed error classification for production workloads.
"""
-
- def __init__(self,
- arcade_client: Optional[BasicArcadeClient] = None,
- cache_manager: Optional[CacheManager] = None,
- retry_config: Optional[RetryConfig] = None,
- circuit_config: Optional[CircuitBreakerConfig] = None):
+
+ def __init__(
+ self,
+ arcade_client: Optional[BasicArcadeClient] = None,
+ cache_manager: Optional[CacheManager] = None,
+ retry_config: Optional[RetryConfig] = None,
+ circuit_config: Optional[CircuitBreakerConfig] = None,
+ ):
"""Initialize resilient executor."""
self.logger = logging.getLogger(__name__)
self.arcade_client = arcade_client
@@ -280,40 +321,42 @@ def __init__(self,
try:
self.local_executor = ToolExecutor()
except Exception as e:
- self.logger.warning(f"Could not initialize ToolExecutor: {e}. Using mock executor.")
+ self.logger.warning(
+ f"Could not initialize ToolExecutor: {e}. Using mock executor."
+ )
self.local_executor = None
self.metrics_collector = MetricsCollector()
-
+
# Configuration
self.retry_config = retry_config or RetryConfig(
strategy=RetryStrategy.EXPONENTIAL_BACKOFF,
max_attempts=3,
base_delay=1.0,
- max_delay=30.0
+ max_delay=30.0,
)
-
+
# Circuit breakers for different components
circuit_config = circuit_config or CircuitBreakerConfig()
self.local_circuit_breaker = CircuitBreaker(circuit_config, "local_executor")
self.remote_circuit_breaker = CircuitBreaker(circuit_config, "remote_executor")
-
+
# Error tracking
self.error_history: List[ErrorContext] = []
self.error_patterns: Dict[str, List[ErrorType]] = {}
-
+
# Degradation flags
self.degraded_mode = False
self.degradation_start_time = 0.0
-
+
async def execute_tool_resilient(self, tool_call: ToolCall) -> ToolResult:
"""Execute a tool with comprehensive error handling and resilience."""
start_time = time.time()
attempt_count = 0
last_exception = None
-
+
while attempt_count < self.retry_config.max_attempts:
attempt_count += 1
-
+
try:
# Try circuit breaker protected execution
if self.arcade_client:
@@ -324,21 +367,21 @@ async def execute_tool_resilient(self, tool_call: ToolCall) -> ToolResult:
result = await self.local_circuit_breaker.call(
self._execute_local_tool, tool_call
)
-
+
execution_time = (time.time() - start_time) * 1000
-
+
return ToolResult(
call_id=tool_call.id,
tool_name=tool_call.name,
success=True,
data=result,
- execution_time_ms=execution_time
+ execution_time_ms=execution_time,
)
-
+
except Exception as e:
last_exception = e
error_type = self._classify_error(e)
-
+
# Create error context
error_context = ErrorContext(
error=e,
@@ -346,62 +389,61 @@ async def execute_tool_resilient(self, tool_call: ToolCall) -> ToolResult:
tool_name=tool_call.name,
attempt_count=attempt_count,
total_elapsed_time=time.time() - start_time,
- user_id=tool_call.user_id
+ user_id=tool_call.user_id,
)
-
+
# Track error
self.error_history.append(error_context)
-
+
# Check if we should retry
if not self._should_retry(error_context):
break
-
+
# Calculate retry delay
delay = self._calculate_retry_delay(attempt_count)
self.logger.warning(
f"Tool {tool_call.name} failed (attempt {attempt_count}): {e}. "
f"Retrying in {delay:.1f}s..."
)
-
+
await asyncio.sleep(delay)
-
+
# All retries exhausted
execution_time = (time.time() - start_time) * 1000
-
+
return ToolResult(
call_id=tool_call.id,
tool_name=tool_call.name,
success=False,
error=str(last_exception),
- execution_time_ms=execution_time
+ execution_time_ms=execution_time,
)
-
+
async def _execute_remote_tool(self, tool_call: ToolCall) -> Dict[str, Any]:
"""Execute tool using remote Arcade client."""
if not self.arcade_client:
raise ToolExecutionError("No arcade client available")
-
+
result = await self.arcade_client.execute_tool(
- tool_call.name,
- tool_call.arguments
+ tool_call.name, tool_call.arguments
)
return result
-
+
async def _execute_local_tool(self, tool_call: ToolCall) -> Dict[str, Any]:
"""Execute tool using local executor."""
if not self.local_executor:
raise ToolExecutionError("No local executor available")
-
+
# For demo purposes, use the unreliable_service function
if tool_call.name == "Test_UnreliableService":
return unreliable_service(**tool_call.arguments)
-
+
raise ToolNotFoundError(f"Tool {tool_call.name} not found")
-
+
def _classify_error(self, error: Exception) -> ErrorType:
"""Classify error type for handling strategy."""
error_str = str(error).lower()
-
+
if "network" in error_str or "connection" in error_str:
return ErrorType.NETWORK_ERROR
elif "timeout" in error_str:
@@ -416,23 +458,23 @@ def _classify_error(self, error: Exception) -> ErrorType:
return ErrorType.SERVER_ERROR
else:
return ErrorType.UNKNOWN_ERROR
-
+
def _should_retry(self, error_context: ErrorContext) -> bool:
"""Determine if error should be retried."""
# Don't retry validation errors
if error_context.error_type == ErrorType.VALIDATION_ERROR:
return False
-
+
# Don't retry authentication errors
if error_context.error_type == ErrorType.AUTHENTICATION_ERROR:
return False
-
+
# Don't retry if max attempts reached
if error_context.attempt_count >= self.retry_config.max_attempts:
return False
-
+
return True
-
+
def _calculate_retry_delay(self, attempt: int) -> float:
"""Calculate retry delay based on strategy."""
if self.retry_config.strategy == RetryStrategy.NO_RETRY:
@@ -440,14 +482,16 @@ def _calculate_retry_delay(self, attempt: int) -> float:
elif self.retry_config.strategy == RetryStrategy.LINEAR_BACKOFF:
delay = self.retry_config.base_delay * attempt
elif self.retry_config.strategy == RetryStrategy.EXPONENTIAL_BACKOFF:
- delay = self.retry_config.base_delay * (self.retry_config.backoff_multiplier ** (attempt - 1))
+ delay = self.retry_config.base_delay * (
+ self.retry_config.backoff_multiplier ** (attempt - 1)
+ )
else:
delay = self.retry_config.base_delay
-
+
# Apply jitter if enabled
if self.retry_config.jitter:
- delay *= (0.5 + random.random() * 0.5)
-
+ delay *= 0.5 + random.random() * 0.5
+
# Cap at max delay
return min(delay, self.retry_config.max_delay)
@@ -457,88 +501,101 @@ def _calculate_retry_delay(self, attempt: int) -> float:
name="Test_UnreliableService",
description="Unreliable service for testing error handling",
parameters={
- "failure_rate": {"type": "number", "description": "Probability of failure (0.0-1.0)", "default": 0.3},
- "delay_ms": {"type": "integer", "description": "Processing delay in milliseconds", "default": 100}
- }
+ "failure_rate": {
+ "type": "number",
+ "description": "Probability of failure (0.0-1.0)",
+ "default": 0.3,
+ },
+ "delay_ms": {
+ "type": "integer",
+ "description": "Processing delay in milliseconds",
+ "default": 100,
+ },
+ },
)
-def unreliable_service(failure_rate: float = 0.3, delay_ms: int = 100) -> Dict[str, Any]:
+def unreliable_service(
+ failure_rate: float = 0.3, delay_ms: int = 100
+) -> Dict[str, Any]:
"""Simulated unreliable service for testing."""
import time
import random
-
+
# Simulate processing delay
time.sleep(delay_ms / 1000.0)
-
+
# Random failure
if random.random() < failure_rate:
error_types = [
"Network connection failed",
- "Service temporarily unavailable",
+ "Service temporarily unavailable",
"Request timeout",
"Rate limit exceeded",
- "Internal server error"
+ "Internal server error",
]
raise Exception(random.choice(error_types))
-
+
return {
"success": True,
"result": f"Service completed successfully after {delay_ms}ms delay",
- "timestamp": time.time()
+ "timestamp": time.time(),
}
async def demonstrate_error_handling():
"""Demonstrate resilient error handling capabilities."""
-
+
# Configure logging
logging.basicConfig(
level=logging.INFO,
- format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
+ format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
)
logger = logging.getLogger(__name__)
-
+
logger.info("Starting resilient error handling demonstration")
-
+
# Initialize components from environment
arcade_api_key = os.getenv("ARCADE_API_KEY")
cache_config = {
"redis_url": os.getenv("REDIS_URL", "redis://localhost:6379"),
- "default_ttl": 3600
+ "default_ttl": 3600,
}
-
+
# Create resilient executor with custom configuration
retry_config = RetryConfig(
strategy=RetryStrategy.EXPONENTIAL_BACKOFF,
max_attempts=5,
base_delay=0.5,
max_delay=10.0,
- jitter=True
+ jitter=True,
)
-
+
circuit_config = CircuitBreakerConfig(
- failure_threshold=3,
- success_threshold=2,
- timeout_seconds=30.0
+ failure_threshold=3, success_threshold=2, timeout_seconds=30.0
)
-
+
# Initialize clients with demo mode support
demo_mode = (
- not bool(arcade_api_key.strip()) if arcade_api_key else True or
- arcade_api_key.strip() in ["your_api_key", "demo_key", "placeholder"] if arcade_api_key else True or
- len(arcade_api_key.strip()) < 10 if arcade_api_key else True # Real API keys are typically longer
+ not bool(arcade_api_key.strip())
+ if arcade_api_key
+ else (
+ True
+ or arcade_api_key.strip() in ["your_api_key", "demo_key", "placeholder"]
+ if arcade_api_key
+ else True or len(arcade_api_key.strip()) < 10 if arcade_api_key else True
+ ) # Real API keys are typically longer
)
-
+
arcade_config = ArcadeConfig(
api_key=arcade_api_key if not demo_mode else "demo_key",
user_id=os.getenv("ARCADE_USER_ID", "demo@example.com"),
- demo_mode=demo_mode
+ demo_mode=demo_mode,
)
-
+
arcade_client = None
if arcade_api_key or demo_mode:
arcade_client = BasicArcadeClient(arcade_config)
await arcade_client.connect()
-
+
cache_manager = None
try:
if os.getenv("REDIS_URL"):
@@ -549,94 +606,94 @@ async def demonstrate_error_handling():
"prefix": "resilient_demo",
"min_tokens": 1,
"max_size": "1MB",
- "ttl_seconds": 300
+ "ttl_seconds": 300,
}
cache_manager = CacheManager(basic_cache_config)
except Exception as e:
- logger.warning(f"Cache manager initialization failed: {e}. Continuing without cache.")
-
+ logger.warning(
+ f"Cache manager initialization failed: {e}. Continuing without cache."
+ )
+
executor = ResilientExecutor(
arcade_client=arcade_client,
cache_manager=cache_manager,
retry_config=retry_config,
- circuit_config=circuit_config
+ circuit_config=circuit_config,
)
-
+
if demo_mode:
- logger.info("š Running in demo mode - using mock responses and local error simulation")
+ logger.info(
+ "š Running in demo mode - using mock responses and local error simulation"
+ )
else:
logger.info("š Using real API key for Arcade.dev integration")
-
+
# Test scenarios
test_scenarios = [
# Low failure rate - should succeed quickly
{"failure_rate": 0.1, "delay_ms": 50, "description": "Low failure rate"},
-
# Medium failure rate - should retry and eventually succeed
{"failure_rate": 0.5, "delay_ms": 100, "description": "Medium failure rate"},
-
# High failure rate - should eventually fail gracefully
{"failure_rate": 0.9, "delay_ms": 200, "description": "High failure rate"},
-
# Network timeout simulation
- {"failure_rate": 1.0, "delay_ms": 5000, "description": "Timeout simulation"}
+ {"failure_rate": 1.0, "delay_ms": 5000, "description": "Timeout simulation"},
]
-
+
for i, scenario in enumerate(test_scenarios, 1):
logger.info(f"\n--- Test Scenario {i}: {scenario['description']} ---")
-
+
tool_call = ToolCall(
id=f"test_{i}",
name="Test_UnreliableService",
arguments=scenario,
- user_id="demo_user"
+ user_id="demo_user",
)
-
+
start_time = time.time()
-
+
try:
- logger.info(f"Executing unreliable service with {scenario['failure_rate']*100:.0f}% failure rate")
-
+ logger.info(
+ f"Executing unreliable service with {scenario['failure_rate']*100:.0f}% failure rate"
+ )
+
# Use the resilient executor
result = await executor.execute_tool_resilient(tool_call)
execution_time = (time.time() - start_time) * 1000
-
+
if result.success:
logger.info(f"ā Success in {execution_time:.0f}ms")
logger.info(f"Result: {result.data}")
else:
logger.error(f"ā Failed after all retry attempts: {result.error}")
logger.info("Applying graceful degradation...")
-
+
# Demonstrate graceful degradation
fallback_result = {
"status": "degraded",
"message": "Service temporarily unavailable, using cached or default response",
- "fallback_data": {"placeholder": "default_value"}
+ "fallback_data": {"placeholder": "default_value"},
}
logger.info(f"Fallback result: {fallback_result}")
-
+
except Exception as e:
logger.error(f"Unexpected error in scenario {i}: {e}")
logger.debug(f"Exception traceback: {traceback.format_exc()}")
-
+
# Brief pause between scenarios
await asyncio.sleep(1)
-
+
# Demonstrate health status monitoring
logger.info("\n--- System Health Status ---")
health_status = {
- "circuit_breakers": {
- "local": "closed",
- "remote": "closed"
- },
+ "circuit_breakers": {"local": "closed", "remote": "closed"},
"degraded_mode": False,
"error_rate": 0.2,
- "recent_errors": 3
+ "recent_errors": 3,
}
-
+
logger.info(f"Health Status: {health_status}")
-
+
logger.info("Error handling demonstration completed")
@@ -645,13 +702,15 @@ async def demonstrate_error_handling():
print("š”ļø FACT SDK - Resilient Error Handling Example")
print("=" * 50)
print()
- print("This example demonstrates comprehensive error handling for Arcade.dev integration:")
+ print(
+ "This example demonstrates comprehensive error handling for Arcade.dev integration:"
+ )
print("⢠Error classification and routing")
print("⢠Intelligent retry strategies")
print("⢠Circuit breaker patterns")
print("⢠Graceful degradation")
print("⢠Comprehensive logging and monitoring")
print()
-
+
# Run demonstration
- asyncio.run(demonstrate_error_handling())
\ No newline at end of file
+ asyncio.run(demonstrate_error_handling())
diff --git a/examples/arcade-dev/05_cache_integration/cached_arcade_client.py b/examples/arcade-dev/05_cache_integration/cached_arcade_client.py
index 15dd915..bf18588 100644
--- a/examples/arcade-dev/05_cache_integration/cached_arcade_client.py
+++ b/examples/arcade-dev/05_cache_integration/cached_arcade_client.py
@@ -36,17 +36,18 @@
@dataclass
class CacheStrategy:
"""Cache strategy configuration."""
+
strategy_name: str
ttl_seconds: int
max_entries: int = 1000
compression_enabled: bool = False
encryption_enabled: bool = False
invalidation_pattern: str = "time" # time, lru, custom
-
+
# Performance thresholds
hit_rate_threshold: float = 0.8
response_time_threshold_ms: float = 100
-
+
# Advanced features
prefetch_enabled: bool = False
background_refresh_enabled: bool = False
@@ -56,670 +57,734 @@ class CacheStrategy:
@dataclass
class CachePerformanceMetrics:
"""Cache performance tracking."""
+
hits: int = 0
misses: int = 0
sets: int = 0
evictions: int = 0
errors: int = 0
-
+
# Timing metrics
total_hit_time_ms: float = 0
total_miss_time_ms: float = 0
total_set_time_ms: float = 0
-
+
# Size metrics
cache_size_bytes: int = 0
avg_entry_size_bytes: float = 0
-
+
# Advanced metrics
prefetch_hits: int = 0
background_refreshes: int = 0
-
+
def get_hit_rate(self) -> float:
total_requests = self.hits + self.misses
return (self.hits / total_requests) if total_requests > 0 else 0.0
-
+
def get_avg_hit_time_ms(self) -> float:
return (self.total_hit_time_ms / self.hits) if self.hits > 0 else 0.0
-
+
def get_avg_miss_time_ms(self) -> float:
return (self.total_miss_time_ms / self.misses) if self.misses > 0 else 0.0
class HybridCacheManager:
"""Advanced cache manager with multiple cache levels and strategies."""
-
- def __init__(self, primary_cache: CacheManager, strategies: Dict[str, CacheStrategy]):
+
+ def __init__(
+ self, primary_cache: CacheManager, strategies: Dict[str, CacheStrategy]
+ ):
self.primary_cache = primary_cache
self.strategies = strategies
self.logger = logging.getLogger(f"{__name__}.HybridCacheManager")
-
+
# In-memory cache for hot data
self.memory_cache: Dict[str, Dict[str, Any]] = {}
self.memory_cache_access_times: Dict[str, datetime] = {}
self.memory_cache_max_size = 100
-
+
# Performance tracking
self.metrics: Dict[str, CachePerformanceMetrics] = {
strategy_name: CachePerformanceMetrics()
for strategy_name in strategies.keys()
}
-
+
# Cache key prefetching
self.prefetch_queue: List[str] = []
self.prefetch_task: Optional[asyncio.Task] = None
-
+
async def get(self, key: str, strategy_name: str = "default") -> Optional[Any]:
"""Get value with hybrid caching strategy."""
start_time = time.time()
strategy = self.strategies.get(strategy_name, self.strategies["default"])
metrics = self.metrics[strategy_name]
-
+
try:
# Level 1: Memory cache (fastest)
if strategy.multi_level_caching and key in self.memory_cache:
self.memory_cache_access_times[key] = datetime.now(timezone.utc)
-
+
cache_data = self.memory_cache[key]
if self._is_cache_entry_valid(cache_data):
metrics.hits += 1
metrics.total_hit_time_ms += (time.time() - start_time) * 1000
self.logger.debug(f"Memory cache hit for key: {key}")
- return cache_data['data']
+ return cache_data["data"]
else:
# Remove expired entry
del self.memory_cache[key]
del self.memory_cache_access_times[key]
-
+
# Level 2: Primary cache (persistent)
cache_entry = self.primary_cache.get(key)
-
- if cache_entry and self._is_cache_entry_valid({'cached_at': cache_entry.created_at, 'ttl': self.strategies[strategy_name].ttl_seconds}):
+
+ if cache_entry and self._is_cache_entry_valid(
+ {
+ "cached_at": cache_entry.created_at,
+ "ttl": self.strategies[strategy_name].ttl_seconds,
+ }
+ ):
# Cache hit - extract data from cache entry content
try:
cache_data = json.loads(cache_entry.content)
metrics.hits += 1
metrics.total_hit_time_ms += (time.time() - start_time) * 1000
-
+
# Promote to memory cache if multi-level enabled
if strategy.multi_level_caching:
await self._promote_to_memory_cache(key, cache_data, strategy)
-
+
self.logger.debug(f"Primary cache hit for key: {key}")
- return cache_data.get('data', cache_data)
+ return cache_data.get("data", cache_data)
except json.JSONDecodeError:
self.logger.warning(f"Invalid JSON in cache entry for key: {key}")
-
+
# Cache miss
metrics.misses += 1
metrics.total_miss_time_ms += (time.time() - start_time) * 1000
self.logger.debug(f"Cache miss for key: {key}")
-
+
# Schedule prefetch if enabled
if strategy.prefetch_enabled:
await self._schedule_prefetch(key, strategy_name)
-
+
return None
-
+
except Exception as e:
metrics.errors += 1
self.logger.error(f"Cache get error for key {key}: {e}")
return None
-
- async def set(self, key: str, value: Any, strategy_name: str = "default",
- custom_ttl: int = None) -> bool:
+
+ async def set(
+ self,
+ key: str,
+ value: Any,
+ strategy_name: str = "default",
+ custom_ttl: int = None,
+ ) -> bool:
"""Set value with caching strategy."""
start_time = time.time()
strategy = self.strategies.get(strategy_name, self.strategies["default"])
metrics = self.metrics[strategy_name]
-
+
try:
ttl = custom_ttl or strategy.ttl_seconds
-
+
# Prepare cache data with metadata
cache_data = {
- 'data': value,
- 'cached_at': time.time(),
- 'ttl': ttl,
- 'strategy': strategy_name,
- 'access_count': 0,
- 'size_bytes': len(json.dumps(value, default=str))
+ "data": value,
+ "cached_at": time.time(),
+ "ttl": ttl,
+ "strategy": strategy_name,
+ "access_count": 0,
+ "size_bytes": len(json.dumps(value, default=str)),
}
-
+
# Apply compression if enabled
if strategy.compression_enabled:
cache_data = await self._compress_cache_data(cache_data)
-
+
# Apply encryption if enabled
if strategy.encryption_enabled:
cache_data = await self._encrypt_cache_data(cache_data)
-
+
# Store in primary cache using store method
try:
- cache_entry = self.primary_cache.store(key, json.dumps(cache_data, default=str))
+ cache_entry = self.primary_cache.store(
+ key, json.dumps(cache_data, default=str)
+ )
success = True
except Exception as e:
self.logger.error(f"Primary cache store failed: {e}")
success = False
-
+
if success:
metrics.sets += 1
metrics.total_set_time_ms += (time.time() - start_time) * 1000
- metrics.cache_size_bytes += cache_data.get('size_bytes', 0)
-
+ metrics.cache_size_bytes += cache_data.get("size_bytes", 0)
+
# Update average entry size
if metrics.sets > 0:
- metrics.avg_entry_size_bytes = metrics.cache_size_bytes / metrics.sets
-
+ metrics.avg_entry_size_bytes = (
+ metrics.cache_size_bytes / metrics.sets
+ )
+
# Set in memory cache if multi-level enabled
if strategy.multi_level_caching:
await self._set_memory_cache(key, cache_data, strategy)
-
+
self.logger.debug(f"Cache set successful for key: {key}")
-
+
# Schedule background refresh if enabled
if strategy.background_refresh_enabled:
await self._schedule_background_refresh(key, strategy_name, ttl)
-
+
return True
else:
metrics.errors += 1
return False
-
+
except Exception as e:
metrics.errors += 1
self.logger.error(f"Cache set error for key {key}: {e}")
return False
-
+
async def invalidate(self, pattern: str, strategy_name: str = "default") -> int:
"""Invalidate cache entries matching pattern."""
strategy = self.strategies.get(strategy_name, self.strategies["default"])
metrics = self.metrics[strategy_name]
-
+
try:
# Invalidate from memory cache
invalidated_count = 0
-
+
if strategy.multi_level_caching:
keys_to_remove = []
for key in self.memory_cache.keys():
if self._matches_pattern(key, pattern):
keys_to_remove.append(key)
-
+
for key in keys_to_remove:
del self.memory_cache[key]
if key in self.memory_cache_access_times:
del self.memory_cache_access_times[key]
invalidated_count += 1
-
+
# Invalidate from primary cache
# Note: This is a simplified implementation
# In production, you'd need pattern-based invalidation in the cache manager
-
+
metrics.evictions += invalidated_count
- self.logger.info(f"Invalidated {invalidated_count} entries matching pattern: {pattern}")
+ self.logger.info(
+ f"Invalidated {invalidated_count} entries matching pattern: {pattern}"
+ )
return invalidated_count
-
+
except Exception as e:
metrics.errors += 1
self.logger.error(f"Cache invalidation error for pattern {pattern}: {e}")
return 0
-
+
async def optimize_cache(self, strategy_name: str = "default"):
"""Optimize cache performance based on usage patterns."""
strategy = self.strategies.get(strategy_name, self.strategies["default"])
metrics = self.metrics[strategy_name]
-
+
try:
# Check if optimization is needed
hit_rate = metrics.get_hit_rate()
-
+
if hit_rate < strategy.hit_rate_threshold:
- self.logger.warning(f"Cache hit rate ({hit_rate:.2%}) below threshold ({strategy.hit_rate_threshold:.2%})")
-
+ self.logger.warning(
+ f"Cache hit rate ({hit_rate:.2%}) below threshold ({strategy.hit_rate_threshold:.2%})"
+ )
+
# Optimization strategies
await self._optimize_memory_cache_size(strategy)
await self._optimize_ttl_values(strategy_name)
await self._cleanup_expired_entries(strategy_name)
-
- self.logger.info(f"Cache optimization completed for strategy: {strategy_name}")
-
+
+ self.logger.info(
+ f"Cache optimization completed for strategy: {strategy_name}"
+ )
+
except Exception as e:
self.logger.error(f"Cache optimization error: {e}")
-
+
def _is_cache_entry_valid(self, cache_data: Dict[str, Any]) -> bool:
"""Check if cache entry is still valid."""
if not isinstance(cache_data, dict):
return True # Assume simple values are valid
-
- cached_at = cache_data.get('cached_at', 0)
- ttl = cache_data.get('ttl', 3600)
-
+
+ cached_at = cache_data.get("cached_at", 0)
+ ttl = cache_data.get("ttl", 3600)
+
return (time.time() - cached_at) < ttl
-
- async def _promote_to_memory_cache(self, key: str, cache_data: Dict[str, Any], strategy: CacheStrategy):
+
+ async def _promote_to_memory_cache(
+ self, key: str, cache_data: Dict[str, Any], strategy: CacheStrategy
+ ):
"""Promote frequently accessed data to memory cache."""
# Check if memory cache has space
if len(self.memory_cache) >= self.memory_cache_max_size:
await self._evict_lru_memory_cache()
-
+
self.memory_cache[key] = cache_data
self.memory_cache_access_times[key] = datetime.utcnow()
-
- async def _set_memory_cache(self, key: str, cache_data: Dict[str, Any], strategy: CacheStrategy):
+
+ async def _set_memory_cache(
+ self, key: str, cache_data: Dict[str, Any], strategy: CacheStrategy
+ ):
"""Set data in memory cache."""
if len(self.memory_cache) >= self.memory_cache_max_size:
await self._evict_lru_memory_cache()
-
+
self.memory_cache[key] = cache_data
self.memory_cache_access_times[key] = datetime.now(timezone.utc)
-
+
async def _evict_lru_memory_cache(self):
"""Evict least recently used entry from memory cache."""
if not self.memory_cache_access_times:
return
-
+
# Find LRU entry
- lru_key = min(self.memory_cache_access_times.keys(),
- key=lambda k: self.memory_cache_access_times[k])
-
+ lru_key = min(
+ self.memory_cache_access_times.keys(),
+ key=lambda k: self.memory_cache_access_times[k],
+ )
+
# Remove LRU entry
del self.memory_cache[lru_key]
del self.memory_cache_access_times[lru_key]
-
+
async def _schedule_prefetch(self, key: str, strategy_name: str):
"""Schedule prefetching for related keys."""
if key not in self.prefetch_queue:
self.prefetch_queue.append(key)
-
+
if not self.prefetch_task or self.prefetch_task.done():
self.prefetch_task = asyncio.create_task(self._process_prefetch_queue())
-
+
async def _process_prefetch_queue(self):
"""Process prefetch queue in background."""
while self.prefetch_queue:
key = self.prefetch_queue.pop(0)
# In a real implementation, you'd prefetch related data
await asyncio.sleep(0.1) # Simulate prefetch work
-
- async def _schedule_background_refresh(self, key: str, strategy_name: str, ttl: int):
+
+ async def _schedule_background_refresh(
+ self, key: str, strategy_name: str, ttl: int
+ ):
"""Schedule background refresh before TTL expires."""
refresh_time = ttl * 0.8 # Refresh at 80% of TTL
await asyncio.sleep(refresh_time)
-
+
# In a real implementation, you'd refresh the data from source
self.logger.debug(f"Background refresh triggered for key: {key}")
-
+
async def _compress_cache_data(self, cache_data: Dict[str, Any]) -> Dict[str, Any]:
"""Compress cache data to save space."""
# Simplified compression simulation
- cache_data['compressed'] = True
- cache_data['size_bytes'] = int(cache_data['size_bytes'] * 0.7) # 30% compression
+ cache_data["compressed"] = True
+ cache_data["size_bytes"] = int(
+ cache_data["size_bytes"] * 0.7
+ ) # 30% compression
return cache_data
-
+
async def _encrypt_cache_data(self, cache_data: Dict[str, Any]) -> Dict[str, Any]:
"""Encrypt sensitive cache data."""
# Simplified encryption simulation
- cache_data['encrypted'] = True
+ cache_data["encrypted"] = True
return cache_data
-
+
def _matches_pattern(self, key: str, pattern: str) -> bool:
"""Check if key matches invalidation pattern."""
# Simplified pattern matching
if pattern == "*":
return True
return pattern in key
-
+
async def _optimize_memory_cache_size(self, strategy: CacheStrategy):
"""Optimize memory cache size based on hit patterns."""
current_hit_rate = self.metrics[strategy.strategy_name].get_hit_rate()
-
+
if current_hit_rate < 0.5:
# Increase memory cache size
self.memory_cache_max_size = min(self.memory_cache_max_size * 2, 500)
elif current_hit_rate > 0.9:
# Decrease memory cache size
self.memory_cache_max_size = max(self.memory_cache_max_size // 2, 50)
-
- self.logger.debug(f"Optimized memory cache size to: {self.memory_cache_max_size}")
-
+
+ self.logger.debug(
+ f"Optimized memory cache size to: {self.memory_cache_max_size}"
+ )
+
async def _optimize_ttl_values(self, strategy_name: str):
"""Optimize TTL values based on access patterns."""
# In a real implementation, analyze access patterns and adjust TTLs
self.logger.debug(f"TTL optimization completed for strategy: {strategy_name}")
-
+
async def _cleanup_expired_entries(self, strategy_name: str):
"""Clean up expired entries from memory cache."""
expired_keys = []
-
+
for key, cache_data in self.memory_cache.items():
if not self._is_cache_entry_valid(cache_data):
expired_keys.append(key)
-
+
for key in expired_keys:
del self.memory_cache[key]
if key in self.memory_cache_access_times:
del self.memory_cache_access_times[key]
-
+
self.metrics[strategy_name].evictions += len(expired_keys)
self.logger.debug(f"Cleaned up {len(expired_keys)} expired entries")
-
+
def get_performance_report(self) -> Dict[str, Any]:
"""Generate comprehensive performance report."""
report = {
- 'timestamp': datetime.now(timezone.utc).isoformat(),
- 'strategies': {},
- 'memory_cache': {
- 'size': len(self.memory_cache),
- 'max_size': self.memory_cache_max_size,
- 'utilization': len(self.memory_cache) / self.memory_cache_max_size
- }
+ "timestamp": datetime.now(timezone.utc).isoformat(),
+ "strategies": {},
+ "memory_cache": {
+ "size": len(self.memory_cache),
+ "max_size": self.memory_cache_max_size,
+ "utilization": len(self.memory_cache) / self.memory_cache_max_size,
+ },
}
-
+
for strategy_name, metrics in self.metrics.items():
strategy_report = {
- 'hit_rate': metrics.get_hit_rate(),
- 'total_requests': metrics.hits + metrics.misses,
- 'cache_hits': metrics.hits,
- 'cache_misses': metrics.misses,
- 'cache_sets': metrics.sets,
- 'cache_errors': metrics.errors,
- 'avg_hit_time_ms': metrics.get_avg_hit_time_ms(),
- 'avg_miss_time_ms': metrics.get_avg_miss_time_ms(),
- 'cache_size_bytes': metrics.cache_size_bytes,
- 'avg_entry_size_bytes': metrics.avg_entry_size_bytes
+ "hit_rate": metrics.get_hit_rate(),
+ "total_requests": metrics.hits + metrics.misses,
+ "cache_hits": metrics.hits,
+ "cache_misses": metrics.misses,
+ "cache_sets": metrics.sets,
+ "cache_errors": metrics.errors,
+ "avg_hit_time_ms": metrics.get_avg_hit_time_ms(),
+ "avg_miss_time_ms": metrics.get_avg_miss_time_ms(),
+ "cache_size_bytes": metrics.cache_size_bytes,
+ "avg_entry_size_bytes": metrics.avg_entry_size_bytes,
}
-
- report['strategies'][strategy_name] = strategy_report
-
+
+ report["strategies"][strategy_name] = strategy_report
+
return report
class AdvancedArcadeClient:
"""Advanced Arcade.dev client with sophisticated caching."""
-
+
def __init__(self, api_key: str, cache_manager: CacheManager):
self.api_key = api_key
self.logger = logging.getLogger(__name__)
-
+
# Setup cache strategies
self.cache_strategies = {
"default": CacheStrategy(
strategy_name="default",
ttl_seconds=3600,
max_entries=1000,
- multi_level_caching=True
+ multi_level_caching=True,
),
"fast": CacheStrategy(
strategy_name="fast",
ttl_seconds=300,
max_entries=100,
multi_level_caching=True,
- prefetch_enabled=True
+ prefetch_enabled=True,
),
"persistent": CacheStrategy(
strategy_name="persistent",
ttl_seconds=86400, # 24 hours
max_entries=5000,
compression_enabled=True,
- background_refresh_enabled=True
+ background_refresh_enabled=True,
),
"secure": CacheStrategy(
strategy_name="secure",
ttl_seconds=1800,
max_entries=500,
encryption_enabled=True,
- multi_level_caching=False
- )
+ multi_level_caching=False,
+ ),
}
-
+
# Initialize hybrid cache manager
self.hybrid_cache = HybridCacheManager(cache_manager, self.cache_strategies)
-
+
def _generate_cache_key(self, operation: str, **kwargs) -> str:
"""Generate cache key for operation."""
params_str = json.dumps(kwargs, sort_keys=True, default=str)
params_hash = hashlib.sha256(params_str.encode()).hexdigest()[:16]
return f"arcade:{operation}:{params_hash}"
-
- async def cached_operation(self, operation: str, strategy: str = "default",
- force_refresh: bool = False, **kwargs) -> Dict[str, Any]:
+
+ async def cached_operation(
+ self,
+ operation: str,
+ strategy: str = "default",
+ force_refresh: bool = False,
+ **kwargs,
+ ) -> Dict[str, Any]:
"""Execute operation with advanced caching."""
cache_key = self._generate_cache_key(operation, **kwargs)
-
+
# Check cache first
if not force_refresh:
cached_result = await self.hybrid_cache.get(cache_key, strategy)
if cached_result:
self.logger.debug(f"Cache hit for {operation} with strategy {strategy}")
return cached_result
-
+
# Execute operation
start_time = time.time()
result = await self._execute_operation(operation, **kwargs)
execution_time = (time.time() - start_time) * 1000
-
+
# Add metadata
- result['_execution_time_ms'] = execution_time
- result['_cached'] = False
- result['_timestamp'] = datetime.now(timezone.utc).isoformat()
-
+ result["_execution_time_ms"] = execution_time
+ result["_cached"] = False
+ result["_timestamp"] = datetime.now(timezone.utc).isoformat()
+
# Cache result
await self.hybrid_cache.set(cache_key, result, strategy)
-
+
return result
-
+
async def _execute_operation(self, operation: str, **kwargs) -> Dict[str, Any]:
"""Execute the actual operation (mock implementation)."""
# Simulate different operation types
operation_times = {
- 'code_analysis': 2.0,
- 'test_generation': 3.0,
- 'documentation': 1.5,
- 'refactoring': 2.5
+ "code_analysis": 2.0,
+ "test_generation": 3.0,
+ "documentation": 1.5,
+ "refactoring": 2.5,
}
-
+
await asyncio.sleep(operation_times.get(operation, 1.0))
-
+
# Generate longer, more realistic responses to meet 500 token minimum
detailed_results = {
- 'code_analysis': {
- 'summary': f"Comprehensive code analysis completed for {operation}",
- 'complexity_score': 7.5,
- 'maintainability_index': 85.2,
- 'cyclomatic_complexity': 12,
- 'lines_of_code': 342,
- 'functions_analyzed': 15,
- 'classes_analyzed': 3,
- 'issues_found': [
- {'type': 'warning', 'line': 45, 'message': 'Consider breaking down this complex function'},
- {'type': 'info', 'line': 78, 'message': 'Variable naming could be more descriptive'},
- {'type': 'suggestion', 'line': 92, 'message': 'Consider using list comprehension here'}
+ "code_analysis": {
+ "summary": f"Comprehensive code analysis completed for {operation}",
+ "complexity_score": 7.5,
+ "maintainability_index": 85.2,
+ "cyclomatic_complexity": 12,
+ "lines_of_code": 342,
+ "functions_analyzed": 15,
+ "classes_analyzed": 3,
+ "issues_found": [
+ {
+ "type": "warning",
+ "line": 45,
+ "message": "Consider breaking down this complex function",
+ },
+ {
+ "type": "info",
+ "line": 78,
+ "message": "Variable naming could be more descriptive",
+ },
+ {
+ "type": "suggestion",
+ "line": 92,
+ "message": "Consider using list comprehension here",
+ },
],
- 'suggestions': [
- 'Implement proper error handling in the main processing loop',
- 'Add type hints to improve code readability and IDE support',
- 'Consider extracting utility functions to reduce code duplication',
- 'Implement comprehensive logging for better debugging capabilities'
+ "suggestions": [
+ "Implement proper error handling in the main processing loop",
+ "Add type hints to improve code readability and IDE support",
+ "Consider extracting utility functions to reduce code duplication",
+ "Implement comprehensive logging for better debugging capabilities",
],
- 'metrics': {
- 'code_coverage': 87.5,
- 'test_coverage': 92.1,
- 'documentation_coverage': 78.3,
- 'performance_score': 8.7
+ "metrics": {
+ "code_coverage": 87.5,
+ "test_coverage": 92.1,
+ "documentation_coverage": 78.3,
+ "performance_score": 8.7,
+ },
+ "dependencies": [
+ "requests",
+ "asyncio",
+ "json",
+ "pathlib",
+ "dataclasses",
+ ],
+ "security_analysis": {
+ "vulnerabilities_found": 0,
+ "security_score": 9.2,
+ "recommendations": [
+ "Update dependencies to latest versions",
+ "Implement input validation",
+ ],
},
- 'dependencies': ['requests', 'asyncio', 'json', 'pathlib', 'dataclasses'],
- 'security_analysis': {
- 'vulnerabilities_found': 0,
- 'security_score': 9.2,
- 'recommendations': ['Update dependencies to latest versions', 'Implement input validation']
- }
},
- 'test_generation': {
- 'summary': f"Automated test generation completed for {operation}",
- 'tests_generated': 23,
- 'coverage_improvement': 15.7,
- 'test_types': ['unit', 'integration', 'edge_cases'],
- 'generated_tests': [
+ "test_generation": {
+ "summary": f"Automated test generation completed for {operation}",
+ "tests_generated": 23,
+ "coverage_improvement": 15.7,
+ "test_types": ["unit", "integration", "edge_cases"],
+ "generated_tests": [
{
- 'name': 'test_basic_functionality',
- 'type': 'unit',
- 'description': 'Tests basic function behavior with valid inputs',
- 'expected_coverage': 85.2
+ "name": "test_basic_functionality",
+ "type": "unit",
+ "description": "Tests basic function behavior with valid inputs",
+ "expected_coverage": 85.2,
},
{
- 'name': 'test_edge_cases',
- 'type': 'edge_case',
- 'description': 'Tests function behavior with boundary conditions and edge cases',
- 'expected_coverage': 92.8
+ "name": "test_edge_cases",
+ "type": "edge_case",
+ "description": "Tests function behavior with boundary conditions and edge cases",
+ "expected_coverage": 92.8,
},
{
- 'name': 'test_error_handling',
- 'type': 'error',
- 'description': 'Tests proper error handling and exception management',
- 'expected_coverage': 78.5
- }
+ "name": "test_error_handling",
+ "type": "error",
+ "description": "Tests proper error handling and exception management",
+ "expected_coverage": 78.5,
+ },
],
- 'recommendations': [
- 'Add property-based tests for more comprehensive coverage',
- 'Include performance benchmarks for critical functions',
- 'Implement integration tests for external dependencies',
- 'Add mock tests for API interactions'
+ "recommendations": [
+ "Add property-based tests for more comprehensive coverage",
+ "Include performance benchmarks for critical functions",
+ "Implement integration tests for external dependencies",
+ "Add mock tests for API interactions",
],
- 'quality_metrics': {
- 'test_readability': 9.1,
- 'maintainability': 8.8,
- 'execution_speed': 7.9,
- 'reliability': 9.3
- }
+ "quality_metrics": {
+ "test_readability": 9.1,
+ "maintainability": 8.8,
+ "execution_speed": 7.9,
+ "reliability": 9.3,
+ },
},
- 'documentation': {
- 'summary': f"Documentation generation completed for {operation}",
- 'pages_generated': 12,
- 'sections_created': ['API Reference', 'User Guide', 'Examples', 'FAQ'],
- 'documentation_structure': {
- 'getting_started': 'Complete setup and installation guide',
- 'api_reference': 'Detailed API documentation with examples',
- 'tutorials': 'Step-by-step tutorials for common use cases',
- 'troubleshooting': 'Common issues and their solutions'
+ "documentation": {
+ "summary": f"Documentation generation completed for {operation}",
+ "pages_generated": 12,
+ "sections_created": ["API Reference", "User Guide", "Examples", "FAQ"],
+ "documentation_structure": {
+ "getting_started": "Complete setup and installation guide",
+ "api_reference": "Detailed API documentation with examples",
+ "tutorials": "Step-by-step tutorials for common use cases",
+ "troubleshooting": "Common issues and their solutions",
},
- 'content_analysis': {
- 'readability_score': 8.7,
- 'completeness': 91.2,
- 'accuracy': 94.8,
- 'usefulness': 8.9
+ "content_analysis": {
+ "readability_score": 8.7,
+ "completeness": 91.2,
+ "accuracy": 94.8,
+ "usefulness": 8.9,
},
- 'generated_content': [
- 'Function docstrings with parameter descriptions and return types',
- 'Class documentation with usage examples',
- 'Module-level documentation explaining purpose and structure',
- 'README sections with installation and basic usage instructions'
+ "generated_content": [
+ "Function docstrings with parameter descriptions and return types",
+ "Class documentation with usage examples",
+ "Module-level documentation explaining purpose and structure",
+ "README sections with installation and basic usage instructions",
+ ],
+ "improvements_suggested": [
+ "Add more code examples to illustrate complex concepts",
+ "Include visual diagrams for architecture overview",
+ "Expand troubleshooting section with common error scenarios",
+ "Add FAQ section based on user feedback",
],
- 'improvements_suggested': [
- 'Add more code examples to illustrate complex concepts',
- 'Include visual diagrams for architecture overview',
- 'Expand troubleshooting section with common error scenarios',
- 'Add FAQ section based on user feedback'
- ]
},
- 'refactoring': {
- 'summary': f"Code refactoring analysis completed for {operation}",
- 'refactoring_opportunities': 18,
- 'complexity_reduction': 23.5,
- 'performance_improvement': 12.8,
- 'suggested_changes': [
+ "refactoring": {
+ "summary": f"Code refactoring analysis completed for {operation}",
+ "refactoring_opportunities": 18,
+ "complexity_reduction": 23.5,
+ "performance_improvement": 12.8,
+ "suggested_changes": [
{
- 'type': 'extract_method',
- 'location': 'lines 45-67',
- 'description': 'Extract complex logic into separate method for better readability',
- 'impact': 'high'
+ "type": "extract_method",
+ "location": "lines 45-67",
+ "description": "Extract complex logic into separate method for better readability",
+ "impact": "high",
},
{
- 'type': 'remove_duplication',
- 'location': 'multiple locations',
- 'description': 'Consolidate duplicate code patterns into reusable functions',
- 'impact': 'medium'
+ "type": "remove_duplication",
+ "location": "multiple locations",
+ "description": "Consolidate duplicate code patterns into reusable functions",
+ "impact": "medium",
},
{
- 'type': 'simplify_conditionals',
- 'location': 'lines 89-102',
- 'description': 'Simplify nested conditional statements using early returns',
- 'impact': 'medium'
- }
+ "type": "simplify_conditionals",
+ "location": "lines 89-102",
+ "description": "Simplify nested conditional statements using early returns",
+ "impact": "medium",
+ },
],
- 'code_quality_improvements': {
- 'maintainability_score_increase': 15.3,
- 'readability_improvement': 22.7,
- 'testability_enhancement': 18.9,
- 'performance_optimization': 8.4
+ "code_quality_improvements": {
+ "maintainability_score_increase": 15.3,
+ "readability_improvement": 22.7,
+ "testability_enhancement": 18.9,
+ "performance_optimization": 8.4,
},
- 'design_patterns_suggested': [
- 'Strategy pattern for algorithm selection',
- 'Factory pattern for object creation',
- 'Observer pattern for event handling'
+ "design_patterns_suggested": [
+ "Strategy pattern for algorithm selection",
+ "Factory pattern for object creation",
+ "Observer pattern for event handling",
],
- 'best_practices_recommendations': [
- 'Implement proper dependency injection',
- 'Use configuration files for magic numbers',
- 'Add comprehensive error handling',
- 'Implement proper logging throughout the application'
- ]
- }
+ "best_practices_recommendations": [
+ "Implement proper dependency injection",
+ "Use configuration files for magic numbers",
+ "Add comprehensive error handling",
+ "Implement proper logging throughout the application",
+ ],
+ },
}
-
+
# Add padding to ensure we meet 500 token minimum for caching
- padding_text = " ".join([
- "This is additional padding content to ensure the response meets the minimum token requirement for caching.",
- "The cache system requires at least 500 tokens to store responses effectively.",
- "This padding includes various technical details and explanations about the operation.",
- "Performance metrics indicate optimal execution patterns with efficient resource utilization.",
- "The system maintains high availability and reliability standards throughout the operation lifecycle.",
- "Comprehensive monitoring and logging capabilities provide detailed insights into system behavior.",
- "Advanced error handling mechanisms ensure robust operation under various conditions.",
- "Security protocols maintain data integrity and prevent unauthorized access attempts.",
- "Scalability features allow the system to handle increased load and concurrent operations.",
- "Integration capabilities enable seamless interaction with external systems and APIs.",
- "Configuration management provides flexible deployment options across different environments.",
- "Testing frameworks ensure quality assurance and prevent regression issues.",
- "Documentation standards maintain clear and comprehensive technical specifications.",
- "Version control systems track changes and enable collaborative development workflows.",
- "Deployment pipelines automate the release process and ensure consistent deployments.",
- "Backup and recovery procedures protect against data loss and system failures.",
- "Performance optimization techniques improve response times and resource efficiency.",
- "User experience considerations guide interface design and interaction patterns.",
- "Accessibility standards ensure inclusive design for all user groups.",
- "Internationalization features support multiple languages and regional preferences."
- ])
-
+ padding_text = " ".join(
+ [
+ "This is additional padding content to ensure the response meets the minimum token requirement for caching.",
+ "The cache system requires at least 500 tokens to store responses effectively.",
+ "This padding includes various technical details and explanations about the operation.",
+ "Performance metrics indicate optimal execution patterns with efficient resource utilization.",
+ "The system maintains high availability and reliability standards throughout the operation lifecycle.",
+ "Comprehensive monitoring and logging capabilities provide detailed insights into system behavior.",
+ "Advanced error handling mechanisms ensure robust operation under various conditions.",
+ "Security protocols maintain data integrity and prevent unauthorized access attempts.",
+ "Scalability features allow the system to handle increased load and concurrent operations.",
+ "Integration capabilities enable seamless interaction with external systems and APIs.",
+ "Configuration management provides flexible deployment options across different environments.",
+ "Testing frameworks ensure quality assurance and prevent regression issues.",
+ "Documentation standards maintain clear and comprehensive technical specifications.",
+ "Version control systems track changes and enable collaborative development workflows.",
+ "Deployment pipelines automate the release process and ensure consistent deployments.",
+ "Backup and recovery procedures protect against data loss and system failures.",
+ "Performance optimization techniques improve response times and resource efficiency.",
+ "User experience considerations guide interface design and interaction patterns.",
+ "Accessibility standards ensure inclusive design for all user groups.",
+ "Internationalization features support multiple languages and regional preferences.",
+ ]
+ )
+
return {
- 'operation': operation,
- 'status': 'success',
- 'result': detailed_results.get(operation, f"Detailed result for {operation}"),
- 'parameters': kwargs,
- 'execution_metadata': {
- 'timestamp': datetime.now(timezone.utc).isoformat(),
- 'execution_id': hashlib.sha256(f"{operation}_{time.time()}".encode()).hexdigest()[:16],
- 'processing_time_seconds': operation_times.get(operation, 1.0),
- 'resource_usage': {
- 'memory_mb': 45.7,
- 'cpu_percent': 23.4
- }
+ "operation": operation,
+ "status": "success",
+ "result": detailed_results.get(
+ operation, f"Detailed result for {operation}"
+ ),
+ "parameters": kwargs,
+ "execution_metadata": {
+ "timestamp": datetime.now(timezone.utc).isoformat(),
+ "execution_id": hashlib.sha256(
+ f"{operation}_{time.time()}".encode()
+ ).hexdigest()[:16],
+ "processing_time_seconds": operation_times.get(operation, 1.0),
+ "resource_usage": {"memory_mb": 45.7, "cpu_percent": 23.4},
+ },
+ "padding_for_cache_minimum": padding_text,
+ "technical_notes": {
+ "cache_strategy": "Hybrid caching with multi-tier storage optimization",
+ "performance_profile": "Optimized for high-throughput and low-latency operations",
+ "reliability_measures": "Comprehensive error handling and fallback mechanisms",
+ "security_features": "End-to-end encryption and secure authentication protocols",
+ "monitoring_capabilities": "Real-time metrics collection and alerting systems",
},
- 'padding_for_cache_minimum': padding_text,
- 'technical_notes': {
- 'cache_strategy': 'Hybrid caching with multi-tier storage optimization',
- 'performance_profile': 'Optimized for high-throughput and low-latency operations',
- 'reliability_measures': 'Comprehensive error handling and fallback mechanisms',
- 'security_features': 'End-to-end encryption and secure authentication protocols',
- 'monitoring_capabilities': 'Real-time metrics collection and alerting systems'
- }
}
-
+
async def optimize_caching(self):
"""Optimize all cache strategies."""
for strategy_name in self.cache_strategies.keys():
await self.hybrid_cache.optimize_cache(strategy_name)
-
+
def get_cache_analytics(self) -> Dict[str, Any]:
"""Get comprehensive cache analytics."""
return self.hybrid_cache.get_performance_report()
@@ -729,24 +794,28 @@ async def demonstrate_advanced_caching():
"""Demonstrate advanced caching capabilities."""
print("š Advanced Cache Integration Demo")
print("=" * 50)
-
+
# Initialize cache manager with proper configuration
cache_config = get_default_cache_config()
cache_manager = CacheManager(cache_config)
-
+
# Create advanced client
client = AdvancedArcadeClient("demo_api_key", cache_manager)
-
+
print("\nš Testing different cache strategies...")
-
+
# Test operations with different strategies
operations = [
- ("code_analysis", "fast", {"code": "def hello(): print('world')", "language": "python"}),
+ (
+ "code_analysis",
+ "fast",
+ {"code": "def hello(): print('world')", "language": "python"},
+ ),
("test_generation", "default", {"code": "def add(a, b): return a + b"}),
("documentation", "persistent", {"code": "class MyClass: pass"}),
- ("refactoring", "secure", {"code": "legacy_code_here"})
+ ("refactoring", "secure", {"code": "legacy_code_here"}),
]
-
+
# First run (cache miss)
print("\nš First execution (cache miss):")
for operation, strategy, params in operations:
@@ -754,7 +823,7 @@ async def demonstrate_advanced_caching():
result = await client.cached_operation(operation, strategy, **params)
duration = (time.time() - start_time) * 1000
print(f" {operation} ({strategy}): {duration:.1f}ms")
-
+
# Second run (cache hit)
print("\nā” Second execution (cache hit):")
for operation, strategy, params in operations:
@@ -762,27 +831,27 @@ async def demonstrate_advanced_caching():
result = await client.cached_operation(operation, strategy, **params)
duration = (time.time() - start_time) * 1000
print(f" {operation} ({strategy}): {duration:.1f}ms")
-
+
# Cache optimization
print("\nš§ Optimizing cache performance...")
await client.optimize_caching()
-
+
# Analytics
print("\nš Cache Performance Analytics:")
analytics = client.get_cache_analytics()
-
- for strategy_name, metrics in analytics['strategies'].items():
+
+ for strategy_name, metrics in analytics["strategies"].items():
print(f"\n {strategy_name.upper()} Strategy:")
print(f" Hit Rate: {metrics['hit_rate']:.1%}")
print(f" Total Requests: {metrics['total_requests']}")
print(f" Avg Hit Time: {metrics['avg_hit_time_ms']:.1f}ms")
print(f" Cache Size: {metrics['cache_size_bytes']} bytes")
-
+
print(f"\n Memory Cache:")
- memory_stats = analytics['memory_cache']
+ memory_stats = analytics["memory_cache"]
print(f" Utilization: {memory_stats['utilization']:.1%}")
print(f" Size: {memory_stats['size']}/{memory_stats['max_size']}")
-
+
print("\nš Advanced caching demonstration completed!")
@@ -790,9 +859,9 @@ async def main():
"""Main demonstration function."""
logging.basicConfig(
level=logging.INFO,
- format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
+ format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
)
-
+
try:
await demonstrate_advanced_caching()
return 0
@@ -803,4 +872,4 @@ async def main():
if __name__ == "__main__":
exit_code = asyncio.run(main())
- sys.exit(exit_code)
\ No newline at end of file
+ sys.exit(exit_code)
diff --git a/examples/arcade-dev/05_cache_integration/cached_arcade_client_enhanced.py b/examples/arcade-dev/05_cache_integration/cached_arcade_client_enhanced.py
index f40d517..592d3fe 100644
--- a/examples/arcade-dev/05_cache_integration/cached_arcade_client_enhanced.py
+++ b/examples/arcade-dev/05_cache_integration/cached_arcade_client_enhanced.py
@@ -46,16 +46,17 @@
@dataclass
class CacheStrategy:
"""Cache strategy configuration."""
+
strategy_name: str
ttl_seconds: int
max_entries: int = 1000
compression_enabled: bool = False
encryption_enabled: bool = False
-
+
# Performance thresholds
hit_rate_threshold: float = 0.8
response_time_threshold_ms: float = 100
-
+
# Advanced features
prefetch_enabled: bool = False
background_refresh_enabled: bool = False
@@ -64,54 +65,57 @@ class CacheStrategy:
@dataclass
class CachePerformanceMetrics:
"""Cache performance tracking."""
+
hits: int = 0
misses: int = 0
sets: int = 0
errors: int = 0
-
+
# Timing metrics
total_hit_time_ms: float = 0
total_miss_time_ms: float = 0
total_set_time_ms: float = 0
-
+
# Size metrics
cache_size_bytes: int = 0
-
+
def get_hit_rate(self) -> float:
total_requests = self.hits + self.misses
return (self.hits / total_requests) if total_requests > 0 else 0.0
-
+
def get_avg_hit_time_ms(self) -> float:
return (self.total_hit_time_ms / self.hits) if self.hits > 0 else 0.0
-
+
def get_avg_miss_time_ms(self) -> float:
return (self.total_miss_time_ms / self.misses) if self.misses > 0 else 0.0
class EnhancedCacheManager:
"""Enhanced cache manager with response padding for successful caching."""
-
- def __init__(self, primary_cache: CacheManager, strategies: Dict[str, CacheStrategy]):
+
+ def __init__(
+ self, primary_cache: CacheManager, strategies: Dict[str, CacheStrategy]
+ ):
self.primary_cache = primary_cache
self.strategies = strategies
self.logger = logging.getLogger(f"{__name__}.EnhancedCacheManager")
-
+
# Performance tracking
self.metrics: Dict[str, CachePerformanceMetrics] = {
strategy_name: CachePerformanceMetrics()
for strategy_name in strategies.keys()
}
-
+
def get(self, key: str, strategy_name: str = "default") -> Optional[Any]:
"""Get value with caching strategy."""
start_time = time.time()
strategy = self.strategies.get(strategy_name, self.strategies["default"])
metrics = self.metrics[strategy_name]
-
+
try:
# Get from primary cache
cache_entry = self.primary_cache.get(key)
-
+
if cache_entry:
# Cache hit
metrics.hits += 1
@@ -124,172 +128,175 @@ def get(self, key: str, strategy_name: str = "default") -> Optional[Any]:
metrics.total_miss_time_ms += (time.time() - start_time) * 1000
self.logger.debug(f"Cache miss for key: {key}")
return None
-
+
except Exception as e:
metrics.errors += 1
self.logger.error(f"Cache get error for key {key}: {e}")
return None
-
+
def set(self, key: str, value: Any, strategy_name: str = "default") -> bool:
"""Set value with caching strategy and response padding."""
start_time = time.time()
strategy = self.strategies.get(strategy_name, self.strategies["default"])
metrics = self.metrics[strategy_name]
-
+
try:
# Convert value to string for padding if it's a dict/object
if isinstance(value, dict):
value_str = json.dumps(value, indent=2, default=str)
else:
value_str = str(value)
-
+
# Apply response padding to meet cache requirements
content_type = self._determine_content_type(value)
enhanced_content = pad_response_for_caching(
content=value_str,
content_type=content_type,
target_tokens=self.primary_cache.min_tokens,
- preserve_format=True
+ preserve_format=True,
)
-
+
# Store enhanced content in primary cache
cache_entry = self.primary_cache.store(key, enhanced_content)
-
+
if cache_entry:
metrics.sets += 1
metrics.total_set_time_ms += (time.time() - start_time) * 1000
metrics.cache_size_bytes += cache_entry.size_bytes
-
- self.logger.info(f"Cache set successful for key: {key}, tokens: {cache_entry.token_count}")
+
+ self.logger.info(
+ f"Cache set successful for key: {key}, tokens: {cache_entry.token_count}"
+ )
return True
else:
metrics.errors += 1
return False
-
+
except Exception as e:
metrics.errors += 1
self.logger.error(f"Cache set error for key {key}: {e}")
return False
-
+
def _determine_content_type(self, value: Any) -> str:
"""Determine content type for appropriate padding."""
if isinstance(value, dict):
- if 'operation' in value:
- operation = value.get('operation', '')
- if 'sql' in operation.lower():
- return 'sql'
- elif 'api' in operation.lower():
- return 'json'
- return 'json'
+ if "operation" in value:
+ operation = value.get("operation", "")
+ if "sql" in operation.lower():
+ return "sql"
+ elif "api" in operation.lower():
+ return "json"
+ return "json"
elif isinstance(value, str):
- if 'SELECT' in value.upper() or 'INSERT' in value.upper():
- return 'sql'
- elif value.strip().startswith('{') or value.strip().startswith('['):
- return 'json'
- return 'generic'
-
+ if "SELECT" in value.upper() or "INSERT" in value.upper():
+ return "sql"
+ elif value.strip().startswith("{") or value.strip().startswith("["):
+ return "json"
+ return "generic"
+
def optimize_cache(self, strategy_name: str = "default"):
"""Optimize cache performance based on usage patterns."""
strategy = self.strategies.get(strategy_name, self.strategies["default"])
metrics = self.metrics[strategy_name]
-
+
try:
# Check if optimization is needed
hit_rate = metrics.get_hit_rate()
-
+
if hit_rate < strategy.hit_rate_threshold:
- self.logger.warning(f"Cache hit rate ({hit_rate:.2%}) below threshold ({strategy.hit_rate_threshold:.2%})")
+ self.logger.warning(
+ f"Cache hit rate ({hit_rate:.2%}) below threshold ({strategy.hit_rate_threshold:.2%})"
+ )
else:
- self.logger.info(f"Cache hit rate ({hit_rate:.2%}) meets threshold ({strategy.hit_rate_threshold:.2%})")
-
- self.logger.info(f"Cache optimization completed for strategy: {strategy_name}")
-
+ self.logger.info(
+ f"Cache hit rate ({hit_rate:.2%}) meets threshold ({strategy.hit_rate_threshold:.2%})"
+ )
+
+ self.logger.info(
+ f"Cache optimization completed for strategy: {strategy_name}"
+ )
+
except Exception as e:
self.logger.error(f"Cache optimization error: {e}")
-
+
def get_performance_report(self) -> Dict[str, Any]:
"""Generate comprehensive performance report."""
- report = {
- 'timestamp': datetime.now(timezone.utc).isoformat(),
- 'strategies': {}
- }
-
+ report = {"timestamp": datetime.now(timezone.utc).isoformat(), "strategies": {}}
+
for strategy_name, metrics in self.metrics.items():
strategy_report = {
- 'hit_rate': metrics.get_hit_rate(),
- 'total_requests': metrics.hits + metrics.misses,
- 'cache_hits': metrics.hits,
- 'cache_misses': metrics.misses,
- 'cache_sets': metrics.sets,
- 'cache_errors': metrics.errors,
- 'avg_hit_time_ms': metrics.get_avg_hit_time_ms(),
- 'avg_miss_time_ms': metrics.get_avg_miss_time_ms(),
- 'cache_size_bytes': metrics.cache_size_bytes
+ "hit_rate": metrics.get_hit_rate(),
+ "total_requests": metrics.hits + metrics.misses,
+ "cache_hits": metrics.hits,
+ "cache_misses": metrics.misses,
+ "cache_sets": metrics.sets,
+ "cache_errors": metrics.errors,
+ "avg_hit_time_ms": metrics.get_avg_hit_time_ms(),
+ "avg_miss_time_ms": metrics.get_avg_miss_time_ms(),
+ "cache_size_bytes": metrics.cache_size_bytes,
}
-
- report['strategies'][strategy_name] = strategy_report
-
+
+ report["strategies"][strategy_name] = strategy_report
+
return report
class AdvancedArcadeClient:
"""Advanced Arcade.dev client with sophisticated caching and response padding."""
-
+
def __init__(self, api_key: str, cache_manager: CacheManager):
self.api_key = api_key
self.logger = logging.getLogger(__name__)
-
+
# Detect demo mode based on API key
self.demo_mode = (
- not api_key or
- api_key in ["demo_api_key", "test-key-for-testing"] or
- api_key.startswith("demo_") or
- api_key.startswith("test_")
+ not api_key
+ or api_key in ["demo_api_key", "test-key-for-testing"]
+ or api_key.startswith("demo_")
+ or api_key.startswith("test_")
)
-
+
if self.demo_mode:
self.logger.info("Running in DEMO MODE - no real API calls will be made")
else:
self.logger.info("Running with real API credentials")
-
+
# Setup cache strategies
self.cache_strategies = {
- "default": CacheStrategy(
- strategy_name="default",
- ttl_seconds=3600
- ),
+ "default": CacheStrategy(strategy_name="default", ttl_seconds=3600),
"fast": CacheStrategy(
- strategy_name="fast",
- ttl_seconds=300,
- prefetch_enabled=True
+ strategy_name="fast", ttl_seconds=300, prefetch_enabled=True
),
"persistent": CacheStrategy(
strategy_name="persistent",
ttl_seconds=86400, # 24 hours
compression_enabled=True,
- background_refresh_enabled=True
+ background_refresh_enabled=True,
),
"secure": CacheStrategy(
- strategy_name="secure",
- ttl_seconds=1800,
- encryption_enabled=True
- )
+ strategy_name="secure", ttl_seconds=1800, encryption_enabled=True
+ ),
}
-
+
# Initialize enhanced cache manager
self.enhanced_cache = EnhancedCacheManager(cache_manager, self.cache_strategies)
-
+
def _generate_cache_key(self, operation: str, **kwargs) -> str:
"""Generate cache key for operation."""
params_str = json.dumps(kwargs, sort_keys=True, default=str)
params_hash = hashlib.sha256(params_str.encode()).hexdigest()[:16]
return f"arcade:{operation}:{params_hash}"
-
- async def cached_operation(self, operation: str, strategy: str = "default",
- force_refresh: bool = False, **kwargs) -> Dict[str, Any]:
+
+ async def cached_operation(
+ self,
+ operation: str,
+ strategy: str = "default",
+ force_refresh: bool = False,
+ **kwargs,
+ ) -> Dict[str, Any]:
"""Execute operation with advanced caching and response padding."""
cache_key = self._generate_cache_key(operation, **kwargs)
-
+
# Check cache first
if not force_refresh:
cached_result = self.enhanced_cache.get(cache_key, strategy)
@@ -299,152 +306,201 @@ async def cached_operation(self, operation: str, strategy: str = "default",
# Extract original result from the enhanced content
try:
# Look for the original JSON in the enhanced content
- lines = cached_result.split('\n')
+ lines = cached_result.split("\n")
for line in lines:
- if line.strip().startswith('{') and 'operation' in line:
+ if line.strip().startswith("{") and "operation" in line:
original_result = json.loads(line.strip())
- original_result['_cached'] = True
- original_result['_enhanced_content'] = True
+ original_result["_cached"] = True
+ original_result["_enhanced_content"] = True
return original_result
except:
pass
-
+
# Fallback: return a structured response indicating cache hit
return {
- 'operation': operation,
- 'status': 'success',
- 'result': f"Cached result for {operation}",
- 'parameters': kwargs,
- '_cached': True,
- '_enhanced_content': True,
- '_timestamp': datetime.now(timezone.utc).isoformat()
+ "operation": operation,
+ "status": "success",
+ "result": f"Cached result for {operation}",
+ "parameters": kwargs,
+ "_cached": True,
+ "_enhanced_content": True,
+ "_timestamp": datetime.now(timezone.utc).isoformat(),
}
-
+
# Execute operation
start_time = time.time()
result = await self._execute_operation(operation, **kwargs)
execution_time = (time.time() - start_time) * 1000
-
+
# Add metadata
- result['_execution_time_ms'] = execution_time
- result['_cached'] = False
- result['_timestamp'] = datetime.now(timezone.utc).isoformat()
-
+ result["_execution_time_ms"] = execution_time
+ result["_cached"] = False
+ result["_timestamp"] = datetime.now(timezone.utc).isoformat()
+
# Cache result with padding
self.enhanced_cache.set(cache_key, result, strategy)
-
+
return result
-
+
async def _execute_operation(self, operation: str, **kwargs) -> Dict[str, Any]:
"""Execute the actual operation (mock implementation for demo)."""
if self.demo_mode:
# Simulate different operation types with realistic delays
operation_times = {
- 'code_analysis': 2.0,
- 'test_generation': 3.0,
- 'documentation': 1.5,
- 'refactoring': 2.5
+ "code_analysis": 2.0,
+ "test_generation": 3.0,
+ "documentation": 1.5,
+ "refactoring": 2.5,
}
-
+
await asyncio.sleep(operation_times.get(operation, 1.0))
-
+
# Generate comprehensive demo results
demo_results = {
- 'code_analysis': {
- 'summary': f"Comprehensive analysis of {kwargs.get('language', 'code')} code",
- 'metrics': {
- 'lines_of_code': len(kwargs.get('code', '').split('\n')),
- 'complexity_score': 7.2,
- 'maintainability_index': 85,
- 'technical_debt_ratio': 0.15
+ "code_analysis": {
+ "summary": f"Comprehensive analysis of {kwargs.get('language', 'code')} code",
+ "metrics": {
+ "lines_of_code": len(kwargs.get("code", "").split("\n")),
+ "complexity_score": 7.2,
+ "maintainability_index": 85,
+ "technical_debt_ratio": 0.15,
},
- 'issues': [
- {'type': 'warning', 'message': 'Consider extracting method for better readability'},
- {'type': 'info', 'message': 'Add type hints for better code documentation'},
- {'type': 'suggestion', 'message': 'Use f-strings for string formatting'}
+ "issues": [
+ {
+ "type": "warning",
+ "message": "Consider extracting method for better readability",
+ },
+ {
+ "type": "info",
+ "message": "Add type hints for better code documentation",
+ },
+ {
+ "type": "suggestion",
+ "message": "Use f-strings for string formatting",
+ },
+ ],
+ "recommendations": [
+ "Implement unit tests for critical functions",
+ "Add docstrings to public methods",
+ "Consider using design patterns for scalability",
],
- 'recommendations': [
- 'Implement unit tests for critical functions',
- 'Add docstrings to public methods',
- 'Consider using design patterns for scalability'
- ]
},
- 'test_generation': {
- 'summary': f"Generated comprehensive test suite",
- 'test_cases': [
- {'name': 'test_basic_functionality', 'type': 'unit', 'coverage': '95%'},
- {'name': 'test_edge_cases', 'type': 'unit', 'coverage': '88%'},
- {'name': 'test_error_handling', 'type': 'unit', 'coverage': '92%'},
- {'name': 'test_integration', 'type': 'integration', 'coverage': '87%'},
- {'name': 'test_performance', 'type': 'performance', 'coverage': '90%'}
+ "test_generation": {
+ "summary": f"Generated comprehensive test suite",
+ "test_cases": [
+ {
+ "name": "test_basic_functionality",
+ "type": "unit",
+ "coverage": "95%",
+ },
+ {"name": "test_edge_cases", "type": "unit", "coverage": "88%"},
+ {
+ "name": "test_error_handling",
+ "type": "unit",
+ "coverage": "92%",
+ },
+ {
+ "name": "test_integration",
+ "type": "integration",
+ "coverage": "87%",
+ },
+ {
+ "name": "test_performance",
+ "type": "performance",
+ "coverage": "90%",
+ },
],
- 'coverage_metrics': {
- 'line_coverage': '91%',
- 'branch_coverage': '88%',
- 'function_coverage': '100%'
+ "coverage_metrics": {
+ "line_coverage": "91%",
+ "branch_coverage": "88%",
+ "function_coverage": "100%",
},
- 'frameworks': ['pytest', 'unittest', 'hypothesis']
+ "frameworks": ["pytest", "unittest", "hypothesis"],
},
- 'documentation': {
- 'summary': f"Generated technical documentation",
- 'sections': [
- {'title': 'API Reference', 'pages': 12, 'status': 'complete'},
- {'title': 'Usage Examples', 'pages': 8, 'status': 'complete'},
- {'title': 'Architecture Overview', 'pages': 15, 'status': 'complete'},
- {'title': 'Troubleshooting Guide', 'pages': 6, 'status': 'complete'}
+ "documentation": {
+ "summary": f"Generated technical documentation",
+ "sections": [
+ {"title": "API Reference", "pages": 12, "status": "complete"},
+ {"title": "Usage Examples", "pages": 8, "status": "complete"},
+ {
+ "title": "Architecture Overview",
+ "pages": 15,
+ "status": "complete",
+ },
+ {
+ "title": "Troubleshooting Guide",
+ "pages": 6,
+ "status": "complete",
+ },
],
- 'formats': ['markdown', 'html', 'pdf'],
- 'word_count': 12500,
- 'diagrams_generated': 8
+ "formats": ["markdown", "html", "pdf"],
+ "word_count": 12500,
+ "diagrams_generated": 8,
},
- 'refactoring': {
- 'summary': f"Comprehensive code refactoring analysis",
- 'improvements': [
- {'type': 'Extract Method', 'impact': 'high', 'effort': 'medium'},
- {'type': 'Reduce Complexity', 'impact': 'high', 'effort': 'high'},
- {'type': 'Improve Naming', 'impact': 'medium', 'effort': 'low'},
- {'type': 'Add Type Hints', 'impact': 'medium', 'effort': 'medium'},
- {'type': 'Optimize Performance', 'impact': 'high', 'effort': 'high'}
+ "refactoring": {
+ "summary": f"Comprehensive code refactoring analysis",
+ "improvements": [
+ {
+ "type": "Extract Method",
+ "impact": "high",
+ "effort": "medium",
+ },
+ {
+ "type": "Reduce Complexity",
+ "impact": "high",
+ "effort": "high",
+ },
+ {"type": "Improve Naming", "impact": "medium", "effort": "low"},
+ {
+ "type": "Add Type Hints",
+ "impact": "medium",
+ "effort": "medium",
+ },
+ {
+ "type": "Optimize Performance",
+ "impact": "high",
+ "effort": "high",
+ },
],
- 'metrics_before': {
- 'cyclomatic_complexity': 12,
- 'maintainability_index': 65,
- 'lines_of_code': 450
+ "metrics_before": {
+ "cyclomatic_complexity": 12,
+ "maintainability_index": 65,
+ "lines_of_code": 450,
},
- 'metrics_after': {
- 'cyclomatic_complexity': 7,
- 'maintainability_index': 85,
- 'lines_of_code': 380
+ "metrics_after": {
+ "cyclomatic_complexity": 7,
+ "maintainability_index": 85,
+ "lines_of_code": 380,
},
- 'estimated_time_saved': '15 hours/month'
- }
+ "estimated_time_saved": "15 hours/month",
+ },
}
-
+
return {
- 'operation': operation,
- 'status': 'success',
- 'result': demo_results.get(operation, f"Demo result for {operation}"),
- 'parameters': kwargs,
- 'demo_mode': True
+ "operation": operation,
+ "status": "success",
+ "result": demo_results.get(operation, f"Demo result for {operation}"),
+ "parameters": kwargs,
+ "demo_mode": True,
}
else:
# In live mode, this would make actual API calls to Arcade.dev
self.logger.warning("Live API mode not yet implemented - using simulation")
await asyncio.sleep(1.0) # Faster for live mode
-
+
return {
- 'operation': operation,
- 'status': 'simulated',
- 'result': f"Simulated {operation} result (would be real API call)",
- 'parameters': kwargs,
- 'demo_mode': False
+ "operation": operation,
+ "status": "simulated",
+ "result": f"Simulated {operation} result (would be real API call)",
+ "parameters": kwargs,
+ "demo_mode": False,
}
-
+
def optimize_caching(self):
"""Optimize all cache strategies."""
for strategy_name in self.cache_strategies.keys():
self.enhanced_cache.optimize_cache(strategy_name)
-
+
def get_cache_analytics(self) -> Dict[str, Any]:
"""Get comprehensive cache analytics."""
return self.enhanced_cache.get_performance_report()
@@ -454,32 +510,36 @@ async def demonstrate_enhanced_caching():
"""Demonstrate enhanced caching capabilities with response padding."""
print("š Enhanced Cache Integration Demo with Response Padding")
print("=" * 70)
-
+
# Initialize cache manager with proper configuration
cache_config = get_default_cache_config()
cache_manager = CacheManager(cache_config)
-
+
# Get API key from environment or use demo mode
api_key = os.getenv("ARCADE_API_KEY", "demo_api_key")
-
+
# Create advanced client
client = AdvancedArcadeClient(api_key, cache_manager)
-
+
# Show mode status
mode_status = "š® DEMO MODE" if client.demo_mode else "š LIVE MODE"
print(f"\n{mode_status} - Using API key: {api_key[:10]}...")
print(f"š¾ Cache minimum tokens: {cache_manager.min_tokens}")
-
+
print("\nš Testing different cache strategies with response padding...")
-
+
# Test operations with different strategies
operations = [
- ("code_analysis", "fast", {"code": "def hello(): print('world')", "language": "python"}),
+ (
+ "code_analysis",
+ "fast",
+ {"code": "def hello(): print('world')", "language": "python"},
+ ),
("test_generation", "default", {"code": "def add(a, b): return a + b"}),
("documentation", "persistent", {"code": "class MyClass: pass"}),
- ("refactoring", "secure", {"code": "legacy_code_here"})
+ ("refactoring", "secure", {"code": "legacy_code_here"}),
]
-
+
# First run (cache miss)
print("\nš First execution (cache miss with response padding):")
for operation, strategy, params in operations:
@@ -487,26 +547,28 @@ async def demonstrate_enhanced_caching():
result = await client.cached_operation(operation, strategy, **params)
duration = (time.time() - start_time) * 1000
print(f" {operation} ({strategy}): {duration:.1f}ms")
-
+
# Second run (cache hit)
print("\nā” Second execution (cache hit):")
for operation, strategy, params in operations:
start_time = time.time()
result = await client.cached_operation(operation, strategy, **params)
duration = (time.time() - start_time) * 1000
- cached_status = "HIT" if result.get('_cached') else "MISS"
- enhanced_status = "ENHANCED" if result.get('_enhanced_content') else "ORIGINAL"
- print(f" {operation} ({strategy}): {duration:.1f}ms [{cached_status}] [{enhanced_status}]")
-
+ cached_status = "HIT" if result.get("_cached") else "MISS"
+ enhanced_status = "ENHANCED" if result.get("_enhanced_content") else "ORIGINAL"
+ print(
+ f" {operation} ({strategy}): {duration:.1f}ms [{cached_status}] [{enhanced_status}]"
+ )
+
# Cache optimization
print("\nš§ Optimizing cache performance...")
client.optimize_caching()
-
+
# Analytics
print("\nš Cache Performance Analytics:")
analytics = client.get_cache_analytics()
-
- for strategy_name, metrics in analytics['strategies'].items():
+
+ for strategy_name, metrics in analytics["strategies"].items():
print(f"\n {strategy_name.upper()} Strategy:")
print(f" Hit Rate: {metrics['hit_rate']:.1%}")
print(f" Total Requests: {metrics['total_requests']}")
@@ -515,7 +577,7 @@ async def demonstrate_enhanced_caching():
print(f" Cache Sets: {metrics['cache_sets']}")
print(f" Avg Hit Time: {metrics['avg_hit_time_ms']:.1f}ms")
print(f" Cache Size: {metrics['cache_size_bytes']} bytes")
-
+
print("\nš Enhanced caching demonstration completed successfully!")
print("\nš” Key Features Demonstrated:")
print(" ā
Response padding to meet minimum token requirements")
@@ -532,19 +594,20 @@ async def main():
"""Main demonstration function."""
logging.basicConfig(
level=logging.INFO,
- format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
+ format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
)
-
+
try:
await demonstrate_enhanced_caching()
return 0
except Exception as e:
print(f"ā Error: {e}")
import traceback
+
traceback.print_exc()
return 1
if __name__ == "__main__":
exit_code = asyncio.run(main())
- sys.exit(exit_code)
\ No newline at end of file
+ sys.exit(exit_code)
diff --git a/examples/arcade-dev/05_cache_integration/cached_arcade_client_fixed.py b/examples/arcade-dev/05_cache_integration/cached_arcade_client_fixed.py
index 58cd5b8..5a8dc86 100644
--- a/examples/arcade-dev/05_cache_integration/cached_arcade_client_fixed.py
+++ b/examples/arcade-dev/05_cache_integration/cached_arcade_client_fixed.py
@@ -45,16 +45,17 @@
@dataclass
class CacheStrategy:
"""Cache strategy configuration."""
+
strategy_name: str
ttl_seconds: int
max_entries: int = 1000
compression_enabled: bool = False
encryption_enabled: bool = False
-
+
# Performance thresholds
hit_rate_threshold: float = 0.8
response_time_threshold_ms: float = 100
-
+
# Advanced features
prefetch_enabled: bool = False
background_refresh_enabled: bool = False
@@ -63,54 +64,57 @@ class CacheStrategy:
@dataclass
class CachePerformanceMetrics:
"""Cache performance tracking."""
+
hits: int = 0
misses: int = 0
sets: int = 0
errors: int = 0
-
+
# Timing metrics
total_hit_time_ms: float = 0
total_miss_time_ms: float = 0
total_set_time_ms: float = 0
-
+
# Size metrics
cache_size_bytes: int = 0
-
+
def get_hit_rate(self) -> float:
total_requests = self.hits + self.misses
return (self.hits / total_requests) if total_requests > 0 else 0.0
-
+
def get_avg_hit_time_ms(self) -> float:
return (self.total_hit_time_ms / self.hits) if self.hits > 0 else 0.0
-
+
def get_avg_miss_time_ms(self) -> float:
return (self.total_miss_time_ms / self.misses) if self.misses > 0 else 0.0
class SimpleCacheManager:
"""Simplified cache manager that works with FACT CacheManager."""
-
- def __init__(self, primary_cache: CacheManager, strategies: Dict[str, CacheStrategy]):
+
+ def __init__(
+ self, primary_cache: CacheManager, strategies: Dict[str, CacheStrategy]
+ ):
self.primary_cache = primary_cache
self.strategies = strategies
self.logger = logging.getLogger(f"{__name__}.SimpleCacheManager")
-
+
# Performance tracking
self.metrics: Dict[str, CachePerformanceMetrics] = {
strategy_name: CachePerformanceMetrics()
for strategy_name in strategies.keys()
}
-
+
def get(self, key: str, strategy_name: str = "default") -> Optional[Any]:
"""Get value with caching strategy."""
start_time = time.time()
strategy = self.strategies.get(strategy_name, self.strategies["default"])
metrics = self.metrics[strategy_name]
-
+
try:
# Get from primary cache
cache_entry = self.primary_cache.get(key)
-
+
if cache_entry:
# Cache hit
metrics.hits += 1
@@ -123,137 +127,136 @@ def get(self, key: str, strategy_name: str = "default") -> Optional[Any]:
metrics.total_miss_time_ms += (time.time() - start_time) * 1000
self.logger.debug(f"Cache miss for key: {key}")
return None
-
+
except Exception as e:
metrics.errors += 1
self.logger.error(f"Cache get error for key {key}: {e}")
return None
-
+
def set(self, key: str, value: Any, strategy_name: str = "default") -> bool:
"""Set value with caching strategy."""
start_time = time.time()
strategy = self.strategies.get(strategy_name, self.strategies["default"])
metrics = self.metrics[strategy_name]
-
+
try:
# Store in primary cache
cache_entry = self.primary_cache.store(key, value)
-
+
if cache_entry:
metrics.sets += 1
metrics.total_set_time_ms += (time.time() - start_time) * 1000
metrics.cache_size_bytes += cache_entry.size_bytes
-
+
self.logger.debug(f"Cache set successful for key: {key}")
return True
else:
metrics.errors += 1
return False
-
+
except Exception as e:
metrics.errors += 1
self.logger.error(f"Cache set error for key {key}: {e}")
return False
-
+
def optimize_cache(self, strategy_name: str = "default"):
"""Optimize cache performance based on usage patterns."""
strategy = self.strategies.get(strategy_name, self.strategies["default"])
metrics = self.metrics[strategy_name]
-
+
try:
# Check if optimization is needed
hit_rate = metrics.get_hit_rate()
-
+
if hit_rate < strategy.hit_rate_threshold:
- self.logger.warning(f"Cache hit rate ({hit_rate:.2%}) below threshold ({strategy.hit_rate_threshold:.2%})")
- self.logger.info(f"Cache optimization completed for strategy: {strategy_name}")
-
+ self.logger.warning(
+ f"Cache hit rate ({hit_rate:.2%}) below threshold ({strategy.hit_rate_threshold:.2%})"
+ )
+ self.logger.info(
+ f"Cache optimization completed for strategy: {strategy_name}"
+ )
+
except Exception as e:
self.logger.error(f"Cache optimization error: {e}")
-
+
def get_performance_report(self) -> Dict[str, Any]:
"""Generate comprehensive performance report."""
- report = {
- 'timestamp': datetime.now(timezone.utc).isoformat(),
- 'strategies': {}
- }
-
+ report = {"timestamp": datetime.now(timezone.utc).isoformat(), "strategies": {}}
+
for strategy_name, metrics in self.metrics.items():
strategy_report = {
- 'hit_rate': metrics.get_hit_rate(),
- 'total_requests': metrics.hits + metrics.misses,
- 'cache_hits': metrics.hits,
- 'cache_misses': metrics.misses,
- 'cache_sets': metrics.sets,
- 'cache_errors': metrics.errors,
- 'avg_hit_time_ms': metrics.get_avg_hit_time_ms(),
- 'avg_miss_time_ms': metrics.get_avg_miss_time_ms(),
- 'cache_size_bytes': metrics.cache_size_bytes
+ "hit_rate": metrics.get_hit_rate(),
+ "total_requests": metrics.hits + metrics.misses,
+ "cache_hits": metrics.hits,
+ "cache_misses": metrics.misses,
+ "cache_sets": metrics.sets,
+ "cache_errors": metrics.errors,
+ "avg_hit_time_ms": metrics.get_avg_hit_time_ms(),
+ "avg_miss_time_ms": metrics.get_avg_miss_time_ms(),
+ "cache_size_bytes": metrics.cache_size_bytes,
}
-
- report['strategies'][strategy_name] = strategy_report
-
+
+ report["strategies"][strategy_name] = strategy_report
+
return report
class AdvancedArcadeClient:
"""Advanced Arcade.dev client with sophisticated caching."""
-
+
def __init__(self, api_key: str, cache_manager: CacheManager):
self.api_key = api_key
self.logger = logging.getLogger(__name__)
-
+
# Detect demo mode based on API key
self.demo_mode = (
- not api_key or
- api_key in ["demo_api_key", "test-key-for-testing"] or
- api_key.startswith("demo_") or
- api_key.startswith("test_")
+ not api_key
+ or api_key in ["demo_api_key", "test-key-for-testing"]
+ or api_key.startswith("demo_")
+ or api_key.startswith("test_")
)
-
+
if self.demo_mode:
self.logger.info("Running in DEMO MODE - no real API calls will be made")
else:
self.logger.info("Running with real API credentials")
-
+
# Setup cache strategies
self.cache_strategies = {
- "default": CacheStrategy(
- strategy_name="default",
- ttl_seconds=3600
- ),
+ "default": CacheStrategy(strategy_name="default", ttl_seconds=3600),
"fast": CacheStrategy(
- strategy_name="fast",
- ttl_seconds=300,
- prefetch_enabled=True
+ strategy_name="fast", ttl_seconds=300, prefetch_enabled=True
),
"persistent": CacheStrategy(
strategy_name="persistent",
ttl_seconds=86400, # 24 hours
compression_enabled=True,
- background_refresh_enabled=True
+ background_refresh_enabled=True,
),
"secure": CacheStrategy(
- strategy_name="secure",
- ttl_seconds=1800,
- encryption_enabled=True
- )
+ strategy_name="secure", ttl_seconds=1800, encryption_enabled=True
+ ),
}
-
+
# Initialize simplified cache manager
self.simple_cache = SimpleCacheManager(cache_manager, self.cache_strategies)
-
+
def _generate_cache_key(self, operation: str, **kwargs) -> str:
"""Generate cache key for operation."""
params_str = json.dumps(kwargs, sort_keys=True, default=str)
params_hash = hashlib.sha256(params_str.encode()).hexdigest()[:16]
return f"arcade:{operation}:{params_hash}"
-
- async def cached_operation(self, operation: str, strategy: str = "default",
- force_refresh: bool = False, **kwargs) -> Dict[str, Any]:
+
+ async def cached_operation(
+ self,
+ operation: str,
+ strategy: str = "default",
+ force_refresh: bool = False,
+ **kwargs,
+ ) -> Dict[str, Any]:
"""Execute operation with advanced caching."""
cache_key = self._generate_cache_key(operation, **kwargs)
-
+
# Check cache first
if not force_refresh:
cached_result = self.simple_cache.get(cache_key, strategy)
@@ -266,78 +269,78 @@ async def cached_operation(self, operation: str, strategy: str = "default",
except json.JSONDecodeError:
# If not JSON, wrap in result structure
return {
- 'operation': operation,
- 'status': 'success',
- 'result': cached_result,
- 'parameters': kwargs,
- '_cached': True,
- '_timestamp': datetime.now(timezone.utc).isoformat()
+ "operation": operation,
+ "status": "success",
+ "result": cached_result,
+ "parameters": kwargs,
+ "_cached": True,
+ "_timestamp": datetime.now(timezone.utc).isoformat(),
}
return cached_result
-
+
# Execute operation
start_time = time.time()
result = await self._execute_operation(operation, **kwargs)
execution_time = (time.time() - start_time) * 1000
-
+
# Add metadata
- result['_execution_time_ms'] = execution_time
- result['_cached'] = False
- result['_timestamp'] = datetime.now(timezone.utc).isoformat()
-
+ result["_execution_time_ms"] = execution_time
+ result["_cached"] = False
+ result["_timestamp"] = datetime.now(timezone.utc).isoformat()
+
# Cache result as JSON string
result_json = json.dumps(result, default=str)
self.simple_cache.set(cache_key, result_json, strategy)
-
+
return result
-
+
async def _execute_operation(self, operation: str, **kwargs) -> Dict[str, Any]:
"""Execute the actual operation (mock implementation for demo)."""
if self.demo_mode:
# Simulate different operation types with realistic delays
operation_times = {
- 'code_analysis': 2.0,
- 'test_generation': 3.0,
- 'documentation': 1.5,
- 'refactoring': 2.5
+ "code_analysis": 2.0,
+ "test_generation": 3.0,
+ "documentation": 1.5,
+ "refactoring": 2.5,
}
-
+
await asyncio.sleep(operation_times.get(operation, 1.0))
-
+
# Generate demo results based on operation type
demo_results = {
- 'code_analysis': f"Demo analysis of {kwargs.get('language', 'code')}: Found {len(kwargs.get('code', ''))} characters, suggests 3 improvements",
- 'test_generation': f"Demo test suite generated with 5 test cases for function: {kwargs.get('code', 'function')[:20]}...",
- 'documentation': f"Demo documentation generated for {len(kwargs.get('code', ''))} lines of code",
- 'refactoring': f"Demo refactoring suggestions: Extract 2 methods, reduce complexity by 15%"
+ "code_analysis": f"Demo analysis of {kwargs.get('language', 'code')}: Found {len(kwargs.get('code', ''))} characters, suggests 3 improvements",
+ "test_generation": f"Demo test suite generated with 5 test cases for function: {kwargs.get('code', 'function')[:20]}...",
+ "documentation": f"Demo documentation generated for {len(kwargs.get('code', ''))} lines of code",
+ "refactoring": f"Demo refactoring suggestions: Extract 2 methods, reduce complexity by 15%",
}
-
+
return {
- 'operation': operation,
- 'status': 'success',
- 'result': demo_results.get(operation, f"Demo result for {operation}"),
- 'parameters': kwargs,
- 'demo_mode': True
+ "operation": operation,
+ "status": "success",
+ "result": demo_results.get(operation, f"Demo result for {operation}"),
+ "parameters": kwargs,
+ "demo_mode": True,
}
else:
# In live mode, this would make actual API calls to Arcade.dev
# For now, we'll simulate since we don't have real API integration
self.logger.warning("Live API mode not yet implemented - using simulation")
await asyncio.sleep(1.0) # Faster for live mode
-
+
return {
- 'operation': operation,
- 'status': 'simulated',
- 'result': f"Simulated {operation} result (would be real API call)",
- 'parameters': kwargs,
- 'demo_mode': False
+ "operation": operation,
+ "status": "simulated",
+ "result": f"Simulated {operation} result (would be real API call)",
+ "parameters": kwargs,
+ "demo_mode": False,
}
-
+
def optimize_caching(self):
"""Optimize all cache strategies."""
for strategy_name in self.cache_strategies.keys():
self.simple_cache.optimize_cache(strategy_name)
-
+
def get_cache_analytics(self) -> Dict[str, Any]:
"""Get comprehensive cache analytics."""
return self.simple_cache.get_performance_report()
@@ -347,31 +350,35 @@ async def demonstrate_advanced_caching():
"""Demonstrate advanced caching capabilities."""
print("š Advanced Cache Integration Demo - Fixed Version")
print("=" * 60)
-
+
# Initialize cache manager with proper configuration
cache_config = get_default_cache_config()
cache_manager = CacheManager(cache_config)
-
+
# Get API key from environment or use demo mode
api_key = os.getenv("ARCADE_API_KEY", "demo_api_key")
-
+
# Create advanced client
client = AdvancedArcadeClient(api_key, cache_manager)
-
+
# Show mode status
mode_status = "š® DEMO MODE" if client.demo_mode else "š LIVE MODE"
print(f"\n{mode_status} - Using API key: {api_key[:10]}...")
-
+
print("\nš Testing different cache strategies...")
-
+
# Test operations with different strategies
operations = [
- ("code_analysis", "fast", {"code": "def hello(): print('world')", "language": "python"}),
+ (
+ "code_analysis",
+ "fast",
+ {"code": "def hello(): print('world')", "language": "python"},
+ ),
("test_generation", "default", {"code": "def add(a, b): return a + b"}),
("documentation", "persistent", {"code": "class MyClass: pass"}),
- ("refactoring", "secure", {"code": "legacy_code_here"})
+ ("refactoring", "secure", {"code": "legacy_code_here"}),
]
-
+
# First run (cache miss)
print("\nš First execution (cache miss):")
for operation, strategy, params in operations:
@@ -379,25 +386,25 @@ async def demonstrate_advanced_caching():
result = await client.cached_operation(operation, strategy, **params)
duration = (time.time() - start_time) * 1000
print(f" {operation} ({strategy}): {duration:.1f}ms")
-
+
# Second run (cache hit)
print("\nā” Second execution (cache hit):")
for operation, strategy, params in operations:
start_time = time.time()
result = await client.cached_operation(operation, strategy, **params)
duration = (time.time() - start_time) * 1000
- cached_status = "HIT" if result.get('_cached') else "MISS"
+ cached_status = "HIT" if result.get("_cached") else "MISS"
print(f" {operation} ({strategy}): {duration:.1f}ms [{cached_status}]")
-
+
# Cache optimization
print("\nš§ Optimizing cache performance...")
client.optimize_caching()
-
+
# Analytics
print("\nš Cache Performance Analytics:")
analytics = client.get_cache_analytics()
-
- for strategy_name, metrics in analytics['strategies'].items():
+
+ for strategy_name, metrics in analytics["strategies"].items():
print(f"\n {strategy_name.upper()} Strategy:")
print(f" Hit Rate: {metrics['hit_rate']:.1%}")
print(f" Total Requests: {metrics['total_requests']}")
@@ -405,7 +412,7 @@ async def demonstrate_advanced_caching():
print(f" Cache Misses: {metrics['cache_misses']}")
print(f" Avg Hit Time: {metrics['avg_hit_time_ms']:.1f}ms")
print(f" Cache Size: {metrics['cache_size_bytes']} bytes")
-
+
print("\nš Advanced caching demonstration completed successfully!")
print("\nš” Key Features Demonstrated:")
print(" ā
Multiple cache strategies (fast, default, persistent, secure)")
@@ -420,19 +427,20 @@ async def main():
"""Main demonstration function."""
logging.basicConfig(
level=logging.INFO,
- format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
+ format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
)
-
+
try:
await demonstrate_advanced_caching()
return 0
except Exception as e:
print(f"ā Error: {e}")
import traceback
+
traceback.print_exc()
return 1
if __name__ == "__main__":
exit_code = asyncio.run(main())
- sys.exit(exit_code)
\ No newline at end of file
+ sys.exit(exit_code)
diff --git a/examples/arcade-dev/06_security/secure_tool_execution.py b/examples/arcade-dev/06_security/secure_tool_execution.py
index 52258c5..eebf67c 100644
--- a/examples/arcade-dev/06_security/secure_tool_execution.py
+++ b/examples/arcade-dev/06_security/secure_tool_execution.py
@@ -36,25 +36,26 @@
from src.security.token_manager import TokenManager
from src.core.driver import FACTDriver
from src.cache.manager import CacheManager
+
FACT_AVAILABLE = True
except ImportError as e:
print(f"ā ļø FACT modules not fully available: {e}")
print("š Running in demo mode with mock implementations")
FACT_AVAILABLE = False
-
+
# Create mock classes for demo
class MockAuthorizationManager:
def __init__(self, *args, **kwargs):
pass
-
+
class MockInputSanitizer:
def __init__(self, *args, **kwargs):
pass
-
+
class MockTokenManager:
def __init__(self, *args, **kwargs):
pass
-
+
AuthorizationManager = MockAuthorizationManager
InputSanitizer = MockInputSanitizer
TokenManager = MockTokenManager
@@ -63,23 +64,26 @@ def __init__(self, *args, **kwargs):
@dataclass
class SecurityConfig:
"""Security configuration for Arcade.dev integration."""
+
enable_input_validation: bool = True
enable_audit_logging: bool = True
enable_rate_limiting: bool = True
enable_data_encryption: bool = True
-
+
# Rate limiting
max_requests_per_minute: int = 60
max_requests_per_hour: int = 1000
-
+
# Input validation
max_input_size: int = 1024 * 1024 # 1MB
- allowed_file_types: Set[str] = field(default_factory=lambda: {'.py', '.js', '.ts', '.java', '.cpp'})
-
+ allowed_file_types: Set[str] = field(
+ default_factory=lambda: {".py", ".js", ".ts", ".java", ".cpp"}
+ )
+
# Audit settings
audit_log_file: str = "arcade_security_audit.log"
log_sensitive_data: bool = False
-
+
# Token settings
token_expiry_minutes: int = 60
refresh_token_expiry_days: int = 7
@@ -88,6 +92,7 @@ class SecurityConfig:
@dataclass
class UserPermissions:
"""User permission model."""
+
user_id: str
scopes: Set[str]
max_daily_requests: int = 1000
@@ -100,6 +105,7 @@ class UserPermissions:
@dataclass
class AuditLogEntry:
"""Audit log entry for security tracking."""
+
timestamp: datetime
user_id: str
operation: str
@@ -113,38 +119,40 @@ class AuditLogEntry:
class SecureCredentialManager:
"""Manages secure storage and retrieval of API credentials."""
-
+
def __init__(self):
self.logger = logging.getLogger(f"{__name__}.CredentialManager")
self._credentials: Dict[str, str] = {}
self._demo_mode = not FACT_AVAILABLE
-
+
if self._demo_mode:
- self.logger.info("Running in demo mode - using simplified credential storage")
- self._encryption_key = b'demo_key_for_testing_only_not_secure'
+ self.logger.info(
+ "Running in demo mode - using simplified credential storage"
+ )
+ self._encryption_key = b"demo_key_for_testing_only_not_secure"
else:
self._encryption_key = self._get_or_create_encryption_key()
-
+
def _get_or_create_encryption_key(self) -> bytes:
"""Get or create encryption key for credential storage."""
# In demo mode, use a simple key
if self._demo_mode:
- return b'demo_key_for_testing_only_not_secure'
-
- key_file = Path.home() / '.fact' / 'encryption.key'
-
+ return b"demo_key_for_testing_only_not_secure"
+
+ key_file = Path.home() / ".fact" / "encryption.key"
+
if key_file.exists():
- with open(key_file, 'rb') as f:
+ with open(key_file, "rb") as f:
return f.read()
else:
# Create new key
key_file.parent.mkdir(exist_ok=True)
key = secrets.token_bytes(32)
- with open(key_file, 'wb') as f:
+ with open(key_file, "wb") as f:
f.write(key)
key_file.chmod(0o600) # Read only for owner
return key
-
+
def store_credential(self, service: str, credential: str) -> bool:
"""Securely store a credential."""
try:
@@ -156,134 +164,148 @@ def store_credential(self, service: str, credential: str) -> bool:
except Exception as e:
self.logger.error(f"Failed to store credential for {service}: {e}")
return False
-
+
def get_credential(self, service: str) -> Optional[str]:
"""Retrieve a credential securely."""
if self._demo_mode:
# In demo mode, return a mock credential if no env var is set
env_key = f"{service.upper()}_API_KEY"
credential = os.getenv(env_key)
-
+
if not credential:
# Return demo credential for testing
credential = f"demo_key_for_{service}_testing_only"
self.logger.info(f"Using demo credential for service: {service}")
else:
self.logger.info(f"Using environment credential for service: {service}")
-
+
return credential
else:
# In production, implement proper decryption
# For now, use environment variables
env_key = f"{service.upper()}_API_KEY"
credential = os.getenv(env_key)
-
+
if not credential:
self.logger.warning(f"No credential found for service: {service}")
-
+
return credential
class SecurityValidator:
"""Validates and sanitizes inputs for security."""
-
+
def __init__(self, config: SecurityConfig):
self.config = config
self.logger = logging.getLogger(f"{__name__}.SecurityValidator")
-
+
# Dangerous patterns to detect
self.dangerous_patterns = [
- r'eval\s*\(',
- r'exec\s*\(',
- r'import\s+os',
- r'import\s+subprocess',
- r'__import__',
- r'open\s*\(',
- r'file\s*\(',
- r'input\s*\(',
- r'raw_input\s*\(',
+ r"eval\s*\(",
+ r"exec\s*\(",
+ r"import\s+os",
+ r"import\s+subprocess",
+ r"__import__",
+ r"open\s*\(",
+ r"file\s*\(",
+ r"input\s*\(",
+ r"raw_input\s*\(",
]
-
- def validate_code_input(self, code: str, language: str = 'python') -> Dict[str, Any]:
+
+ def validate_code_input(
+ self, code: str, language: str = "python"
+ ) -> Dict[str, Any]:
"""Validate code input for security risks."""
validation_result = {
- 'is_safe': True,
- 'warnings': [],
- 'blocked_patterns': [],
- 'sanitized_code': code
+ "is_safe": True,
+ "warnings": [],
+ "blocked_patterns": [],
+ "sanitized_code": code,
}
-
+
if not self.config.enable_input_validation:
return validation_result
-
+
# Size check
if len(code) > self.config.max_input_size:
- validation_result['is_safe'] = False
- validation_result['warnings'].append(f"Code exceeds maximum size limit: {self.config.max_input_size}")
-
+ validation_result["is_safe"] = False
+ validation_result["warnings"].append(
+ f"Code exceeds maximum size limit: {self.config.max_input_size}"
+ )
+
# Pattern detection
for pattern in self.dangerous_patterns:
if re.search(pattern, code, re.IGNORECASE):
- validation_result['blocked_patterns'].append(pattern)
- validation_result['warnings'].append(f"Potentially dangerous pattern detected: {pattern}")
-
+ validation_result["blocked_patterns"].append(pattern)
+ validation_result["warnings"].append(
+ f"Potentially dangerous pattern detected: {pattern}"
+ )
+
# If dangerous patterns found, mark as unsafe
- if validation_result['blocked_patterns']:
- validation_result['is_safe'] = False
-
+ if validation_result["blocked_patterns"]:
+ validation_result["is_safe"] = False
+
return validation_result
class AuditLogger:
"""Handles security audit logging."""
-
+
def __init__(self, config: SecurityConfig):
self.config = config
self.audit_file = Path(config.audit_log_file)
self.logger = logging.getLogger(f"{__name__}.AuditLogger")
-
+
# Setup audit log file
self._setup_audit_logging()
-
+
def _setup_audit_logging(self):
"""Setup secure audit logging."""
if self.config.enable_audit_logging:
# Create audit log directory
self.audit_file.parent.mkdir(exist_ok=True)
-
+
# Setup file permissions (read/write for owner only)
if self.audit_file.exists():
self.audit_file.chmod(0o600)
-
+
def log_event(self, entry: AuditLogEntry):
"""Log a security audit event."""
if not self.config.enable_audit_logging:
return
-
+
try:
log_data = {
- 'timestamp': entry.timestamp.isoformat(),
- 'request_id': entry.request_id,
- 'user_id': entry.user_id if not self.config.log_sensitive_data else self._hash_user_id(entry.user_id),
- 'operation': entry.operation,
- 'status': entry.status,
- 'ip_address': entry.ip_address if not self.config.log_sensitive_data else self._hash_ip(entry.ip_address),
- 'risk_score': entry.risk_score,
- 'metadata': entry.metadata
+ "timestamp": entry.timestamp.isoformat(),
+ "request_id": entry.request_id,
+ "user_id": (
+ entry.user_id
+ if not self.config.log_sensitive_data
+ else self._hash_user_id(entry.user_id)
+ ),
+ "operation": entry.operation,
+ "status": entry.status,
+ "ip_address": (
+ entry.ip_address
+ if not self.config.log_sensitive_data
+ else self._hash_ip(entry.ip_address)
+ ),
+ "risk_score": entry.risk_score,
+ "metadata": entry.metadata,
}
-
- with open(self.audit_file, 'a') as f:
- f.write(json.dumps(log_data) + '\n')
-
+
+ with open(self.audit_file, "a") as f:
+ f.write(json.dumps(log_data) + "\n")
+
self.logger.debug(f"Audit event logged: {entry.operation} - {entry.status}")
-
+
except Exception as e:
self.logger.error(f"Failed to log audit event: {e}")
-
+
def _hash_user_id(self, user_id: str) -> str:
"""Hash user ID for privacy."""
return hashlib.sha256(user_id.encode()).hexdigest()[:16]
-
+
def _hash_ip(self, ip_address: Optional[str]) -> Optional[str]:
"""Hash IP address for privacy."""
if not ip_address:
@@ -293,140 +315,154 @@ def _hash_ip(self, ip_address: Optional[str]) -> Optional[str]:
class SecureArcadeClient:
"""Secure Arcade.dev client with comprehensive security measures."""
-
+
def __init__(self, config: SecurityConfig):
self.config = config
self.credential_manager = SecureCredentialManager()
self.validator = SecurityValidator(config)
self.audit_logger = AuditLogger(config)
self.logger = logging.getLogger(__name__)
-
+
# User permission storage (in production, use database)
self.user_permissions: Dict[str, UserPermissions] = {}
-
+
# Active sessions
self.active_sessions: Dict[str, Dict[str, Any]] = {}
-
- def register_user(self, user_id: str, scopes: Set[str], allowed_operations: Set[str] = None) -> bool:
+
+ def register_user(
+ self, user_id: str, scopes: Set[str], allowed_operations: Set[str] = None
+ ) -> bool:
"""Register a user with specific permissions."""
try:
if allowed_operations is None:
- allowed_operations = {'code_analysis', 'test_generation', 'documentation'}
-
+ allowed_operations = {
+ "code_analysis",
+ "test_generation",
+ "documentation",
+ }
+
permissions = UserPermissions(
- user_id=user_id,
- scopes=scopes,
- allowed_operations=allowed_operations
+ user_id=user_id, scopes=scopes, allowed_operations=allowed_operations
)
-
+
self.user_permissions[user_id] = permissions
-
+
audit_entry = AuditLogEntry(
timestamp=datetime.now(timezone.utc),
user_id=user_id,
- operation='user_registration',
- status='success',
- metadata={'scopes': list(scopes), 'operations': list(allowed_operations)}
+ operation="user_registration",
+ status="success",
+ metadata={
+ "scopes": list(scopes),
+ "operations": list(allowed_operations),
+ },
)
self.audit_logger.log_event(audit_entry)
-
+
self.logger.info(f"User registered: {user_id}")
return True
-
+
except Exception as e:
self.logger.error(f"Failed to register user {user_id}: {e}")
return False
-
- def authenticate_user(self, user_id: str, api_key: str, ip_address: str = None) -> Dict[str, Any]:
+
+ def authenticate_user(
+ self, user_id: str, api_key: str, ip_address: str = None
+ ) -> Dict[str, Any]:
"""Authenticate user and create secure session."""
auth_result = {
- 'authenticated': False,
- 'session_token': None,
- 'permissions': None,
- 'error': None
+ "authenticated": False,
+ "session_token": None,
+ "permissions": None,
+ "error": None,
}
-
+
try:
# Check if user exists
if user_id not in self.user_permissions:
audit_entry = AuditLogEntry(
timestamp=datetime.now(timezone.utc),
user_id=user_id,
- operation='authentication',
- status='failure',
+ operation="authentication",
+ status="failure",
ip_address=ip_address,
- metadata={'reason': 'User not found'},
- risk_score=7.0
+ metadata={"reason": "User not found"},
+ risk_score=7.0,
)
self.audit_logger.log_event(audit_entry)
- auth_result['error'] = 'User not found'
+ auth_result["error"] = "User not found"
return auth_result
-
+
# Create session token
session_token = secrets.token_urlsafe(32)
user_perms = self.user_permissions[user_id]
session_data = {
- 'user_id': user_id,
- 'created_at': datetime.now(timezone.utc),
- 'expires_at': datetime.now(timezone.utc) + timedelta(minutes=self.config.token_expiry_minutes),
- 'ip_address': ip_address,
- 'permissions': user_perms
+ "user_id": user_id,
+ "created_at": datetime.now(timezone.utc),
+ "expires_at": datetime.now(timezone.utc)
+ + timedelta(minutes=self.config.token_expiry_minutes),
+ "ip_address": ip_address,
+ "permissions": user_perms,
}
-
+
self.active_sessions[session_token] = session_data
-
+
audit_entry = AuditLogEntry(
timestamp=datetime.now(timezone.utc),
user_id=user_id,
- operation='authentication',
- status='success',
- ip_address=ip_address
+ operation="authentication",
+ status="success",
+ ip_address=ip_address,
)
self.audit_logger.log_event(audit_entry)
-
- auth_result.update({
- 'authenticated': True,
- 'session_token': session_token,
- 'permissions': {
- 'scopes': list(user_perms.scopes),
- 'operations': list(user_perms.allowed_operations),
- 'expires_at': session_data['expires_at'].isoformat()
+
+ auth_result.update(
+ {
+ "authenticated": True,
+ "session_token": session_token,
+ "permissions": {
+ "scopes": list(user_perms.scopes),
+ "operations": list(user_perms.allowed_operations),
+ "expires_at": session_data["expires_at"].isoformat(),
+ },
}
- })
-
+ )
+
return auth_result
-
+
except Exception as e:
self.logger.error(f"Authentication error for {user_id}: {e}")
- auth_result['error'] = 'Authentication failed'
+ auth_result["error"] = "Authentication failed"
return auth_result
-
- async def secure_code_analysis(self, session_token: str, code: str, language: str = 'python') -> Dict[str, Any]:
+
+ async def secure_code_analysis(
+ self, session_token: str, code: str, language: str = "python"
+ ) -> Dict[str, Any]:
"""Perform secure code analysis with full security validation."""
# Validate session
if session_token not in self.active_sessions:
- return {'error': 'Invalid session', 'status': 'unauthorized'}
-
+ return {"error": "Invalid session", "status": "unauthorized"}
+
session = self.active_sessions[session_token]
- user_id = session['user_id']
-
+ user_id = session["user_id"]
+
# Input validation
validation_result = self.validator.validate_code_input(code, language)
- if not validation_result['is_safe']:
+ if not validation_result["is_safe"]:
return {
- 'error': 'Input validation failed',
- 'warnings': validation_result['warnings'],
- 'status': 'invalid_input'
+ "error": "Input validation failed",
+ "warnings": validation_result["warnings"],
+ "status": "invalid_input",
}
-
+
# Mock analysis result
return {
- 'status': 'success',
- 'analysis': {
- 'score': 8.5,
- 'suggestions': ['Use type hints', 'Add error handling'],
- 'security_check_passed': True
- }
+ "status": "success",
+ "analysis": {
+ "score": 8.5,
+ "suggestions": ["Use type hints", "Add error handling"],
+ "security_check_passed": True,
+ },
}
@@ -434,55 +470,53 @@ async def demonstrate_security():
"""Demonstrate security features."""
print("š Secure Arcade.dev Integration Demo")
print("=" * 50)
-
+
# Configure security (ensure logs directory exists)
logs_dir = Path("logs")
logs_dir.mkdir(exist_ok=True)
-
+
config = SecurityConfig(
enable_input_validation=True,
enable_audit_logging=True,
- audit_log_file="logs/arcade_security_audit.log"
+ audit_log_file="logs/arcade_security_audit.log",
)
-
+
# Create secure client
client = SecureArcadeClient(config)
-
+
# Register user
client.register_user(
- user_id='demo_user',
- scopes={'read', 'write'},
- allowed_operations={'code_analysis', 'test_generation'}
+ user_id="demo_user",
+ scopes={"read", "write"},
+ allowed_operations={"code_analysis", "test_generation"},
)
-
+
# Authenticate
auth_result = client.authenticate_user(
- user_id='demo_user',
- api_key='demo_key',
- ip_address='127.0.0.1'
+ user_id="demo_user", api_key="demo_key", ip_address="127.0.0.1"
)
-
- if auth_result['authenticated']:
+
+ if auth_result["authenticated"]:
print("ā
User authenticated successfully")
- session_token = auth_result['session_token']
-
+ session_token = auth_result["session_token"]
+
# Test safe code
safe_code = "def hello_world():\n print('Hello, World!')"
result = await client.secure_code_analysis(session_token, safe_code)
print(f"ā
Safe code analysis: {result['status']}")
-
+
# Test dangerous code
dangerous_code = "import os\nos.system('rm -rf /')"
result = await client.secure_code_analysis(session_token, dangerous_code)
print(f"š”ļø Dangerous code blocked: {result['status']}")
-
+
print("\nš Security demonstration completed!")
async def main():
"""Main demonstration function."""
logging.basicConfig(level=logging.INFO)
-
+
try:
await demonstrate_security()
return 0
@@ -493,4 +527,4 @@ async def main():
if __name__ == "__main__":
exit_code = asyncio.run(main())
- sys.exit(exit_code)
\ No newline at end of file
+ sys.exit(exit_code)
diff --git a/examples/arcade-dev/07_cache_integration/cached_arcade_client.py b/examples/arcade-dev/07_cache_integration/cached_arcade_client.py
index 931af45..0ec09d3 100644
--- a/examples/arcade-dev/07_cache_integration/cached_arcade_client.py
+++ b/examples/arcade-dev/07_cache_integration/cached_arcade_client.py
@@ -32,321 +32,465 @@
@dataclass
class CacheConfig:
"""Configuration for caching behavior."""
+
default_ttl: int = 3600 # 1 hour
max_size: int = 1000
enabled: bool = True
key_prefix: str = "arcade"
-
+
# Cache TTL by operation type
ttl_by_operation: Dict[str, int] = None
-
+
def __post_init__(self):
if self.ttl_by_operation is None:
self.ttl_by_operation = {
- 'health': 300, # 5 minutes
- 'user_info': 1800, # 30 minutes
- 'code_analysis': 7200, # 2 hours
- 'test_generation': 3600, # 1 hour
- 'documentation': 7200, # 2 hours
- 'refactoring': 3600, # 1 hour
+ "health": 300, # 5 minutes
+ "user_info": 1800, # 30 minutes
+ "code_analysis": 7200, # 2 hours
+ "test_generation": 3600, # 1 hour
+ "documentation": 7200, # 2 hours
+ "refactoring": 3600, # 1 hour
}
class CachedArcadeClient:
"""Arcade.dev client with intelligent caching."""
-
- def __init__(self, api_key: str, cache_manager: CacheManager, cache_config: CacheConfig = None):
+
+ def __init__(
+ self,
+ api_key: str,
+ cache_manager: CacheManager,
+ cache_config: CacheConfig = None,
+ ):
self.api_key = api_key
self.cache_manager = cache_manager
self.cache_config = cache_config or CacheConfig()
self.logger = logging.getLogger(__name__)
self.stats = {
- 'cache_hits': 0,
- 'cache_misses': 0,
- 'api_calls': 0,
- 'cache_sets': 0,
- 'cache_errors': 0
+ "cache_hits": 0,
+ "cache_misses": 0,
+ "api_calls": 0,
+ "cache_sets": 0,
+ "cache_errors": 0,
}
-
+
def _generate_cache_key(self, operation: str, **kwargs) -> str:
"""Generate a consistent cache key for the operation."""
# Create a stable hash of the parameters
param_str = json.dumps(kwargs, sort_keys=True, default=str)
param_hash = hashlib.sha256(param_str.encode()).hexdigest()[:16]
-
+
return f"{self.cache_config.key_prefix}:{operation}:{param_hash}"
-
+
def _get_cache_ttl(self, operation: str) -> int:
"""Get appropriate TTL for the operation."""
return self.cache_config.ttl_by_operation.get(
- operation,
- self.cache_config.default_ttl
+ operation, self.cache_config.default_ttl
)
-
+
async def _get_from_cache(self, cache_key: str) -> Optional[Dict[str, Any]]:
"""Retrieve data from cache with error handling."""
if not self.cache_config.enabled:
return None
-
+
try:
cache_entry = self.cache_manager.get(cache_key)
if cache_entry and cache_entry.is_valid:
- self.stats['cache_hits'] += 1
+ self.stats["cache_hits"] += 1
self.logger.debug(f"Cache hit for key: {cache_key}")
# Parse the JSON content back to dict
return json.loads(cache_entry.content)
else:
- self.stats['cache_misses'] += 1
+ self.stats["cache_misses"] += 1
self.logger.debug(f"Cache miss for key: {cache_key}")
return None
-
+
except Exception as e:
- self.stats['cache_errors'] += 1
+ self.stats["cache_errors"] += 1
self.logger.warning(f"Cache retrieval error for {cache_key}: {e}")
return None
-
+
async def _set_cache(self, cache_key: str, data: Dict[str, Any], ttl: int) -> bool:
"""Store data in cache with error handling."""
if not self.cache_config.enabled:
return False
-
+
try:
# Add metadata to cached data
cache_data = {
- 'data': data,
- 'cached_at': time.time(),
- 'ttl': ttl,
- 'operation': cache_key.split(':')[1] if ':' in cache_key else 'unknown'
+ "data": data,
+ "cached_at": time.time(),
+ "ttl": ttl,
+ "operation": cache_key.split(":")[1] if ":" in cache_key else "unknown",
}
-
+
self.cache_manager.store(cache_key, json.dumps(cache_data))
- self.stats['cache_sets'] += 1
+ self.stats["cache_sets"] += 1
self.logger.debug(f"Cached data for key: {cache_key} (TTL: {ttl}s)")
return True
-
+
except Exception as e:
- self.stats['cache_errors'] += 1
+ self.stats["cache_errors"] += 1
self.logger.warning(f"Cache storage error for {cache_key}: {e}")
return False
-
+
async def _simulate_api_call(self, operation: str, **kwargs) -> Dict[str, Any]:
"""Simulate an API call to Arcade.dev (replace with actual implementation)."""
- self.stats['api_calls'] += 1
-
+ self.stats["api_calls"] += 1
+
# Simulate different response times
delay_map = {
- 'health': 0.1,
- 'user_info': 0.3,
- 'code_analysis': 2.0,
- 'test_generation': 3.0,
- 'documentation': 2.5,
- 'refactoring': 1.8,
+ "health": 0.1,
+ "user_info": 0.3,
+ "code_analysis": 2.0,
+ "test_generation": 3.0,
+ "documentation": 2.5,
+ "refactoring": 1.8,
}
-
+
await asyncio.sleep(delay_map.get(operation, 1.0))
-
+
# Generate mock responses based on operation
- if operation == 'health':
- return {'status': 'healthy', 'timestamp': time.time()}
-
- elif operation == 'user_info':
+ if operation == "health":
+ return {"status": "healthy", "timestamp": time.time()}
+
+ elif operation == "user_info":
return {
- 'user_id': '12345',
- 'username': 'demo_user',
- 'tier': 'premium',
- 'api_calls_remaining': 1000
+ "user_id": "12345",
+ "username": "demo_user",
+ "tier": "premium",
+ "api_calls_remaining": 1000,
}
-
- elif operation == 'code_analysis':
- code = kwargs.get('code', '')
+
+ elif operation == "code_analysis":
+ code = kwargs.get("code", "")
return {
- 'analysis_id': f"analysis_{hash(code) % 10000}",
- 'language': kwargs.get('language', 'python'),
- 'lines_analyzed': len(code.split('\n')),
- 'suggestions': [
- {'type': 'performance', 'message': 'Consider using list comprehension for better performance and readability'},
- {'type': 'style', 'message': 'Function name should be snake_case according to PEP 8 style guidelines'},
- {'type': 'security', 'message': 'Validate input parameters to prevent injection attacks and ensure data integrity'},
- {'type': 'maintainability', 'message': 'Break down large functions into smaller, more focused methods'},
- {'type': 'documentation', 'message': 'Add comprehensive docstrings with parameter descriptions and return types'},
- {'type': 'error_handling', 'message': 'Implement proper exception handling for robust error management'}
+ "analysis_id": f"analysis_{hash(code) % 10000}",
+ "language": kwargs.get("language", "python"),
+ "lines_analyzed": len(code.split("\n")),
+ "suggestions": [
+ {
+ "type": "performance",
+ "message": "Consider using list comprehension for better performance and readability",
+ },
+ {
+ "type": "style",
+ "message": "Function name should be snake_case according to PEP 8 style guidelines",
+ },
+ {
+ "type": "security",
+ "message": "Validate input parameters to prevent injection attacks and ensure data integrity",
+ },
+ {
+ "type": "maintainability",
+ "message": "Break down large functions into smaller, more focused methods",
+ },
+ {
+ "type": "documentation",
+ "message": "Add comprehensive docstrings with parameter descriptions and return types",
+ },
+ {
+ "type": "error_handling",
+ "message": "Implement proper exception handling for robust error management",
+ },
],
- 'detailed_analysis': {
- 'complexity_metrics': {'cyclomatic_complexity': 12, 'cognitive_complexity': 15, 'nesting_depth': 4},
- 'quality_scores': {'maintainability': 8.5, 'reliability': 9.2, 'security': 8.8, 'efficiency': 7.9},
- 'code_smells': ['Long method', 'Feature envy', 'Data clumps', 'Primitive obsession'],
- 'best_practices': ['Add type hints', 'Use constants for magic numbers', 'Implement logging', 'Add unit tests']
+ "detailed_analysis": {
+ "complexity_metrics": {
+ "cyclomatic_complexity": 12,
+ "cognitive_complexity": 15,
+ "nesting_depth": 4,
+ },
+ "quality_scores": {
+ "maintainability": 8.5,
+ "reliability": 9.2,
+ "security": 8.8,
+ "efficiency": 7.9,
+ },
+ "code_smells": [
+ "Long method",
+ "Feature envy",
+ "Data clumps",
+ "Primitive obsession",
+ ],
+ "best_practices": [
+ "Add type hints",
+ "Use constants for magic numbers",
+ "Implement logging",
+ "Add unit tests",
+ ],
},
- 'score': 8.5,
- 'timestamp': time.time(),
- 'cache_padding': 'This response has been padded to meet minimum token requirements for effective caching. The analysis includes comprehensive code quality metrics, security assessments, performance optimizations, and maintainability recommendations. Additional details about coding standards, best practices, and architectural considerations are included to provide thorough analysis results.'
+ "score": 8.5,
+ "timestamp": time.time(),
+ "cache_padding": "This response has been padded to meet minimum token requirements for effective caching. The analysis includes comprehensive code quality metrics, security assessments, performance optimizations, and maintainability recommendations. Additional details about coding standards, best practices, and architectural considerations are included to provide thorough analysis results.",
}
-
- elif operation == 'test_generation':
+
+ elif operation == "test_generation":
return {
- 'test_id': f"test_{time.time()}",
- 'test_cases': [
- {'name': 'test_happy_path', 'type': 'unit', 'description': 'Tests normal execution flow with valid inputs'},
- {'name': 'test_edge_cases', 'type': 'unit', 'description': 'Tests boundary conditions and edge scenarios'},
- {'name': 'test_error_handling', 'type': 'unit', 'description': 'Tests proper error handling and exception management'},
- {'name': 'test_integration', 'type': 'integration', 'description': 'Tests interaction between different components'},
- {'name': 'test_performance', 'type': 'performance', 'description': 'Tests system performance under load conditions'},
- {'name': 'test_security', 'type': 'security', 'description': 'Tests security measures and vulnerability protection'}
+ "test_id": f"test_{time.time()}",
+ "test_cases": [
+ {
+ "name": "test_happy_path",
+ "type": "unit",
+ "description": "Tests normal execution flow with valid inputs",
+ },
+ {
+ "name": "test_edge_cases",
+ "type": "unit",
+ "description": "Tests boundary conditions and edge scenarios",
+ },
+ {
+ "name": "test_error_handling",
+ "type": "unit",
+ "description": "Tests proper error handling and exception management",
+ },
+ {
+ "name": "test_integration",
+ "type": "integration",
+ "description": "Tests interaction between different components",
+ },
+ {
+ "name": "test_performance",
+ "type": "performance",
+ "description": "Tests system performance under load conditions",
+ },
+ {
+ "name": "test_security",
+ "type": "security",
+ "description": "Tests security measures and vulnerability protection",
+ },
],
- 'detailed_test_plan': {
- 'unit_tests': {'count': 15, 'coverage_target': 95, 'execution_time': '2.3s'},
- 'integration_tests': {'count': 8, 'coverage_target': 85, 'execution_time': '5.7s'},
- 'end_to_end_tests': {'count': 4, 'coverage_target': 75, 'execution_time': '12.1s'}
+ "detailed_test_plan": {
+ "unit_tests": {
+ "count": 15,
+ "coverage_target": 95,
+ "execution_time": "2.3s",
+ },
+ "integration_tests": {
+ "count": 8,
+ "coverage_target": 85,
+ "execution_time": "5.7s",
+ },
+ "end_to_end_tests": {
+ "count": 4,
+ "coverage_target": 75,
+ "execution_time": "12.1s",
+ },
},
- 'testing_framework_recommendations': ['pytest', 'unittest', 'mock', 'coverage.py'],
- 'quality_gates': {'min_coverage': 90, 'max_execution_time': 300, 'zero_critical_bugs': True},
- 'coverage_estimate': 95.0,
- 'timestamp': time.time(),
- 'cache_padding': 'Comprehensive test generation results with detailed test cases, coverage analysis, performance metrics, and quality assurance recommendations for thorough testing strategy implementation.'
+ "testing_framework_recommendations": [
+ "pytest",
+ "unittest",
+ "mock",
+ "coverage.py",
+ ],
+ "quality_gates": {
+ "min_coverage": 90,
+ "max_execution_time": 300,
+ "zero_critical_bugs": True,
+ },
+ "coverage_estimate": 95.0,
+ "timestamp": time.time(),
+ "cache_padding": "Comprehensive test generation results with detailed test cases, coverage analysis, performance metrics, and quality assurance recommendations for thorough testing strategy implementation.",
}
-
- elif operation == 'documentation':
+
+ elif operation == "documentation":
return {
- 'doc_id': f"doc_{time.time()}",
- 'sections': ['Overview', 'Parameters', 'Returns', 'Examples', 'Usage Guidelines', 'Best Practices', 'Troubleshooting'],
- 'detailed_content': {
- 'overview': 'Comprehensive documentation covering all aspects of the API functionality',
- 'parameters': 'Detailed parameter descriptions with types, constraints, and examples',
- 'returns': 'Complete return value documentation with success and error scenarios',
- 'examples': 'Multiple code examples demonstrating various use cases and implementations',
- 'usage_guidelines': 'Best practices for optimal API usage and integration patterns',
- 'troubleshooting': 'Common issues, error codes, and resolution strategies'
+ "doc_id": f"doc_{time.time()}",
+ "sections": [
+ "Overview",
+ "Parameters",
+ "Returns",
+ "Examples",
+ "Usage Guidelines",
+ "Best Practices",
+ "Troubleshooting",
+ ],
+ "detailed_content": {
+ "overview": "Comprehensive documentation covering all aspects of the API functionality",
+ "parameters": "Detailed parameter descriptions with types, constraints, and examples",
+ "returns": "Complete return value documentation with success and error scenarios",
+ "examples": "Multiple code examples demonstrating various use cases and implementations",
+ "usage_guidelines": "Best practices for optimal API usage and integration patterns",
+ "troubleshooting": "Common issues, error codes, and resolution strategies",
},
- 'documentation_metrics': {
- 'completeness_score': 92,
- 'readability_score': 88,
- 'example_coverage': 95,
- 'accuracy_rating': 97
+ "documentation_metrics": {
+ "completeness_score": 92,
+ "readability_score": 88,
+ "example_coverage": 95,
+ "accuracy_rating": 97,
},
- 'generated_artifacts': ['API reference', 'User guide', 'Code examples', 'FAQ section'],
- 'estimated_length': 2500,
- 'timestamp': time.time(),
- 'cache_padding': 'Comprehensive documentation generation with detailed content structure, quality metrics, and artifact descriptions for complete API documentation coverage.'
+ "generated_artifacts": [
+ "API reference",
+ "User guide",
+ "Code examples",
+ "FAQ section",
+ ],
+ "estimated_length": 2500,
+ "timestamp": time.time(),
+ "cache_padding": "Comprehensive documentation generation with detailed content structure, quality metrics, and artifact descriptions for complete API documentation coverage.",
}
-
- elif operation == 'refactoring':
+
+ elif operation == "refactoring":
return {
- 'refactor_id': f"refactor_{time.time()}",
- 'suggestions': [
- {'type': 'extract_method', 'confidence': 0.9, 'description': 'Extract complex logic into separate methods', 'impact': 'high'},
- {'type': 'rename_variable', 'confidence': 0.8, 'description': 'Improve variable naming for clarity', 'impact': 'medium'},
- {'type': 'optimize_loop', 'confidence': 0.7, 'description': 'Optimize loop performance and efficiency', 'impact': 'medium'},
- {'type': 'remove_duplication', 'confidence': 0.85, 'description': 'Eliminate duplicate code patterns', 'impact': 'high'},
- {'type': 'simplify_conditionals', 'confidence': 0.75, 'description': 'Simplify complex conditional statements', 'impact': 'medium'},
- {'type': 'improve_error_handling', 'confidence': 0.9, 'description': 'Enhance error handling mechanisms', 'impact': 'high'}
+ "refactor_id": f"refactor_{time.time()}",
+ "suggestions": [
+ {
+ "type": "extract_method",
+ "confidence": 0.9,
+ "description": "Extract complex logic into separate methods",
+ "impact": "high",
+ },
+ {
+ "type": "rename_variable",
+ "confidence": 0.8,
+ "description": "Improve variable naming for clarity",
+ "impact": "medium",
+ },
+ {
+ "type": "optimize_loop",
+ "confidence": 0.7,
+ "description": "Optimize loop performance and efficiency",
+ "impact": "medium",
+ },
+ {
+ "type": "remove_duplication",
+ "confidence": 0.85,
+ "description": "Eliminate duplicate code patterns",
+ "impact": "high",
+ },
+ {
+ "type": "simplify_conditionals",
+ "confidence": 0.75,
+ "description": "Simplify complex conditional statements",
+ "impact": "medium",
+ },
+ {
+ "type": "improve_error_handling",
+ "confidence": 0.9,
+ "description": "Enhance error handling mechanisms",
+ "impact": "high",
+ },
],
- 'refactoring_metrics': {
- 'complexity_reduction': 25,
- 'maintainability_improvement': 30,
- 'performance_gain': 15,
- 'code_quality_score': 8.7
+ "refactoring_metrics": {
+ "complexity_reduction": 25,
+ "maintainability_improvement": 30,
+ "performance_gain": 15,
+ "code_quality_score": 8.7,
},
- 'design_patterns_recommended': ['Strategy', 'Factory', 'Observer', 'Command'],
- 'code_quality_improvements': {
- 'cyclomatic_complexity': 'Reduced from 15 to 8',
- 'code_duplication': 'Eliminated 23% duplicate code',
- 'method_length': 'Average method length reduced by 40%',
- 'class_coupling': 'Reduced coupling between components'
+ "design_patterns_recommended": [
+ "Strategy",
+ "Factory",
+ "Observer",
+ "Command",
+ ],
+ "code_quality_improvements": {
+ "cyclomatic_complexity": "Reduced from 15 to 8",
+ "code_duplication": "Eliminated 23% duplicate code",
+ "method_length": "Average method length reduced by 40%",
+ "class_coupling": "Reduced coupling between components",
},
- 'estimated_improvement': '25% performance boost with 30% maintainability increase',
- 'timestamp': time.time(),
- 'cache_padding': 'Detailed refactoring analysis with comprehensive suggestions, metrics, and quality improvements for optimal code structure and maintainability enhancement.'
+ "estimated_improvement": "25% performance boost with 30% maintainability increase",
+ "timestamp": time.time(),
+ "cache_padding": "Detailed refactoring analysis with comprehensive suggestions, metrics, and quality improvements for optimal code structure and maintainability enhancement.",
}
-
+
else:
- return {'operation': operation, 'timestamp': time.time(), 'status': 'completed'}
-
- async def cached_request(self, operation: str, force_refresh: bool = False, **kwargs) -> Dict[str, Any]:
+ return {
+ "operation": operation,
+ "timestamp": time.time(),
+ "status": "completed",
+ }
+
+ async def cached_request(
+ self, operation: str, force_refresh: bool = False, **kwargs
+ ) -> Dict[str, Any]:
"""Make a cached request to Arcade.dev API."""
cache_key = self._generate_cache_key(operation, **kwargs)
-
+
# Check cache first (unless force refresh)
if not force_refresh:
cached_result = await self._get_from_cache(cache_key)
if cached_result:
# Return the data part, not the metadata
- return cached_result.get('data', cached_result)
-
+ return cached_result.get("data", cached_result)
+
# Make API call
start_time = time.time()
result = await self._simulate_api_call(operation, **kwargs)
api_duration = time.time() - start_time
-
+
# Add API call metadata
- result['_api_duration'] = api_duration
- result['_cached'] = False
-
+ result["_api_duration"] = api_duration
+ result["_cached"] = False
+
# Cache the result
ttl = self._get_cache_ttl(operation)
await self._set_cache(cache_key, result, ttl)
-
+
self.logger.info(f"API call completed: {operation} ({api_duration:.2f}s)")
return result
-
+
async def health_check(self, force_refresh: bool = False) -> Dict[str, Any]:
"""Get API health status with caching."""
- return await self.cached_request('health', force_refresh=force_refresh)
-
+ return await self.cached_request("health", force_refresh=force_refresh)
+
async def get_user_info(self, force_refresh: bool = False) -> Dict[str, Any]:
"""Get user information with caching."""
- return await self.cached_request('user_info', force_refresh=force_refresh)
-
- async def analyze_code(self, code: str, language: str = 'python', force_refresh: bool = False) -> Dict[str, Any]:
+ return await self.cached_request("user_info", force_refresh=force_refresh)
+
+ async def analyze_code(
+ self, code: str, language: str = "python", force_refresh: bool = False
+ ) -> Dict[str, Any]:
"""Analyze code with caching."""
return await self.cached_request(
- 'code_analysis',
- force_refresh=force_refresh,
- code=code,
- language=language
+ "code_analysis", force_refresh=force_refresh, code=code, language=language
)
-
- async def generate_tests(self, code: str, test_type: str = 'unit', force_refresh: bool = False) -> Dict[str, Any]:
+
+ async def generate_tests(
+ self, code: str, test_type: str = "unit", force_refresh: bool = False
+ ) -> Dict[str, Any]:
"""Generate tests with caching."""
return await self.cached_request(
- 'test_generation',
+ "test_generation",
force_refresh=force_refresh,
code=code,
- test_type=test_type
+ test_type=test_type,
)
-
- async def generate_documentation(self, code: str, doc_type: str = 'api', force_refresh: bool = False) -> Dict[str, Any]:
+
+ async def generate_documentation(
+ self, code: str, doc_type: str = "api", force_refresh: bool = False
+ ) -> Dict[str, Any]:
"""Generate documentation with caching."""
return await self.cached_request(
- 'documentation',
- force_refresh=force_refresh,
- code=code,
- doc_type=doc_type
+ "documentation", force_refresh=force_refresh, code=code, doc_type=doc_type
)
-
- async def suggest_refactoring(self, code: str, focus: str = 'performance', force_refresh: bool = False) -> Dict[str, Any]:
+
+ async def suggest_refactoring(
+ self, code: str, focus: str = "performance", force_refresh: bool = False
+ ) -> Dict[str, Any]:
"""Get refactoring suggestions with caching."""
return await self.cached_request(
- 'refactoring',
- force_refresh=force_refresh,
- code=code,
- focus=focus
+ "refactoring", force_refresh=force_refresh, code=code, focus=focus
)
-
+
def get_cache_stats(self) -> Dict[str, Any]:
"""Get cache performance statistics."""
- total_requests = self.stats['cache_hits'] + self.stats['cache_misses']
- hit_rate = (self.stats['cache_hits'] / total_requests * 100) if total_requests > 0 else 0
-
+ total_requests = self.stats["cache_hits"] + self.stats["cache_misses"]
+ hit_rate = (
+ (self.stats["cache_hits"] / total_requests * 100)
+ if total_requests > 0
+ else 0
+ )
+
return {
- 'cache_hits': self.stats['cache_hits'],
- 'cache_misses': self.stats['cache_misses'],
- 'api_calls': self.stats['api_calls'],
- 'cache_sets': self.stats['cache_sets'],
- 'cache_errors': self.stats['cache_errors'],
- 'hit_rate_percent': round(hit_rate, 2),
- 'total_requests': total_requests
+ "cache_hits": self.stats["cache_hits"],
+ "cache_misses": self.stats["cache_misses"],
+ "api_calls": self.stats["api_calls"],
+ "cache_sets": self.stats["cache_sets"],
+ "cache_errors": self.stats["cache_errors"],
+ "hit_rate_percent": round(hit_rate, 2),
+ "total_requests": total_requests,
}
-
+
async def clear_cache(self, operation: str = None):
"""Clear cache for specific operation or all cache."""
if operation:
@@ -360,26 +504,26 @@ async def clear_cache(self, operation: str = None):
async def demonstrate_caching():
"""Demonstrate the caching capabilities."""
print("š§ Initializing cached Arcade.dev client...")
-
+
# Get default cache configuration and initialize cache manager
fact_cache_config = get_default_cache_config()
cache_manager = CacheManager(fact_cache_config)
-
+
# Configure caching behavior
cache_config = CacheConfig(
default_ttl=3600,
enabled=True,
ttl_by_operation={
- 'health': 300,
- 'code_analysis': 7200, # Cache code analysis for 2 hours
- }
+ "health": 300,
+ "code_analysis": 7200, # Cache code analysis for 2 hours
+ },
)
-
+
# Create cached client
- api_key = os.getenv('ARCADE_API_KEY', 'demo_key')
+ api_key = os.getenv("ARCADE_API_KEY", "demo_key")
client = CachedArcadeClient(api_key, cache_manager, cache_config)
-
- sample_code = '''
+
+ sample_code = """
def calculate_fibonacci(n):
if n <= 1:
return n
@@ -388,11 +532,11 @@ def calculate_fibonacci(n):
def main():
for i in range(10):
print(f"F({i}) = {calculate_fibonacci(i)}")
-'''
-
+"""
+
print("\nš Performance Comparison: First vs Cached Requests")
print("=" * 60)
-
+
# First request (will hit API)
print("š First code analysis request (API call)...")
start_time = time.time()
@@ -400,7 +544,7 @@ def main():
first_duration = time.time() - start_time
print(f" Duration: {first_duration:.2f}s")
print(f" Suggestions: {len(result1.get('suggestions', []))}")
-
+
# Second request (should hit cache)
print("\nš Second code analysis request (cached)...")
start_time = time.time()
@@ -408,41 +552,43 @@ def main():
second_duration = time.time() - start_time
print(f" Duration: {second_duration:.2f}s")
print(f" Suggestions: {len(result2.get('suggestions', []))}")
-
+
# Performance improvement
if first_duration > 0:
improvement = ((first_duration - second_duration) / first_duration) * 100
print(f"\nā” Performance improvement: {improvement:.1f}% faster")
-
+
# Demonstrate force refresh
print("\nš Force refresh (bypassing cache)...")
start_time = time.time()
result3 = await client.analyze_code(sample_code, force_refresh=True)
third_duration = time.time() - start_time
print(f" Duration: {third_duration:.2f}s")
-
+
# Show cache statistics
print("\nš Cache Statistics:")
stats = client.get_cache_stats()
for key, value in stats.items():
print(f" {key}: {value}")
-
+
# Demonstrate different operations
print("\nšÆ Testing different operations...")
-
+
operations = [
- ('Health Check', client.health_check),
- ('User Info', client.get_user_info),
- ('Test Generation', lambda: client.generate_tests(sample_code)),
- ('Documentation', lambda: client.generate_documentation(sample_code)),
- ('Refactoring', lambda: client.suggest_refactoring(sample_code))
+ ("Health Check", client.health_check),
+ ("User Info", client.get_user_info),
+ ("Test Generation", lambda: client.generate_tests(sample_code)),
+ ("Documentation", lambda: client.generate_documentation(sample_code)),
+ ("Refactoring", lambda: client.suggest_refactoring(sample_code)),
]
-
+
for name, operation in operations:
print(f" {name}...")
result = await operation()
- print(f" ā
Completed (cached for {client._get_cache_ttl(name.lower().replace(' ', '_'))}s)")
-
+ print(
+ f" ā
Completed (cached for {client._get_cache_ttl(name.lower().replace(' ', '_'))}s)"
+ )
+
# Final statistics
print("\nš Final Cache Statistics:")
final_stats = client.get_cache_stats()
@@ -454,17 +600,17 @@ async def main():
"""Main demonstration function."""
logging.basicConfig(
level=logging.INFO,
- format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
+ format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
)
-
+
print("š® Arcade.dev Cache Integration Example")
print("=" * 50)
-
+
try:
await demonstrate_caching()
print("\nš Cache integration example completed successfully!")
return 0
-
+
except Exception as e:
print(f"\nā Error: {e}")
return 1
@@ -472,4 +618,4 @@ async def main():
if __name__ == "__main__":
exit_code = asyncio.run(main())
- sys.exit(exit_code)
\ No newline at end of file
+ sys.exit(exit_code)
diff --git a/examples/arcade-dev/08_advanced_tools/advanced_tool_usage.py b/examples/arcade-dev/08_advanced_tools/advanced_tool_usage.py
index a90ff72..1f02432 100644
--- a/examples/arcade-dev/08_advanced_tools/advanced_tool_usage.py
+++ b/examples/arcade-dev/08_advanced_tools/advanced_tool_usage.py
@@ -37,42 +37,46 @@
from src.monitoring.metrics import MetricsCollector
from src.core.errors import ToolExecutionError
+
# Define mock ArcadeClient for demo mode
class MockArcadeClient:
def __init__(self, api_key: str = None, **kwargs):
self.api_key = api_key
self._demo_mode = True
-
+
async def __aenter__(self):
return self
-
+
async def __aexit__(self, exc_type, exc_val, exc_tb):
pass
-
+
async def execute_tool(self, payload: Dict[str, Any]) -> Dict[str, Any]:
# Mock response for demo
- tool_name = payload.get('tool', 'unknown')
+ tool_name = payload.get("tool", "unknown")
return {
- 'result': {
- 'success': True,
- 'data': f'Mock result from {tool_name}',
- 'execution_time': 0.1
+ "result": {
+ "success": True,
+ "data": f"Mock result from {tool_name}",
+ "execution_time": 0.1,
}
}
+
# Try to import real ArcadeClient, fall back to mock
try:
from src.arcade.client import ArcadeClient as RealArcadeClient
+
REAL_ARCADE_AVAILABLE = True
except ImportError:
REAL_ARCADE_AVAILABLE = False
+
# Create a wrapper that tries real client first, falls back to mock
class ArcadeClient:
def __new__(cls, api_key: str = None, **kwargs):
# Check if we should use demo mode
demo_mode = api_key == "demo" or not api_key or not REAL_ARCADE_AVAILABLE
-
+
if demo_mode:
return MockArcadeClient(api_key, **kwargs)
else:
@@ -83,26 +87,29 @@ def __new__(cls, api_key: str = None, **kwargs):
print("šÆ Falling back to demo mode due to missing dependencies")
return MockArcadeClient(api_key, **kwargs)
+
ARCADE_AVAILABLE = REAL_ARCADE_AVAILABLE
+
# Mock classes for demo mode
class MockToolExecutor:
"""Mock tool executor for demo mode."""
-
+
async def execute(self, tool_name: str, params: Dict[str, Any]) -> Dict[str, Any]:
await asyncio.sleep(0.1) # Simulate processing
return {
- 'success': True,
- 'data': f'Mock execution result for {tool_name}',
- 'execution_time': 0.1
+ "success": True,
+ "data": f"Mock execution result for {tool_name}",
+ "execution_time": 0.1,
}
+
class MockMetricsCollector:
"""Mock metrics collector for demo mode."""
-
+
async def record_gauge(self, name: str, value: float, tags: Dict[str, str] = None):
pass
-
+
async def record_counter(self, name: str, value: int = 1):
pass
@@ -110,6 +117,7 @@ async def record_counter(self, name: str, value: int = 1):
@dataclass
class ToolChainStep:
"""Represents a single step in a tool execution chain."""
+
tool_name: str
params: Dict[str, Any]
condition: Optional[Callable[[Dict[str, Any]], bool]] = None
@@ -121,6 +129,7 @@ class ToolChainStep:
@dataclass
class ToolOrchestrationResult:
"""Result of tool orchestration execution."""
+
success: bool
results: List[Dict[str, Any]] = field(default_factory=list)
errors: List[str] = field(default_factory=list)
@@ -131,55 +140,55 @@ class ToolOrchestrationResult:
class AdvancedToolOrchestrator:
"""Advanced tool orchestration system with chaining and conditional execution."""
-
+
def __init__(self, arcade_client: ArcadeClient, cache_manager: CacheManager):
self.arcade_client = arcade_client
self.cache_manager = cache_manager
-
+
# Create demo-compatible tool executor
try:
self.tool_executor = ToolExecutor(arcade_client=arcade_client)
except Exception:
# Create minimal mock executor for demo
self.tool_executor = MockToolExecutor()
-
+
try:
self.metrics = MetricsCollector()
except Exception:
# Create minimal mock metrics collector
self.metrics = MockMetricsCollector()
-
+
self.logger = logging.getLogger(__name__)
self.thread_pool = ThreadPoolExecutor(max_workers=4)
-
+
# Dynamic tool registry
self.dynamic_tools: Dict[str, Callable] = {}
-
+
async def __aenter__(self):
"""Async context manager entry."""
await self.arcade_client.__aenter__()
return self
-
+
async def __aexit__(self, exc_type, exc_val, exc_tb):
"""Async context manager exit."""
await self.arcade_client.__aexit__(exc_type, exc_val, exc_tb)
self.thread_pool.shutdown(wait=True)
-
+
def register_dynamic_tool(self, name: str, func: Callable):
"""Register a dynamically created tool."""
self.dynamic_tools[name] = func
self.logger.info(f"Registered dynamic tool: {name}")
-
+
async def execute_tool_chain(
- self,
+ self,
steps: List[ToolChainStep],
parallel: bool = False,
- aggregate_results: bool = True
+ aggregate_results: bool = True,
) -> ToolOrchestrationResult:
"""Execute a chain of tools with conditional branching and result aggregation."""
start_time = time.time()
result = ToolOrchestrationResult(success=True)
-
+
try:
if parallel:
# Execute steps in parallel where possible
@@ -187,180 +196,185 @@ async def execute_tool_chain(
else:
# Execute steps sequentially
result = await self._execute_sequential_chain(steps, result)
-
+
if aggregate_results:
- result.metadata['aggregated_data'] = await self._aggregate_results(result.results)
-
+ result.metadata["aggregated_data"] = await self._aggregate_results(
+ result.results
+ )
+
except Exception as e:
result.success = False
result.errors.append(f"Chain execution failed: {str(e)}")
self.logger.error(f"Tool chain execution error: {e}")
-
+
finally:
result.execution_time = time.time() - start_time
await self._record_metrics(result)
-
+
return result
-
+
async def _execute_sequential_chain(
- self,
- steps: List[ToolChainStep],
- result: ToolOrchestrationResult
+ self, steps: List[ToolChainStep], result: ToolOrchestrationResult
) -> ToolOrchestrationResult:
"""Execute tool chain sequentially."""
context = {}
-
+
for i, step in enumerate(steps):
try:
# Check condition if provided
if step.condition and not step.condition(context):
self.logger.info(f"Skipping step {i}: condition not met")
continue
-
+
# Execute tool with retries
step_result = await self._execute_tool_with_retry(step, context)
-
+
# Transform result if transform function provided
if step.transform:
step_result = step.transform(step_result)
-
+
# Update context and results
context.update(step_result)
result.results.append(step_result)
result.steps_completed += 1
-
+
self.logger.info(f"Completed step {i}: {step.tool_name}")
-
+
except Exception as e:
error_msg = f"Step {i} ({step.tool_name}) failed: {str(e)}"
result.errors.append(error_msg)
self.logger.error(error_msg)
-
+
# Decide whether to continue or abort chain
if not self._should_continue_on_error(step, e):
result.success = False
break
-
+
return result
-
+
async def _execute_parallel_chain(
- self,
- steps: List[ToolChainStep],
- result: ToolOrchestrationResult
+ self, steps: List[ToolChainStep], result: ToolOrchestrationResult
) -> ToolOrchestrationResult:
"""Execute independent tool steps in parallel."""
# Group steps by dependencies (simplified: assume all independent for now)
tasks = []
-
+
for step in steps:
task = asyncio.create_task(self._execute_tool_with_retry(step, {}))
tasks.append((step, task))
-
+
# Wait for all tasks to complete
for step, task in tasks:
try:
step_result = await task
if step.transform:
step_result = step.transform(step_result)
-
+
result.results.append(step_result)
result.steps_completed += 1
-
+
except Exception as e:
error_msg = f"Parallel step ({step.tool_name}) failed: {str(e)}"
result.errors.append(error_msg)
self.logger.error(error_msg)
-
+
return result
-
+
async def _execute_tool_with_retry(
- self,
- step: ToolChainStep,
- context: Dict[str, Any]
+ self, step: ToolChainStep, context: Dict[str, Any]
) -> Dict[str, Any]:
"""Execute a single tool with retry logic."""
last_error = None
-
+
for attempt in range(step.retry_count):
try:
# Merge step params with context
merged_params = {**step.params, **context}
-
+
# Check if it's a dynamic tool
if step.tool_name in self.dynamic_tools:
- return await self._execute_dynamic_tool(step.tool_name, merged_params)
-
+ return await self._execute_dynamic_tool(
+ step.tool_name, merged_params
+ )
+
# Execute via Arcade.dev or local executor
- if step.tool_name.startswith('arcade:'):
+ if step.tool_name.startswith("arcade:"):
tool_name = step.tool_name[7:] # Remove 'arcade:' prefix
return await self._execute_arcade_tool(tool_name, merged_params)
else:
return await self._execute_local_tool(step.tool_name, merged_params)
-
+
except Exception as e:
last_error = e
if attempt < step.retry_count - 1:
- wait_time = 2 ** attempt
- self.logger.warning(f"Tool {step.tool_name} attempt {attempt + 1} failed, retrying in {wait_time}s")
+ wait_time = 2**attempt
+ self.logger.warning(
+ f"Tool {step.tool_name} attempt {attempt + 1} failed, retrying in {wait_time}s"
+ )
await asyncio.sleep(wait_time)
-
+
raise last_error
-
- async def _execute_arcade_tool(self, tool_name: str, params: Dict[str, Any]) -> Dict[str, Any]:
+
+ async def _execute_arcade_tool(
+ self, tool_name: str, params: Dict[str, Any]
+ ) -> Dict[str, Any]:
"""Execute tool via Arcade.dev API."""
payload = {
"tool": tool_name,
"parameters": params,
- "context": {
- "source": "fact_sdk",
- "timestamp": time.time()
- }
+ "context": {"source": "fact_sdk", "timestamp": time.time()},
}
-
+
response = await self.arcade_client.execute_tool(payload)
- return response.get('result', {})
-
- async def _execute_local_tool(self, tool_name: str, params: Dict[str, Any]) -> Dict[str, Any]:
+ return response.get("result", {})
+
+ async def _execute_local_tool(
+ self, tool_name: str, params: Dict[str, Any]
+ ) -> Dict[str, Any]:
"""Execute tool locally using FACT executor."""
return await self.tool_executor.execute(tool_name, params)
-
- async def _execute_dynamic_tool(self, tool_name: str, params: Dict[str, Any]) -> Dict[str, Any]:
+
+ async def _execute_dynamic_tool(
+ self, tool_name: str, params: Dict[str, Any]
+ ) -> Dict[str, Any]:
"""Execute dynamically registered tool."""
tool_func = self.dynamic_tools[tool_name]
-
+
# Run in thread pool if not async
if asyncio.iscoroutinefunction(tool_func):
return await tool_func(**params)
else:
loop = asyncio.get_event_loop()
- return await loop.run_in_executor(self.thread_pool, lambda: tool_func(**params))
-
+ return await loop.run_in_executor(
+ self.thread_pool, lambda: tool_func(**params)
+ )
+
def _should_continue_on_error(self, step: ToolChainStep, error: Exception) -> bool:
"""Determine if chain should continue after an error."""
# Custom logic based on error type and step configuration
if isinstance(error, ToolExecutionError) and error.is_recoverable:
return True
return False
-
+
async def _aggregate_results(self, results: List[Dict[str, Any]]) -> Dict[str, Any]:
"""Aggregate results from multiple tool executions."""
aggregated = {
"total_results": len(results),
- "success_count": sum(1 for r in results if r.get('success', True)),
+ "success_count": sum(1 for r in results if r.get("success", True)),
"combined_data": {},
"metrics": {
- "execution_times": [r.get('execution_time', 0) for r in results],
- "data_sizes": [len(str(r)) for r in results]
- }
+ "execution_times": [r.get("execution_time", 0) for r in results],
+ "data_sizes": [len(str(r)) for r in results],
+ },
}
-
+
# Merge data from all results
for result in results:
- if 'data' in result:
- aggregated['combined_data'].update(result['data'])
-
+ if "data" in result:
+ aggregated["combined_data"].update(result["data"])
+
return aggregated
-
+
async def _record_metrics(self, result: ToolOrchestrationResult):
"""Record execution metrics."""
# Use the correct MetricsCollector method
@@ -368,34 +382,31 @@ async def _record_metrics(self, result: ToolOrchestrationResult):
tool_name="tool_chain",
success=result.success,
execution_time=result.execution_time * 1000, # Convert to milliseconds
- metadata={"steps_completed": result.steps_completed}
+ metadata={"steps_completed": result.steps_completed},
)
-
+
if result.errors:
# Record error metrics using the correct method signature
self.metrics.record_tool_execution(
tool_name="tool_chain_errors",
success=False,
execution_time=0,
- metadata={"error_count": len(result.errors)}
+ metadata={"error_count": len(result.errors)},
)
-
+
async def create_conditional_branch(
- self,
+ self,
condition_tool: str,
true_branch: List[ToolChainStep],
- false_branch: List[ToolChainStep]
+ false_branch: List[ToolChainStep],
) -> ToolOrchestrationResult:
"""Create conditional execution branches based on tool result."""
# Execute condition tool
- condition_step = ToolChainStep(
- tool_name=condition_tool,
- params={}
- )
-
+ condition_step = ToolChainStep(tool_name=condition_tool, params={})
+
condition_result = await self._execute_tool_with_retry(condition_step, {})
- branch_condition = condition_result.get('condition_met', False)
-
+ branch_condition = condition_result.get("condition_met", False)
+
# Execute appropriate branch
if branch_condition:
self.logger.info("Executing true branch")
@@ -403,24 +414,22 @@ async def create_conditional_branch(
else:
self.logger.info("Executing false branch")
return await self.execute_tool_chain(false_branch)
-
+
def generate_dynamic_tool(
- self,
- name: str,
- template: str,
- parameters: Dict[str, Any]
+ self, name: str, template: str, parameters: Dict[str, Any]
) -> Callable:
"""Generate a dynamic tool from a template."""
+
def dynamic_tool(**kwargs):
# Simple template substitution for demonstration
code = template.format(**parameters, **kwargs)
-
+
# Execute the generated code (in a real implementation, use safer execution)
local_vars = {}
exec(code, {"__builtins__": {}}, local_vars)
-
- return local_vars.get('result', {})
-
+
+ return local_vars.get("result", {})
+
self.register_dynamic_tool(name, dynamic_tool)
return dynamic_tool
@@ -428,87 +437,99 @@ def dynamic_tool(**kwargs):
# Example tool implementations for demonstration
def create_sample_tools(orchestrator: AdvancedToolOrchestrator):
"""Create sample dynamic tools for demonstration."""
-
+
# Data validation tool
- def validate_data(data: Dict[str, Any] = None, schema: Dict[str, Any] = None, **kwargs) -> Dict[str, Any]:
+ def validate_data(
+ data: Dict[str, Any] = None, schema: Dict[str, Any] = None, **kwargs
+ ) -> Dict[str, Any]:
"""Validate data against schema."""
# Use provided parameters or defaults from kwargs
if data is None:
- data = kwargs.get('data', {'name': 'Demo', 'age': 25})
+ data = kwargs.get("data", {"name": "Demo", "age": 25})
if schema is None:
- schema = kwargs.get('schema', {'name': str, 'age': int})
-
+ schema = kwargs.get("schema", {"name": str, "age": int})
+
errors = []
for field, field_type in schema.items():
if field not in data:
errors.append(f"Missing field: {field}")
elif not isinstance(data[field], field_type):
- errors.append(f"Invalid type for {field}: expected {field_type.__name__}")
-
- return {
- 'valid': len(errors) == 0,
- 'errors': errors,
- 'data': data
- }
-
+ errors.append(
+ f"Invalid type for {field}: expected {field_type.__name__}"
+ )
+
+ return {"valid": len(errors) == 0, "errors": errors, "data": data}
+
# Data transformation tool
- def transform_data(transformations: List[str] = None, data: Dict[str, Any] = None, **kwargs) -> Dict[str, Any]:
+ def transform_data(
+ transformations: List[str] = None, data: Dict[str, Any] = None, **kwargs
+ ) -> Dict[str, Any]:
"""Apply transformations to data."""
# Handle context from previous steps
- if data is None and 'data' in kwargs:
- data = kwargs['data']
+ if data is None and "data" in kwargs:
+ data = kwargs["data"]
elif data is None:
- data = kwargs.get('context', {})
-
+ data = kwargs.get("context", {})
+
if transformations is None:
- transformations = kwargs.get('transformations', ['add_timestamp'])
-
+ transformations = kwargs.get("transformations", ["add_timestamp"])
+
result = data.copy() if data else {}
-
+
for transform in transformations:
- if transform == 'uppercase_strings':
+ if transform == "uppercase_strings":
for key, value in result.items():
if isinstance(value, str):
result[key] = value.upper()
- elif transform == 'add_timestamp':
- result['timestamp'] = time.time()
-
- return {'transformed_data': result}
-
+ elif transform == "add_timestamp":
+ result["timestamp"] = time.time()
+
+ return {"transformed_data": result}
+
# Analysis tool
- async def analyze_results(results: List[Dict[str, Any]] = None, **kwargs) -> Dict[str, Any]:
+ async def analyze_results(
+ results: List[Dict[str, Any]] = None, **kwargs
+ ) -> Dict[str, Any]:
"""Analyze aggregated results."""
await asyncio.sleep(0.1) # Simulate async processing
-
+
# Handle context from previous steps
if results is None:
- results = kwargs.get('results', [])
-
+ results = kwargs.get("results", [])
+
# If no results provided, create demo analysis
if not results:
- results = [{'execution_time': 0.1, 'success': True}]
-
+ results = [{"execution_time": 0.1, "success": True}]
+
return {
- 'analysis': {
- 'total_items': len(results),
- 'avg_processing_time': sum(r.get('execution_time', 0) for r in results) / len(results) if results else 0,
- 'error_rate': sum(1 for r in results if 'error' in r) / len(results) if results else 0
+ "analysis": {
+ "total_items": len(results),
+ "avg_processing_time": (
+ sum(r.get("execution_time", 0) for r in results) / len(results)
+ if results
+ else 0
+ ),
+ "error_rate": (
+ sum(1 for r in results if "error" in r) / len(results)
+ if results
+ else 0
+ ),
}
}
-
+
# Register tools
- orchestrator.register_dynamic_tool('validate_data', validate_data)
- orchestrator.register_dynamic_tool('transform_data', transform_data)
- orchestrator.register_dynamic_tool('analyze_results', analyze_results)
+ orchestrator.register_dynamic_tool("validate_data", validate_data)
+ orchestrator.register_dynamic_tool("transform_data", transform_data)
+ orchestrator.register_dynamic_tool("analyze_results", analyze_results)
async def demonstrate_tool_chaining():
"""Demonstrate basic tool chaining."""
print("š Demonstrating Tool Chaining")
-
+
# Create mock arcade client and cache manager with proper config
arcade_client = ArcadeClient(api_key=os.getenv("ARCADE_API_KEY", "demo"))
-
+
# Create cache config for demo
cache_config = {
"prefix": "demo_cache",
@@ -516,39 +537,36 @@ async def demonstrate_tool_chaining():
"max_size": "50MB",
"ttl_seconds": 3600,
"hit_target_ms": 48,
- "miss_target_ms": 140
+ "miss_target_ms": 140,
}
cache_manager = CacheManager(cache_config)
-
+
async with AdvancedToolOrchestrator(arcade_client, cache_manager) as orchestrator:
# Create sample tools
create_sample_tools(orchestrator)
-
+
# Define tool chain
steps = [
ToolChainStep(
- tool_name='validate_data',
- params={
- 'data': {'name': 'John', 'age': 30},
- 'schema': {'name': str, 'age': int}
- }
- ),
- ToolChainStep(
- tool_name='transform_data',
+ tool_name="validate_data",
params={
- 'transformations': ['uppercase_strings', 'add_timestamp']
+ "data": {"name": "John", "age": 30},
+ "schema": {"name": str, "age": int},
},
- condition=lambda ctx: ctx.get('valid', False) # Only if validation passed
),
ToolChainStep(
- tool_name='analyze_results',
- params={}
- )
+ tool_name="transform_data",
+ params={"transformations": ["uppercase_strings", "add_timestamp"]},
+ condition=lambda ctx: ctx.get(
+ "valid", False
+ ), # Only if validation passed
+ ),
+ ToolChainStep(tool_name="analyze_results", params={}),
]
-
+
# Execute chain
result = await orchestrator.execute_tool_chain(steps)
-
+
print(f"ā
Chain completed: {result.steps_completed} steps")
print(f"ā±ļø Execution time: {result.execution_time:.2f}s")
if result.errors:
@@ -558,9 +576,9 @@ async def demonstrate_tool_chaining():
async def demonstrate_conditional_branching():
"""Demonstrate conditional tool branching."""
print("\nšæ Demonstrating Conditional Branching")
-
+
arcade_client = ArcadeClient(api_key=os.getenv("ARCADE_API_KEY", "demo"))
-
+
# Create cache config for demo
cache_config = {
"prefix": "demo_cache",
@@ -568,52 +586,45 @@ async def demonstrate_conditional_branching():
"max_size": "50MB",
"ttl_seconds": 3600,
"hit_target_ms": 48,
- "miss_target_ms": 140
+ "miss_target_ms": 140,
}
cache_manager = CacheManager(cache_config)
-
+
async with AdvancedToolOrchestrator(arcade_client, cache_manager) as orchestrator:
create_sample_tools(orchestrator)
-
+
# Create condition tool
def check_condition(threshold: float = 0.5) -> Dict[str, Any]:
import random
+
value = random.random()
- return {
- 'condition_met': value > threshold,
- 'value': value
- }
-
- orchestrator.register_dynamic_tool('check_condition', check_condition)
-
+ return {"condition_met": value > threshold, "value": value}
+
+ orchestrator.register_dynamic_tool("check_condition", check_condition)
+
# Define branches
true_branch = [
ToolChainStep(
- tool_name='transform_data',
+ tool_name="transform_data",
params={
- 'data': {'status': 'success'},
- 'transformations': ['add_timestamp']
- }
+ "data": {"status": "success"},
+ "transformations": ["add_timestamp"],
+ },
)
]
-
+
false_branch = [
ToolChainStep(
- tool_name='validate_data',
- params={
- 'data': {'status': 'failure'},
- 'schema': {'status': str}
- }
+ tool_name="validate_data",
+ params={"data": {"status": "failure"}, "schema": {"status": str}},
)
]
-
+
# Execute conditional branch
result = await orchestrator.create_conditional_branch(
- 'check_condition',
- true_branch,
- false_branch
+ "check_condition", true_branch, false_branch
)
-
+
print(f"ā
Branch completed: {result.success}")
print(f"š Results: {len(result.results)} items")
@@ -621,9 +632,9 @@ def check_condition(threshold: float = 0.5) -> Dict[str, Any]:
async def demonstrate_dynamic_tool_generation():
"""Demonstrate dynamic tool generation."""
print("\nš ļø Demonstrating Dynamic Tool Generation")
-
+
arcade_client = ArcadeClient(api_key=os.getenv("ARCADE_API_KEY", "demo"))
-
+
# Create cache config for demo
cache_config = {
"prefix": "demo_cache",
@@ -631,10 +642,10 @@ async def demonstrate_dynamic_tool_generation():
"max_size": "50MB",
"ttl_seconds": 3600,
"hit_target_ms": 48,
- "miss_target_ms": 140
+ "miss_target_ms": 140,
}
cache_manager = CacheManager(cache_config)
-
+
async with AdvancedToolOrchestrator(arcade_client, cache_manager) as orchestrator:
# Generate dynamic tool from template
template = """
@@ -649,24 +660,17 @@ def calculate_{operation}(a, b):
result = calculate_{operation}({a}, {b})
"""
-
+
# Generate calculator tool
orchestrator.generate_dynamic_tool(
- 'calculator',
- template,
- {'operation': 'add', 'a': 10, 'b': 20}
+ "calculator", template, {"operation": "add", "a": 10, "b": 20}
)
-
+
# Execute generated tool
- steps = [
- ToolChainStep(
- tool_name='calculator',
- params={}
- )
- ]
-
+ steps = [ToolChainStep(tool_name="calculator", params={})]
+
result = await orchestrator.execute_tool_chain(steps)
-
+
print(f"ā
Dynamic tool executed: {result.success}")
if result.results:
print(f"š¢ Calculation result: {result.results[0]}")
@@ -676,16 +680,16 @@ async def main():
"""Main demonstration function."""
logging.basicConfig(
level=logging.INFO,
- format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
+ format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
)
-
+
# Check for demo mode
api_key = os.getenv("ARCADE_API_KEY", "demo")
demo_mode = api_key == "demo" or not api_key or not REAL_ARCADE_AVAILABLE
-
+
print("š Advanced Tool Usage Example")
print("=" * 50)
-
+
if demo_mode:
print("šÆ Running in DEMO MODE")
print(" - Using mock Arcade client")
@@ -696,25 +700,28 @@ async def main():
print("š Connected to Arcade.dev API")
print(f" - API Key: {api_key[:8]}...")
print()
-
+
try:
# Run demonstrations
await demonstrate_tool_chaining()
await demonstrate_conditional_branching()
await demonstrate_dynamic_tool_generation()
-
+
print("\nš All advanced tool usage examples completed successfully!")
if demo_mode:
- print("š” To run with real API integration, set the ARCADE_API_KEY environment variable")
+ print(
+ "š” To run with real API integration, set the ARCADE_API_KEY environment variable"
+ )
return 0
-
+
except Exception as e:
print(f"ā Error: {e}")
import traceback
+
traceback.print_exc()
return 1
if __name__ == "__main__":
exit_code = asyncio.run(main())
- sys.exit(exit_code)
\ No newline at end of file
+ sys.exit(exit_code)
diff --git a/examples/arcade-dev/09_testing/arcade_integration_tests.py b/examples/arcade-dev/09_testing/arcade_integration_tests.py
index c2b54f7..09d937c 100644
--- a/examples/arcade-dev/09_testing/arcade_integration_tests.py
+++ b/examples/arcade-dev/09_testing/arcade_integration_tests.py
@@ -39,35 +39,38 @@
from src.monitoring.metrics import MetricsCollector
from src.core.errors import ToolExecutionError, ConfigurationError
+
# Import arcade integration classes (create simple mocks for testing)
class ArcadeClient:
"""Mock ArcadeClient for testing."""
-
- def __init__(self, api_key: str, cache_manager=None, max_retries: int = 3, timeout: int = 30):
+
+ def __init__(
+ self, api_key: str, cache_manager=None, max_retries: int = 3, timeout: int = 30
+ ):
self.api_key = api_key
self.cache_manager = cache_manager
self.max_retries = max_retries
self.timeout = timeout
self.session = None
-
+
async def connect(self):
"""Mock connect method."""
pass
-
+
async def disconnect(self):
"""Mock disconnect method."""
pass
-
+
async def health_check(self):
"""Mock health check that uses cache."""
cache_key = "health_check"
-
+
# Check cache first if available
if self.cache_manager:
cached = self.cache_manager.get(cache_key)
if cached:
return json.loads(cached.content)
-
+
# Make "API" request - create a larger response to meet token minimum
result = {
"status": "healthy",
@@ -86,27 +89,27 @@ async def health_check(self):
"authorization": "running",
"cache": "running",
"metrics": "running",
- "logging": "running"
+ "logging": "running",
},
"performance_metrics": {
"avg_response_time": "125ms",
"p95_response_time": "350ms",
"p99_response_time": "750ms",
- "error_rate": "0.02%"
- }
- }
+ "error_rate": "0.02%",
+ },
+ },
}
-
+
# Store in cache if available
if self.cache_manager:
self.cache_manager.store(cache_key, json.dumps(result))
-
+
return result
-
+
async def analyze_code(self, code: str, language: str):
"""Mock analyze code."""
return {"result": f"analyzed_{code}_{language}"}
-
+
async def make_request(self, method: str, endpoint: str, **kwargs):
"""Mock request method."""
return {"status": "success", "method": method, "endpoint": endpoint}
@@ -114,31 +117,31 @@ async def make_request(self, method: str, endpoint: str, **kwargs):
class ArcadeGateway:
"""Mock ArcadeGateway for testing."""
-
+
def __init__(self, api_key: str, enable_failover: bool = False):
self.api_key = api_key
self.enable_failover = enable_failover
-
+
async def initialize(self):
"""Mock initialize method."""
pass
-
+
async def cleanup(self):
"""Mock cleanup method."""
pass
-
+
async def execute_local_tool(self, tool_name: str, params: Dict[str, Any]):
"""Mock local tool execution."""
return {"tool": tool_name, "params": params, "result": "local_result"}
-
+
async def execute_arcade_tool(self, tool_name: str, params: Dict[str, Any]):
"""Mock arcade tool execution."""
return {"tool": tool_name, "params": params, "result": "arcade_result"}
-
+
async def execute_hybrid_workflow(self, workflow: List[Dict[str, Any]]):
"""Mock hybrid workflow execution."""
return {"workflow": workflow, "result": "hybrid_result"}
-
+
async def analyze_with_failover(self, code: str):
"""Mock analyze with failover."""
return {"result": "fallback_analysis", "code": code}
@@ -147,6 +150,7 @@ async def analyze_with_failover(self, code: str):
@dataclass
class TestResult:
"""Test execution result."""
+
test_name: str
success: bool
execution_time: float
@@ -156,71 +160,65 @@ class TestResult:
class MockArcadeResponse:
"""Mock Arcade.dev API response."""
-
+
def __init__(self, status: int = 200, data: Dict[str, Any] = None):
self.status = status
self.data = data or {}
-
+
async def json(self):
return self.data
-
+
def raise_for_status(self):
if self.status >= 400:
raise aiohttp.ClientResponseError(
- request_info=None,
- history=None,
- status=self.status
+ request_info=None, history=None, status=self.status
)
class MockArcadeSession:
"""Mock aiohttp session for Arcade.dev API."""
-
+
def __init__(self, responses: Dict[str, MockArcadeResponse] = None):
self.responses = responses or {}
self.request_history = []
-
+
async def request(self, method: str, url: str, **kwargs):
"""Mock request method."""
- self.request_history.append({
- 'method': method,
- 'url': url,
- 'kwargs': kwargs
- })
-
+ self.request_history.append({"method": method, "url": url, "kwargs": kwargs})
+
# Return configured response or default success
key = f"{method}:{url}"
if key in self.responses:
response = self.responses[key]
else:
- response = MockArcadeResponse(200, {'status': 'success'})
-
+ response = MockArcadeResponse(200, {"status": "success"})
+
return response
-
+
async def close(self):
"""Mock close method."""
pass
-
+
def __aenter__(self):
return self
-
+
async def __aexit__(self, *args):
await self.close()
class ArcadeIntegrationTestSuite:
"""Comprehensive test suite for Arcade.dev integration."""
-
+
def __init__(self):
self.test_results: List[TestResult] = []
self.mock_session = None
self.cache_manager = None
self.temp_dir = None
-
+
async def setup(self):
"""Set up test environment."""
self.temp_dir = tempfile.mkdtemp()
-
+
# Create cache configuration
cache_config = {
"prefix": "test_cache",
@@ -228,686 +226,757 @@ async def setup(self):
"max_size": "10MB",
"ttl_seconds": 3600,
"hit_target_ms": 50,
- "miss_target_ms": 150
+ "miss_target_ms": 150,
}
-
+
self.cache_manager = CacheManager(cache_config)
# CacheManager doesn't have async initialize method
-
+
async def teardown(self):
"""Clean up test environment."""
if self.cache_manager:
# CacheManager doesn't have async close method
self.cache_manager = None
-
+
if self.temp_dir:
import shutil
+
shutil.rmtree(self.temp_dir, ignore_errors=True)
-
+
async def run_all_tests(self) -> List[TestResult]:
"""Run all integration tests."""
await self.setup()
-
+
try:
# Unit tests
await self.test_tool_registration()
await self.test_cache_integration()
await self.test_error_handling()
-
+
# Integration tests
await self.test_hybrid_execution()
await self.test_concurrent_requests()
await self.test_failover_behavior()
-
+
# Performance tests
await self.test_performance_benchmarks()
await self.test_memory_usage()
-
+
# Mock tests
await self.test_mock_responses()
await self.test_network_failures()
-
+
finally:
await self.teardown()
-
+
return self.test_results
-
+
async def test_tool_registration(self):
"""Test tool registration functionality."""
test_name = "tool_registration"
start_time = time.time()
errors = []
details = {}
-
+
try:
# Test basic tool execution using ToolExecutor
executor = ToolExecutor()
-
+
# Get available tools (this tests the tool registry)
available_tools = executor.get_available_tools()
- details['available_tools_count'] = len(available_tools)
-
+ details["available_tools_count"] = len(available_tools)
+
# Create a mock tool call to test execution path
from src.tools.executor import ToolCall, create_tool_call
-
+
# Test tool call creation
tool_call = create_tool_call(
tool_name="test_tool",
arguments={"param1": "hello", "param2": 42},
- user_id="test_user"
+ user_id="test_user",
)
- details['tool_call_created'] = True
-
+ details["tool_call_created"] = True
+
# Test that we can handle tool calls (even if tool doesn't exist)
try:
result = await executor.execute_tool_call(tool_call)
# Tool likely doesn't exist, so we expect this to fail with ToolNotFoundError
- details['unexpected_success'] = True
+ details["unexpected_success"] = True
except Exception as e:
# Expected - tool doesn't exist in registry
- details['expected_error'] = str(type(e).__name__)
-
+ details["expected_error"] = str(type(e).__name__)
+
# Test rate limiting functionality
- can_execute_before = executor.rate_limiter.can_execute() if executor.rate_limiter else True
- details['rate_limiting_enabled'] = executor.enable_rate_limiting
- details['can_execute'] = can_execute_before
-
- details['validation_tests'] = "passed"
+ can_execute_before = (
+ executor.rate_limiter.can_execute() if executor.rate_limiter else True
+ )
+ details["rate_limiting_enabled"] = executor.enable_rate_limiting
+ details["can_execute"] = can_execute_before
+
+ details["validation_tests"] = "passed"
success = True
-
+
except Exception as e:
success = False
errors.append(str(e))
-
+
execution_time = time.time() - start_time
- self.test_results.append(TestResult(
- test_name=test_name,
- success=success,
- execution_time=execution_time,
- details=details,
- errors=errors
- ))
-
+ self.test_results.append(
+ TestResult(
+ test_name=test_name,
+ success=success,
+ execution_time=execution_time,
+ details=details,
+ errors=errors,
+ )
+ )
+
async def test_cache_integration(self):
"""Test cache integration with Arcade.dev client."""
test_name = "cache_integration"
start_time = time.time()
errors = []
details = {}
-
+
try:
# Test cache operations
cache_key = "test_key_hash"
# Create content with sufficient tokens (minimum 100 required)
- test_data_str = json.dumps({
- "message": "cached_data_for_testing",
- "timestamp": time.time(),
- "description": "This is a test cache entry that contains enough content to meet the minimum token requirement for the cache manager. " * 10,
- "additional_data": ["item1", "item2", "item3"] * 20
- })
-
+ test_data_str = json.dumps(
+ {
+ "message": "cached_data_for_testing",
+ "timestamp": time.time(),
+ "description": "This is a test cache entry that contains enough content to meet the minimum token requirement for the cache manager. "
+ * 10,
+ "additional_data": ["item1", "item2", "item3"] * 20,
+ }
+ )
+
# Test cache store
entry = self.cache_manager.store(cache_key, test_data_str)
assert entry is not None
- details['cache_set'] = True
-
+ details["cache_set"] = True
+
# Test cache get
cached_entry = self.cache_manager.get(cache_key)
assert cached_entry is not None
assert cached_entry.content == test_data_str
- details['cache_get'] = True
-
+ details["cache_get"] = True
+
# Test cache invalidation (cache manager doesn't have delete method, skip for now)
- details['cache_invalidation'] = "skipped"
-
+ details["cache_invalidation"] = "skipped"
+
# Test cache with Arcade client
mock_responses = {
"GET:/v1/health": MockArcadeResponse(200, {"status": "healthy"})
}
self.mock_session = MockArcadeSession(mock_responses)
-
- with patch('aiohttp.ClientSession', return_value=self.mock_session):
- client = ArcadeClient(api_key="test_key", cache_manager=self.cache_manager)
+
+ with patch("aiohttp.ClientSession", return_value=self.mock_session):
+ client = ArcadeClient(
+ api_key="test_key", cache_manager=self.cache_manager
+ )
await client.connect()
-
+
# First request should hit API
result1 = await client.health_check()
- details['first_request'] = result1
-
+ details["first_request"] = result1
+
# Second request should hit cache
result2 = await client.health_check()
- details['cached_request'] = result2
-
+ details["cached_request"] = result2
+
# Verify both results are the same
assert result1 == result2
-
+
await client.disconnect()
-
+
success = True
-
+
except Exception as e:
success = False
errors.append(str(e))
-
+
execution_time = time.time() - start_time
- self.test_results.append(TestResult(
- test_name=test_name,
- success=success,
- execution_time=execution_time,
- details=details,
- errors=errors
- ))
-
+ self.test_results.append(
+ TestResult(
+ test_name=test_name,
+ success=success,
+ execution_time=execution_time,
+ details=details,
+ errors=errors,
+ )
+ )
+
async def test_hybrid_execution(self):
"""Test hybrid execution between local and Arcade.dev tools."""
test_name = "hybrid_execution"
start_time = time.time()
errors = []
details = {}
-
+
try:
# Set up mock Arcade responses
mock_responses = {
- "POST:/v1/tools/execute": MockArcadeResponse(200, {
- "result": {"analysis": "comprehensive", "score": 85}
- })
+ "POST:/v1/tools/execute": MockArcadeResponse(
+ 200, {"result": {"analysis": "comprehensive", "score": 85}}
+ )
}
self.mock_session = MockArcadeSession(mock_responses)
-
- with patch('aiohttp.ClientSession', return_value=self.mock_session):
+
+ with patch("aiohttp.ClientSession", return_value=self.mock_session):
# Create gateway for hybrid execution
gateway = ArcadeGateway(api_key="test_key")
await gateway.initialize()
-
+
# Test local tool execution
- local_result = await gateway.execute_local_tool("local_processor", {
- "data": "test_input"
- })
- details['local_execution'] = local_result
-
+ local_result = await gateway.execute_local_tool(
+ "local_processor", {"data": "test_input"}
+ )
+ details["local_execution"] = local_result
+
# Test Arcade tool execution
- arcade_result = await gateway.execute_arcade_tool("code_analyzer", {
- "code": "def hello(): return 'world'"
- })
- details['arcade_execution'] = arcade_result
-
+ arcade_result = await gateway.execute_arcade_tool(
+ "code_analyzer", {"code": "def hello(): return 'world'"}
+ )
+ details["arcade_execution"] = arcade_result
+
# Test hybrid workflow
- hybrid_result = await gateway.execute_hybrid_workflow([
- {"type": "local", "tool": "preprocessor", "params": {"raw_data": "input"}},
- {"type": "arcade", "tool": "analyzer", "params": {"processed_data": "{{previous_result}}"}},
- {"type": "local", "tool": "postprocessor", "params": {"analysis": "{{previous_result}}"}}
- ])
- details['hybrid_workflow'] = hybrid_result
-
+ hybrid_result = await gateway.execute_hybrid_workflow(
+ [
+ {
+ "type": "local",
+ "tool": "preprocessor",
+ "params": {"raw_data": "input"},
+ },
+ {
+ "type": "arcade",
+ "tool": "analyzer",
+ "params": {"processed_data": "{{previous_result}}"},
+ },
+ {
+ "type": "local",
+ "tool": "postprocessor",
+ "params": {"analysis": "{{previous_result}}"},
+ },
+ ]
+ )
+ details["hybrid_workflow"] = hybrid_result
+
await gateway.cleanup()
-
+
success = True
-
+
except Exception as e:
success = False
errors.append(str(e))
-
+
execution_time = time.time() - start_time
- self.test_results.append(TestResult(
- test_name=test_name,
- success=success,
- execution_time=execution_time,
- details=details,
- errors=errors
- ))
-
+ self.test_results.append(
+ TestResult(
+ test_name=test_name,
+ success=success,
+ execution_time=execution_time,
+ details=details,
+ errors=errors,
+ )
+ )
+
async def test_concurrent_requests(self):
"""Test concurrent request handling."""
test_name = "concurrent_requests"
start_time = time.time()
errors = []
details = {}
-
+
try:
# Set up mock responses with delays
mock_responses = {
"GET:/v1/health": MockArcadeResponse(200, {"status": "healthy"}),
- "POST:/v1/analyze": MockArcadeResponse(200, {"result": "analysis_complete"})
+ "POST:/v1/analyze": MockArcadeResponse(
+ 200, {"result": "analysis_complete"}
+ ),
}
self.mock_session = MockArcadeSession(mock_responses)
-
- with patch('aiohttp.ClientSession', return_value=self.mock_session):
+
+ with patch("aiohttp.ClientSession", return_value=self.mock_session):
client = ArcadeClient(api_key="test_key")
await client.connect()
-
+
# Create multiple concurrent tasks
tasks = []
for i in range(10):
if i % 2 == 0:
task = asyncio.create_task(client.health_check())
else:
- task = asyncio.create_task(client.analyze_code(f"code_{i}", "python"))
+ task = asyncio.create_task(
+ client.analyze_code(f"code_{i}", "python")
+ )
tasks.append(task)
-
+
# Wait for all tasks to complete
results = await asyncio.gather(*tasks, return_exceptions=True)
-
+
# Analyze results
- successful_requests = sum(1 for r in results if not isinstance(r, Exception))
+ successful_requests = sum(
+ 1 for r in results if not isinstance(r, Exception)
+ )
failed_requests = sum(1 for r in results if isinstance(r, Exception))
-
- details['total_requests'] = len(tasks)
- details['successful_requests'] = successful_requests
- details['failed_requests'] = failed_requests
- details['success_rate'] = successful_requests / len(tasks)
-
+
+ details["total_requests"] = len(tasks)
+ details["successful_requests"] = successful_requests
+ details["failed_requests"] = failed_requests
+ details["success_rate"] = successful_requests / len(tasks)
+
await client.disconnect()
-
+
# Verify reasonable success rate
- assert details['success_rate'] >= 0.8 # At least 80% success
+ assert details["success_rate"] >= 0.8 # At least 80% success
success = True
-
+
except Exception as e:
success = False
errors.append(str(e))
-
+
execution_time = time.time() - start_time
- self.test_results.append(TestResult(
- test_name=test_name,
- success=success,
- execution_time=execution_time,
- details=details,
- errors=errors
- ))
-
+ self.test_results.append(
+ TestResult(
+ test_name=test_name,
+ success=success,
+ execution_time=execution_time,
+ details=details,
+ errors=errors,
+ )
+ )
+
async def test_performance_benchmarks(self):
"""Test performance benchmarks."""
test_name = "performance_benchmarks"
start_time = time.time()
errors = []
details = {}
-
+
try:
# Benchmark cache operations
cache_times = []
for i in range(100):
cache_start = time.time()
# Create content with sufficient tokens
- test_data = json.dumps({
- "data": f"value_{i}",
- "description": "Performance test data that contains enough content to meet the minimum token requirement. " * 10,
- "iteration": i,
- "additional_fields": ["field1", "field2", "field3"] * 10
- })
+ test_data = json.dumps(
+ {
+ "data": f"value_{i}",
+ "description": "Performance test data that contains enough content to meet the minimum token requirement. "
+ * 10,
+ "iteration": i,
+ "additional_fields": ["field1", "field2", "field3"] * 10,
+ }
+ )
self.cache_manager.store(f"perf_key_{i}", test_data)
result = self.cache_manager.get(f"perf_key_{i}")
cache_times.append(time.time() - cache_start)
-
- details['cache_avg_time'] = sum(cache_times) / len(cache_times)
- details['cache_max_time'] = max(cache_times)
- details['cache_min_time'] = min(cache_times)
-
+
+ details["cache_avg_time"] = sum(cache_times) / len(cache_times)
+ details["cache_max_time"] = max(cache_times)
+ details["cache_min_time"] = min(cache_times)
+
# Benchmark tool execution using ToolExecutor
executor = ToolExecutor()
-
+
# Create mock tool calls for performance testing
tool_times = []
for i in range(50):
tool_start = time.time()
-
+
# Create a tool call that should fail (no tool registered)
tool_call = create_tool_call(
tool_name="benchmark_tool",
arguments={"size": 1000},
- user_id="test_user"
+ user_id="test_user",
)
-
+
# Execute and expect failure (measure execution time)
try:
await executor.execute_tool_call(tool_call)
except Exception:
pass # Expected - no tool registered
-
+
tool_times.append(time.time() - tool_start)
-
- details['tool_avg_time'] = sum(tool_times) / len(tool_times)
- details['tool_max_time'] = max(tool_times)
- details['tool_min_time'] = min(tool_times)
-
+
+ details["tool_avg_time"] = sum(tool_times) / len(tool_times)
+ details["tool_max_time"] = max(tool_times)
+ details["tool_min_time"] = min(tool_times)
+
# Performance assertions
- assert details['cache_avg_time'] < 0.01 # Cache ops should be fast
- assert details['tool_avg_time'] < 0.1 # Tool execution reasonable (even for failures)
-
+ assert details["cache_avg_time"] < 0.01 # Cache ops should be fast
+ assert (
+ details["tool_avg_time"] < 0.1
+ ) # Tool execution reasonable (even for failures)
+
success = True
-
+
except Exception as e:
success = False
errors.append(str(e))
-
+
execution_time = time.time() - start_time
- self.test_results.append(TestResult(
- test_name=test_name,
- success=success,
- execution_time=execution_time,
- details=details,
- errors=errors
- ))
-
+ self.test_results.append(
+ TestResult(
+ test_name=test_name,
+ success=success,
+ execution_time=execution_time,
+ details=details,
+ errors=errors,
+ )
+ )
+
async def test_error_handling(self):
"""Test error handling scenarios."""
test_name = "error_handling"
start_time = time.time()
errors = []
details = {}
-
+
try:
# Test tool execution errors (skip network errors since mock isn't working as expected)
executor = ToolExecutor()
-
+
# Test invalid tool handling
try:
invalid_tool_call = create_tool_call(
- tool_name="nonexistent_tool",
- arguments={},
- user_id="test_user"
+ tool_name="nonexistent_tool", arguments={}, user_id="test_user"
)
result = await executor.execute_tool_call(invalid_tool_call)
assert False, "Should have raised exception"
except Exception as e:
- details['invalid_tool_handling'] = "passed"
- details['invalid_tool_error'] = str(type(e).__name__)
-
+ details["invalid_tool_handling"] = "passed"
+ details["invalid_tool_error"] = str(type(e).__name__)
+
# Test rate limiting (if enabled)
if executor.enable_rate_limiting:
# Record many calls to test rate limiting
for _ in range(65): # Exceed default limit of 60
executor.rate_limiter.record_call("test_user")
-
+
can_execute_after = executor.rate_limiter.can_execute("test_user")
- details['rate_limiting_works'] = not can_execute_after
-
- details['tool_error_handling'] = "passed"
- details['network_error_handling'] = "skipped (mock issues)"
-
+ details["rate_limiting_works"] = not can_execute_after
+
+ details["tool_error_handling"] = "passed"
+ details["network_error_handling"] = "skipped (mock issues)"
+
success = True
-
+
except Exception as e:
success = False
errors.append(str(e))
-
+
execution_time = time.time() - start_time
- self.test_results.append(TestResult(
- test_name=test_name,
- success=success,
- execution_time=execution_time,
- details=details,
- errors=errors
- ))
-
+ self.test_results.append(
+ TestResult(
+ test_name=test_name,
+ success=success,
+ execution_time=execution_time,
+ details=details,
+ errors=errors,
+ )
+ )
+
async def test_mock_responses(self):
"""Test various mock response scenarios."""
test_name = "mock_responses"
start_time = time.time()
errors = []
details = {}
-
+
try:
# Test different response types
scenarios = [
- ("success", MockArcadeResponse(200, {"status": "success", "data": "test"})),
+ (
+ "success",
+ MockArcadeResponse(200, {"status": "success", "data": "test"}),
+ ),
("not_found", MockArcadeResponse(404, {"error": "Not found"})),
("rate_limited", MockArcadeResponse(429, {"error": "Rate limited"})),
- ("server_error", MockArcadeResponse(500, {"error": "Internal error"}))
+ ("server_error", MockArcadeResponse(500, {"error": "Internal error"})),
]
-
+
for scenario_name, mock_response in scenarios:
mock_responses = {"GET:/v1/test": mock_response}
self.mock_session = MockArcadeSession(mock_responses)
-
- with patch('aiohttp.ClientSession', return_value=self.mock_session):
+
+ with patch("aiohttp.ClientSession", return_value=self.mock_session):
client = ArcadeClient(api_key="test_key", max_retries=1)
await client.connect()
-
+
try:
result = await client.make_request("GET", "/test")
- details[f'{scenario_name}_result'] = result
+ details[f"{scenario_name}_result"] = result
except Exception as e:
- details[f'{scenario_name}_error'] = str(type(e).__name__)
-
+ details[f"{scenario_name}_error"] = str(type(e).__name__)
+
await client.disconnect()
-
+
success = True
-
+
except Exception as e:
success = False
errors.append(str(e))
-
+
execution_time = time.time() - start_time
- self.test_results.append(TestResult(
- test_name=test_name,
- success=success,
- execution_time=execution_time,
- details=details,
- errors=errors
- ))
-
+ self.test_results.append(
+ TestResult(
+ test_name=test_name,
+ success=success,
+ execution_time=execution_time,
+ details=details,
+ errors=errors,
+ )
+ )
+
async def test_network_failures(self):
"""Test network failure scenarios."""
test_name = "network_failures"
start_time = time.time()
errors = []
details = {}
-
+
try:
# Test timeout scenario - simulate by setting very short timeout
try:
# Create a client with extremely short timeout to force timeout
client = ArcadeClient(api_key="test_key", timeout=0.001, max_retries=1)
await client.connect()
-
+
# This should timeout due to the very short timeout
await asyncio.wait_for(client.health_check(), timeout=0.001)
- details['timeout_handling'] = "timeout not triggered as expected"
- except (asyncio.TimeoutError, aiohttp.ClientTimeout, aiohttp.ServerTimeoutError):
- details['timeout_handling'] = "passed"
+ details["timeout_handling"] = "timeout not triggered as expected"
+ except (
+ asyncio.TimeoutError,
+ aiohttp.ClientTimeout,
+ aiohttp.ServerTimeoutError,
+ ):
+ details["timeout_handling"] = "passed"
except Exception as e:
# Accept any network-related error as timeout-like behavior
if "timeout" in str(e).lower() or "time" in str(e).lower():
- details['timeout_handling'] = "passed"
+ details["timeout_handling"] = "passed"
else:
- details['timeout_handling'] = f"unexpected error: {e}"
+ details["timeout_handling"] = f"unexpected error: {e}"
finally:
try:
await client.disconnect()
except:
pass
-
+
# Test connection error with invalid URL
try:
- client = ArcadeClient(api_key="test_key", base_url="http://invalid-host-12345.local", max_retries=1)
+ client = ArcadeClient(
+ api_key="test_key",
+ base_url="http://invalid-host-12345.local",
+ max_retries=1,
+ )
await client.connect()
await client.health_check()
- details['connection_error_handling'] = "connection error not triggered"
+ details["connection_error_handling"] = "connection error not triggered"
except (aiohttp.ClientConnectorError, aiohttp.ClientError, OSError):
- details['connection_error_handling'] = "passed"
+ details["connection_error_handling"] = "passed"
except Exception as e:
# Accept DNS or connection-related errors
- if any(word in str(e).lower() for word in ['connection', 'resolve', 'network', 'host']):
- details['connection_error_handling'] = "passed"
+ if any(
+ word in str(e).lower()
+ for word in ["connection", "resolve", "network", "host"]
+ ):
+ details["connection_error_handling"] = "passed"
else:
- details['connection_error_handling'] = f"unexpected error: {e}"
+ details["connection_error_handling"] = f"unexpected error: {e}"
finally:
try:
await client.disconnect()
except:
pass
-
+
success = True
-
+
except Exception as e:
success = False
errors.append(str(e))
-
+
execution_time = time.time() - start_time
- self.test_results.append(TestResult(
- test_name=test_name,
- success=success,
- execution_time=execution_time,
- details=details,
- errors=errors
- ))
-
+ self.test_results.append(
+ TestResult(
+ test_name=test_name,
+ success=success,
+ execution_time=execution_time,
+ details=details,
+ errors=errors,
+ )
+ )
+
async def test_failover_behavior(self):
"""Test failover behavior between services."""
test_name = "failover_behavior"
start_time = time.time()
errors = []
details = {}
-
+
try:
# Simulate primary service failure and fallback
primary_responses = {
- "POST:/v1/analyze": MockArcadeResponse(503, {"error": "Service unavailable"})
+ "POST:/v1/analyze": MockArcadeResponse(
+ 503, {"error": "Service unavailable"}
+ )
}
-
+
fallback_responses = {
- "POST:/v1/analyze": MockArcadeResponse(200, {"result": "fallback_analysis"})
+ "POST:/v1/analyze": MockArcadeResponse(
+ 200, {"result": "fallback_analysis"}
+ )
}
-
+
# Test with failover logic
- with patch('aiohttp.ClientSession') as mock_session_class:
+ with patch("aiohttp.ClientSession") as mock_session_class:
# First attempt fails, second succeeds
failed_session = MockArcadeSession(primary_responses)
success_session = MockArcadeSession(fallback_responses)
mock_session_class.side_effect = [failed_session, success_session]
-
+
gateway = ArcadeGateway(api_key="test_key", enable_failover=True)
-
+
# Should succeed via fallback
result = await gateway.analyze_with_failover("test code")
- details['failover_result'] = result
+ details["failover_result"] = result
assert result.get("result") == "fallback_analysis"
-
+
success = True
-
+
except Exception as e:
success = False
errors.append(str(e))
-
+
execution_time = time.time() - start_time
- self.test_results.append(TestResult(
- test_name=test_name,
- success=success,
- execution_time=execution_time,
- details=details,
- errors=errors
- ))
-
+ self.test_results.append(
+ TestResult(
+ test_name=test_name,
+ success=success,
+ execution_time=execution_time,
+ details=details,
+ errors=errors,
+ )
+ )
+
async def test_memory_usage(self):
"""Test memory usage during operations."""
test_name = "memory_usage"
start_time = time.time()
errors = []
details = {}
-
+
try:
import psutil
+
process = psutil.Process()
-
+
# Baseline memory
baseline_memory = process.memory_info().rss / 1024 / 1024 # MB
-
+
# Create large cache entries
large_data = json.dumps({"data": "x" * 10000})
for i in range(100):
self.cache_manager.store(f"large_key_{i}", large_data)
-
+
# Memory after cache operations
cache_memory = process.memory_info().rss / 1024 / 1024 # MB
-
+
# Clean up cache (skip deletion as manager doesn't have delete method)
# for i in range(100):
# self.cache_manager.delete(f"large_key_{i}")
-
+
# Memory after cleanup
cleanup_memory = process.memory_info().rss / 1024 / 1024 # MB
-
- details['baseline_memory_mb'] = baseline_memory
- details['cache_memory_mb'] = cache_memory
- details['cleanup_memory_mb'] = cleanup_memory
- details['memory_growth_mb'] = cache_memory - baseline_memory
- details['memory_recovered_mb'] = cache_memory - cleanup_memory
-
+
+ details["baseline_memory_mb"] = baseline_memory
+ details["cache_memory_mb"] = cache_memory
+ details["cleanup_memory_mb"] = cleanup_memory
+ details["memory_growth_mb"] = cache_memory - baseline_memory
+ details["memory_recovered_mb"] = cache_memory - cleanup_memory
+
# Memory growth should be reasonable
- assert details['memory_growth_mb'] < 100 # Less than 100MB growth
-
+ assert details["memory_growth_mb"] < 100 # Less than 100MB growth
+
success = True
-
+
except ImportError:
# psutil not available, skip memory test
- details['skipped'] = "psutil not available"
+ details["skipped"] = "psutil not available"
success = True
except Exception as e:
success = False
errors.append(str(e))
-
+
execution_time = time.time() - start_time
- self.test_results.append(TestResult(
- test_name=test_name,
- success=success,
- execution_time=execution_time,
- details=details,
- errors=errors
- ))
+ self.test_results.append(
+ TestResult(
+ test_name=test_name,
+ success=success,
+ execution_time=execution_time,
+ details=details,
+ errors=errors,
+ )
+ )
def create_test_fixtures():
"""Create test fixtures and utilities."""
-
+
class TestFixtures:
"""Collection of test fixtures."""
-
+
@staticmethod
def create_mock_cache_manager():
"""Create a mock cache manager."""
cache = {}
-
+
class MockCacheManager:
async def get(self, key: str):
return cache.get(key)
-
+
async def set(self, key: str, value: Any, ttl: int = None):
cache[key] = value
-
+
async def delete(self, key: str):
cache.pop(key, None)
-
+
async def initialize(self):
pass
-
+
async def close(self):
cache.clear()
-
+
return MockCacheManager()
-
+
@staticmethod
def create_sample_tools():
"""Create sample tools for testing."""
tools = {}
-
- async def sample_analyzer(code: str, language: str = "python") -> Dict[str, Any]:
+
+ async def sample_analyzer(
+ code: str, language: str = "python"
+ ) -> Dict[str, Any]:
return {
- "lines": len(code.split('\n')),
+ "lines": len(code.split("\n")),
"language": language,
- "complexity": "low"
- }
-
- async def sample_formatter(code: str, style: str = "black") -> Dict[str, Any]:
- return {
- "formatted_code": code.strip(),
- "style_applied": style
+ "complexity": "low",
}
-
+
+ async def sample_formatter(
+ code: str, style: str = "black"
+ ) -> Dict[str, Any]:
+ return {"formatted_code": code.strip(), "style_applied": style}
+
tools["analyzer"] = sample_analyzer
tools["formatter"] = sample_formatter
-
+
return tools
-
+
@staticmethod
def create_test_data():
"""Create test data sets."""
@@ -920,9 +989,9 @@ def fibonacci(n):
return fibonacci(n-1) + fibonacci(n-2)
""",
"invalid_code": "def broken( syntax error",
- "large_code": "# " + "x" * 10000
+ "large_code": "# " + "x" * 10000,
}
-
+
return TestFixtures
@@ -930,52 +999,53 @@ async def run_comprehensive_tests():
"""Run comprehensive test suite."""
print("š§Ŗ Running Comprehensive Arcade.dev Integration Tests")
print("=" * 60)
-
+
test_suite = ArcadeIntegrationTestSuite()
results = await test_suite.run_all_tests()
-
+
# Print results summary
total_tests = len(results)
passed_tests = sum(1 for r in results if r.success)
failed_tests = total_tests - passed_tests
-
+
print(f"\nš Test Results Summary:")
print(f" Total Tests: {total_tests}")
print(f" Passed: {passed_tests}")
print(f" Failed: {failed_tests}")
print(f" Success Rate: {(passed_tests/total_tests)*100:.1f}%")
-
+
# Print detailed results
print(f"\nš Detailed Results:")
for result in results:
status = "ā
" if result.success else "ā"
print(f" {status} {result.test_name}: {result.execution_time:.3f}s")
-
+
if result.errors:
for error in result.errors:
print(f" Error: {error}")
-
+
return passed_tests == total_tests
async def main():
"""Main test execution function."""
import logging
+
logging.basicConfig(
level=logging.WARNING, # Reduce noise during tests
- format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
+ format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
)
-
+
try:
success = await run_comprehensive_tests()
-
+
if success:
print("\nš All tests passed successfully!")
return 0
else:
print("\nā Some tests failed!")
return 1
-
+
except Exception as e:
print(f"ā Test execution failed: {e}")
return 1
@@ -983,4 +1053,4 @@ async def main():
if __name__ == "__main__":
exit_code = asyncio.run(main())
- sys.exit(exit_code)
\ No newline at end of file
+ sys.exit(exit_code)
diff --git a/examples/arcade-dev/10_deployment/production_deployment.py b/examples/arcade-dev/10_deployment/production_deployment.py
index 26c8e53..cd45bd4 100644
--- a/examples/arcade-dev/10_deployment/production_deployment.py
+++ b/examples/arcade-dev/10_deployment/production_deployment.py
@@ -25,6 +25,7 @@
import threading
from concurrent.futures import ThreadPoolExecutor
import subprocess
+
try:
import psutil
except ImportError:
@@ -48,24 +49,31 @@
# Mock classes for demonstration (in production, these would be real implementations)
class ArcadeClient:
"""Mock Arcade.dev API client for deployment example."""
-
- def __init__(self, api_key: str, api_url: str, timeout: int, max_retries: int, cache_manager=None):
+
+ def __init__(
+ self,
+ api_key: str,
+ api_url: str,
+ timeout: int,
+ max_retries: int,
+ cache_manager=None,
+ ):
self.api_key = api_key
self.api_url = api_url
self.timeout = timeout
self.max_retries = max_retries
self.cache_manager = cache_manager
self.connected = False
-
+
async def connect(self):
"""Connect to Arcade.dev API."""
# Mock connection
self.connected = True
-
+
async def disconnect(self):
"""Disconnect from Arcade.dev API."""
self.connected = False
-
+
async def health_check(self):
"""Perform health check."""
if not self.connected:
@@ -75,18 +83,18 @@ async def health_check(self):
class ArcadeGateway:
"""Mock Arcade.dev gateway for deployment example."""
-
+
def __init__(self, arcade_client, cache_manager, metrics, authorization_manager):
self.arcade_client = arcade_client
self.cache_manager = cache_manager
self.metrics = metrics
self.authorization_manager = authorization_manager
self.initialized = False
-
+
async def initialize(self):
"""Initialize the gateway."""
self.initialized = True
-
+
async def cleanup(self):
"""Clean up gateway resources."""
self.initialized = False
@@ -94,12 +102,14 @@ async def cleanup(self):
class ServiceStartupError(Exception):
"""Custom exception for service startup errors."""
+
pass
@dataclass
class ServiceHealthStatus:
"""Health status of a service component."""
+
service_name: str
healthy: bool
last_check: float
@@ -110,38 +120,39 @@ class ServiceHealthStatus:
@dataclass
class DeploymentConfig:
"""Production deployment configuration."""
+
# Service configuration
service_name: str = "fact-arcade-integration"
service_version: str = "1.0.0"
environment: str = "production"
-
+
# Network configuration
host: str = "0.0.0.0"
port: int = 8080
health_check_port: int = 8081
-
+
# Arcade.dev configuration
arcade_api_key: str = ""
arcade_api_url: str = "https://api.arcade.dev"
arcade_timeout: int = 30
arcade_max_retries: int = 3
-
+
# Cache configuration
cache_backend: str = "redis"
cache_host: str = "localhost"
cache_port: int = 6379
cache_db: int = 0
-
+
# Monitoring configuration
metrics_enabled: bool = True
metrics_port: int = 9090
log_level: str = "INFO"
-
+
# Performance configuration
worker_threads: int = 4
max_concurrent_requests: int = 100
request_timeout: int = 60
-
+
# Security configuration
enable_auth: bool = False # Disabled for demo
jwt_secret: str = ""
@@ -150,11 +161,11 @@ class DeploymentConfig:
class ProductionArcadeService:
"""Production-ready Arcade.dev integration service."""
-
+
def __init__(self, config: DeploymentConfig):
self.config = config
self.logger = self._setup_logging()
-
+
# Core components
self.driver: Optional[FACTDriver] = None
self.cache_manager: Optional[CacheManager] = None
@@ -162,127 +173,138 @@ def __init__(self, config: DeploymentConfig):
self.arcade_gateway: Optional[ArcadeGateway] = None
self.metrics: Optional[MetricsCollector] = None
self.authorization_manager: Optional[AuthorizationManager] = None
-
+
# Service state
self.is_running = False
self.is_ready = False
self.startup_time: Optional[float] = None
self.shutdown_event = asyncio.Event()
self.health_status: Dict[str, ServiceHealthStatus] = {}
-
+
# Background tasks
self.background_tasks: List[asyncio.Task] = []
self.thread_pool = ThreadPoolExecutor(max_workers=config.worker_threads)
-
+
# Signal handlers
self._setup_signal_handlers()
-
+
def _setup_logging(self) -> logging.Logger:
"""Set up production logging configuration."""
# Create log handlers
handlers = [logging.StreamHandler()]
-
+
# Try to add file handler, but don't fail if directory doesn't exist
try:
- log_file = f'/var/log/{self.config.service_name}.log'
+ log_file = f"/var/log/{self.config.service_name}.log"
handlers.append(logging.FileHandler(log_file))
except (PermissionError, FileNotFoundError):
# Fallback to local log file
try:
- log_file = f'./{self.config.service_name}.log'
+ log_file = f"./{self.config.service_name}.log"
handlers.append(logging.FileHandler(log_file))
except Exception:
# If all file logging fails, just use console
pass
-
+
logging.basicConfig(
level=getattr(logging, self.config.log_level.upper()),
- format='%(asctime)s - %(name)s - %(levelname)s - %(message)s - %(pathname)s:%(lineno)d',
- handlers=handlers
+ format="%(asctime)s - %(name)s - %(levelname)s - %(message)s - %(pathname)s:%(lineno)d",
+ handlers=handlers,
)
-
+
logger = logging.getLogger(self.config.service_name)
- logger.info(f"Initializing {self.config.service_name} v{self.config.service_version}")
+ logger.info(
+ f"Initializing {self.config.service_name} v{self.config.service_version}"
+ )
return logger
-
+
def _setup_signal_handlers(self):
"""Set up signal handlers for graceful shutdown."""
+
def signal_handler(signum, frame):
self.logger.info(f"Received signal {signum}, initiating graceful shutdown")
asyncio.create_task(self.shutdown())
-
+
signal.signal(signal.SIGTERM, signal_handler)
signal.signal(signal.SIGINT, signal_handler)
-
+
async def initialize(self):
"""Initialize all service components."""
self.logger.info("Starting service initialization")
initialization_start = time.time()
-
+
try:
# Load and validate configuration
await self._load_configuration()
-
+
# Initialize core components
await self._initialize_cache()
await self._initialize_arcade_client()
await self._initialize_monitoring()
await self._initialize_authorization()
await self._initialize_gateway()
-
+
# Start background tasks
await self._start_background_tasks()
-
+
# Perform initial health checks
await self._perform_initial_health_checks()
-
+
self.startup_time = time.time() - initialization_start
self.is_running = True
self.is_ready = True
-
- self.logger.info(f"Service initialization completed in {self.startup_time:.2f}s")
-
+
+ self.logger.info(
+ f"Service initialization completed in {self.startup_time:.2f}s"
+ )
+
except Exception as e:
self.logger.error(f"Service initialization failed: {e}")
await self.cleanup()
raise ServiceStartupError(f"Failed to initialize service: {e}")
-
+
async def _load_configuration(self):
"""Load and validate configuration from environment and files."""
self.logger.info("Loading configuration")
-
+
# Load from environment variables
- self.config.arcade_api_key = os.getenv("ARCADE_API_KEY", self.config.arcade_api_key)
- self.config.arcade_api_url = os.getenv("ARCADE_API_URL", self.config.arcade_api_url)
+ self.config.arcade_api_key = os.getenv(
+ "ARCADE_API_KEY", self.config.arcade_api_key
+ )
+ self.config.arcade_api_url = os.getenv(
+ "ARCADE_API_URL", self.config.arcade_api_url
+ )
self.config.cache_host = os.getenv("CACHE_HOST", self.config.cache_host)
self.config.jwt_secret = os.getenv("JWT_SECRET", self.config.jwt_secret)
-
+
# Load from configuration file if exists
config_file = os.getenv("CONFIG_FILE", "/etc/fact-arcade/config.yaml")
if os.path.exists(config_file):
- with open(config_file, 'r') as f:
+ with open(config_file, "r") as f:
file_config = yaml.safe_load(f)
self._merge_config(file_config)
-
+
# Validate required configuration
if not self.config.arcade_api_key:
raise ConfigurationError("ARCADE_API_KEY is required")
-
+
if self.config.enable_auth and not self.config.jwt_secret:
- raise ConfigurationError("JWT_SECRET is required when authentication is enabled")
-
+ raise ConfigurationError(
+ "JWT_SECRET is required when authentication is enabled"
+ )
+
self.logger.info("Configuration loaded and validated")
-
+
def _merge_config(self, file_config: Dict[str, Any]):
"""Merge file configuration with current config."""
for key, value in file_config.items():
if hasattr(self.config, key):
setattr(self.config, key, value)
-
+
async def _initialize_cache(self):
"""Initialize cache manager."""
self.logger.info("Initializing cache manager")
-
+
cache_config = {
"prefix": "fact_arcade_cache",
"min_tokens": 1, # Reduced for demo purposes
@@ -291,142 +313,141 @@ async def _initialize_cache(self):
"backend": self.config.cache_backend,
"host": self.config.cache_host,
"port": self.config.cache_port,
- "db": self.config.cache_db
+ "db": self.config.cache_db,
}
-
+
self.cache_manager = CacheManager(cache_config)
# CacheManager doesn't require async initialization
-
+
# Test cache connectivity
test_key = "health_check_cache"
test_content = '{"status": "ok"}'
self.cache_manager.store(test_key, test_content)
result = self.cache_manager.get(test_key)
-
+
if not result:
raise ServiceStartupError("Cache connectivity test failed")
-
+
self.cache_manager.invalidate_by_prefix(test_key)
self.logger.info("Cache manager initialized successfully")
-
+
async def _initialize_arcade_client(self):
"""Initialize Arcade.dev client."""
self.logger.info("Initializing Arcade.dev client")
-
+
self.arcade_client = ArcadeClient(
api_key=self.config.arcade_api_key,
api_url=self.config.arcade_api_url,
timeout=self.config.arcade_timeout,
max_retries=self.config.arcade_max_retries,
- cache_manager=self.cache_manager
+ cache_manager=self.cache_manager,
)
-
+
await self.arcade_client.connect()
-
+
# Test API connectivity
try:
health_result = await self.arcade_client.health_check()
- self.logger.info(f"Arcade.dev API health: {health_result.get('status', 'unknown')}")
+ self.logger.info(
+ f"Arcade.dev API health: {health_result.get('status', 'unknown')}"
+ )
except Exception as e:
self.logger.warning(f"Arcade.dev API health check failed: {e}")
-
+
self.logger.info("Arcade.dev client initialized successfully")
-
+
async def _initialize_monitoring(self):
"""Initialize monitoring and metrics collection."""
if not self.config.metrics_enabled:
self.logger.info("Metrics collection disabled")
return
-
+
self.logger.info("Initializing metrics collection")
-
- self.metrics = MetricsCollector(
- max_history=10000
- )
-
+
+ self.metrics = MetricsCollector(max_history=10000)
+
# MetricsCollector doesn't require async initialization
-
+
# Register custom metrics
await self._register_custom_metrics()
-
+
self.logger.info("Metrics collection initialized successfully")
-
+
async def _register_custom_metrics(self):
"""Register custom application metrics."""
if not self.metrics:
return
-
+
# MetricsCollector from monitoring doesn't require metric registration
# It automatically tracks tool execution metrics
self.logger.info("Metrics collection ready")
-
+
async def _initialize_authorization(self):
"""Initialize security manager."""
if not self.config.enable_auth:
self.logger.info("Authorization disabled")
return
-
+
self.logger.info("Initializing authorization manager")
-
+
self.authorization_manager = AuthorizationManager()
-
+
# Authorization manager doesn't need async initialization
self.logger.info("Authorization manager initialized successfully")
-
+
async def _initialize_gateway(self):
"""Initialize Arcade.dev gateway."""
self.logger.info("Initializing Arcade.dev gateway")
-
+
self.arcade_gateway = ArcadeGateway(
arcade_client=self.arcade_client,
cache_manager=self.cache_manager,
metrics=self.metrics,
- authorization_manager=self.authorization_manager
+ authorization_manager=self.authorization_manager,
)
-
+
await self.arcade_gateway.initialize()
self.logger.info("Arcade.dev gateway initialized successfully")
-
+
async def _start_background_tasks(self):
"""Start background maintenance tasks."""
self.logger.info("Starting background tasks")
-
+
# Health check task
health_task = asyncio.create_task(self._health_check_loop())
self.background_tasks.append(health_task)
-
+
# Metrics collection task
if self.metrics:
metrics_task = asyncio.create_task(self._metrics_collection_loop())
self.background_tasks.append(metrics_task)
-
+
# Cache cleanup task
if self.cache_manager:
cleanup_task = asyncio.create_task(self._cache_cleanup_loop())
self.background_tasks.append(cleanup_task)
-
+
self.logger.info(f"Started {len(self.background_tasks)} background tasks")
-
+
async def _perform_initial_health_checks(self):
"""Perform initial health checks on all components."""
self.logger.info("Performing initial health checks")
-
+
# Check all components
await self._check_cache_health()
await self._check_arcade_health()
await self._check_system_health()
-
+
# Verify all components are healthy
unhealthy_services = [
- name for name, status in self.health_status.items()
- if not status.healthy
+ name for name, status in self.health_status.items() if not status.healthy
]
-
+
if unhealthy_services:
raise ServiceStartupError(f"Unhealthy services: {unhealthy_services}")
-
+
self.logger.info("All initial health checks passed")
-
+
async def _health_check_loop(self):
"""Background task for periodic health checks."""
while not self.shutdown_event.is_set():
@@ -434,68 +455,67 @@ async def _health_check_loop(self):
await self._check_cache_health()
await self._check_arcade_health()
await self._check_system_health()
-
+
# Wait for next check
await asyncio.wait_for(
- self.shutdown_event.wait(),
- timeout=30.0 # Check every 30 seconds
+ self.shutdown_event.wait(), timeout=30.0 # Check every 30 seconds
)
-
+
except asyncio.TimeoutError:
continue # Continue health checks
except Exception as e:
self.logger.error(f"Health check loop error: {e}")
await asyncio.sleep(30)
-
+
async def _check_cache_health(self):
"""Check cache health."""
try:
start_time = time.time()
test_key = f"health_check_{int(time.time())}"
-
+
self.cache_manager.store(test_key, "ok")
result = self.cache_manager.get(test_key)
self.cache_manager.invalidate_by_prefix(test_key)
-
+
response_time = time.time() - start_time
-
+
self.health_status["cache"] = ServiceHealthStatus(
service_name="cache",
healthy=result is not None and result.content == "ok",
last_check=time.time(),
- details={"response_time": response_time}
+ details={"response_time": response_time},
)
-
+
except Exception as e:
self.health_status["cache"] = ServiceHealthStatus(
service_name="cache",
healthy=False,
last_check=time.time(),
- error_message=str(e)
+ error_message=str(e),
)
-
+
async def _check_arcade_health(self):
"""Check Arcade.dev API health."""
try:
start_time = time.time()
result = await self.arcade_client.health_check()
response_time = time.time() - start_time
-
+
self.health_status["arcade"] = ServiceHealthStatus(
service_name="arcade",
healthy=result.get("status") == "healthy",
last_check=time.time(),
- details={"response_time": response_time, "api_status": result}
+ details={"response_time": response_time, "api_status": result},
)
-
+
except Exception as e:
self.health_status["arcade"] = ServiceHealthStatus(
service_name="arcade",
healthy=False,
last_check=time.time(),
- error_message=str(e)
+ error_message=str(e),
)
-
+
async def _check_system_health(self):
"""Check system resource health."""
try:
@@ -509,11 +529,11 @@ async def _check_system_health(self):
process = psutil.Process()
memory_info = process.memory_info()
cpu_percent = process.cpu_percent()
-
+
# Check if we're using too much memory (> 1GB as example)
memory_mb = memory_info.rss / 1024 / 1024
healthy = memory_mb < 1024 and cpu_percent < 80
-
+
self.health_status["system"] = ServiceHealthStatus(
service_name="system",
healthy=healthy,
@@ -521,18 +541,20 @@ async def _check_system_health(self):
details={
"memory_mb": memory_mb,
"cpu_percent": cpu_percent,
- "uptime": time.time() - self.startup_time if self.startup_time else 0
- }
+ "uptime": (
+ time.time() - self.startup_time if self.startup_time else 0
+ ),
+ },
)
-
+
except Exception as e:
self.health_status["system"] = ServiceHealthStatus(
service_name="system",
healthy=False,
last_check=time.time(),
- error_message=str(e)
+ error_message=str(e),
)
-
+
async def _metrics_collection_loop(self):
"""Background task for metrics collection."""
while not self.shutdown_event.is_set():
@@ -542,19 +564,18 @@ async def _metrics_collection_loop(self):
# MetricsCollector from monitoring doesn't have record_gauge method
# It automatically tracks metrics during tool execution
self.logger.debug(f"Service uptime: {uptime:.2f} seconds")
-
+
# Wait for next collection
await asyncio.wait_for(
- self.shutdown_event.wait(),
- timeout=60.0 # Collect every minute
+ self.shutdown_event.wait(), timeout=60.0 # Collect every minute
)
-
+
except asyncio.TimeoutError:
continue
except Exception as e:
self.logger.error(f"Metrics collection error: {e}")
await asyncio.sleep(60)
-
+
async def _cache_cleanup_loop(self):
"""Background task for cache cleanup."""
while not self.shutdown_event.is_set():
@@ -562,25 +583,22 @@ async def _cache_cleanup_loop(self):
if self.cache_manager:
# Use private method for cleanup (sync, not async)
self.cache_manager._cleanup_expired()
-
+
# Wait for next cleanup
await asyncio.wait_for(
- self.shutdown_event.wait(),
- timeout=300.0 # Cleanup every 5 minutes
+ self.shutdown_event.wait(), timeout=300.0 # Cleanup every 5 minutes
)
-
+
except asyncio.TimeoutError:
continue
except Exception as e:
self.logger.error(f"Cache cleanup error: {e}")
await asyncio.sleep(300)
-
+
async def get_health_status(self) -> Dict[str, Any]:
"""Get current health status of all components."""
- overall_healthy = all(
- status.healthy for status in self.health_status.values()
- )
-
+ overall_healthy = all(status.healthy for status in self.health_status.values())
+
return {
"service": self.config.service_name,
"version": self.config.service_version,
@@ -592,78 +610,78 @@ async def get_health_status(self) -> Dict[str, Any]:
"healthy": status.healthy,
"last_check": status.last_check,
"details": status.details,
- "error": status.error_message
+ "error": status.error_message,
}
for name, status in self.health_status.items()
- }
+ },
}
-
+
async def get_readiness_status(self) -> Dict[str, Any]:
"""Get readiness status for load balancer probes."""
return {
"ready": self.is_ready,
"startup_time": self.startup_time,
- "components_ready": len(self.health_status) > 0
+ "components_ready": len(self.health_status) > 0,
}
-
+
async def shutdown(self):
"""Gracefully shutdown the service."""
if not self.is_running:
return
-
+
self.logger.info("Starting graceful shutdown")
shutdown_start = time.time()
-
+
# Set shutdown event
self.shutdown_event.set()
self.is_ready = False
-
+
try:
# Cancel background tasks
self.logger.info("Cancelling background tasks")
for task in self.background_tasks:
task.cancel()
-
+
# Wait for tasks to complete
if self.background_tasks:
await asyncio.gather(*self.background_tasks, return_exceptions=True)
-
+
# Cleanup components
await self.cleanup()
-
+
shutdown_time = time.time() - shutdown_start
self.logger.info(f"Graceful shutdown completed in {shutdown_time:.2f}s")
-
+
except Exception as e:
self.logger.error(f"Error during shutdown: {e}")
finally:
self.is_running = False
-
+
async def cleanup(self):
"""Clean up all resources."""
self.logger.info("Cleaning up resources")
-
+
# Close gateway
if self.arcade_gateway:
await self.arcade_gateway.cleanup()
-
+
# Close arcade client
if self.arcade_client:
await self.arcade_client.disconnect()
-
+
# Close cache manager
if self.cache_manager:
# CacheManager doesn't have a close method - cleanup is automatic
pass
-
+
# Close metrics collector
if self.metrics:
# MetricsCollector from monitoring doesn't require cleanup
pass
-
+
# Shutdown thread pool
self.thread_pool.shutdown(wait=True)
-
+
self.logger.info("Resource cleanup completed")
@@ -675,29 +693,30 @@ async def create_health_check_server(service: ProductionArcadeService):
# Fallback for demonstration if aiohttp is not installed
web = None
web_response = None
-
+
if web is None or web_response is None:
# Return a mock app if aiohttp is not available
class MockApp:
pass
+
return MockApp()
-
+
async def health_handler(request):
"""Health check endpoint."""
status = await service.get_health_status()
status_code = 200 if status["status"] == "healthy" else 503
return web_response.json_response(status, status=status_code)
-
+
async def readiness_handler(request):
"""Readiness check endpoint."""
status = await service.get_readiness_status()
status_code = 200 if status["ready"] else 503
return web_response.json_response(status, status=status_code)
-
+
app = web.Application()
- app.router.add_get('/health', health_handler)
- app.router.add_get('/ready', readiness_handler)
-
+ app.router.add_get("/health", health_handler)
+ app.router.add_get("/ready", readiness_handler)
+
return app
@@ -705,54 +724,54 @@ async def main():
"""Main service entry point."""
# Load configuration
config = DeploymentConfig()
-
+
# Override from environment
config.environment = os.getenv("ENVIRONMENT", config.environment)
config.log_level = os.getenv("LOG_LEVEL", config.log_level)
config.port = int(os.getenv("PORT", str(config.port)))
-
+
# Create and start service
service = ProductionArcadeService(config)
-
+
try:
# Initialize service
await service.initialize()
-
+
# Create health check server
health_app = await create_health_check_server(service)
-
+
# Start health check server
try:
from aiohttp import web
except ImportError:
web = None
- if web is not None and hasattr(health_app, 'router'):
+ if web is not None and hasattr(health_app, "router"):
health_runner = web.AppRunner(health_app)
await health_runner.setup()
health_site = web.TCPSite(
- health_runner,
- service.config.host,
- service.config.health_check_port
+ health_runner, service.config.host, service.config.health_check_port
)
await health_site.start()
-
- service.logger.info(f"Health check server started on port {service.config.health_check_port}")
+
+ service.logger.info(
+ f"Health check server started on port {service.config.health_check_port}"
+ )
else:
service.logger.info("Health check server skipped (aiohttp not available)")
health_runner = None
-
+
service.logger.info(f"Service ready and accepting requests")
-
+
# Wait for shutdown signal
await service.shutdown_event.wait()
-
+
# Cleanup health server
if health_runner:
await health_runner.cleanup()
-
+
service.logger.info("Service shutdown completed")
return 0
-
+
except Exception as e:
if service:
service.logger.error(f"Service failed: {e}")
@@ -763,4 +782,4 @@ async def main():
if __name__ == "__main__":
exit_code = asyncio.run(main())
- sys.exit(exit_code)
\ No newline at end of file
+ sys.exit(exit_code)
diff --git a/examples/arcade-dev/12_monitoring/arcade_monitoring.py b/examples/arcade-dev/12_monitoring/arcade_monitoring.py
index 292804b..1089059 100644
--- a/examples/arcade-dev/12_monitoring/arcade_monitoring.py
+++ b/examples/arcade-dev/12_monitoring/arcade_monitoring.py
@@ -38,26 +38,27 @@
@dataclass
class MonitoringConfig:
"""Configuration for monitoring and observability."""
+
# Performance monitoring
enable_performance_tracking: bool = True
enable_health_checks: bool = True
enable_alerting: bool = True
enable_telemetry: bool = True
-
+
# Health check intervals
health_check_interval_seconds: int = 30
performance_check_interval_seconds: int = 60
-
+
# Alert thresholds
error_rate_threshold: float = 0.05 # 5%
response_time_threshold_ms: float = 5000 # 5 seconds
cpu_usage_threshold: float = 80.0 # 80%
memory_usage_threshold: float = 85.0 # 85%
-
+
# Metrics retention
metrics_retention_hours: int = 24
detailed_metrics_retention_minutes: int = 60
-
+
# Logging
metrics_log_file: str = "arcade_metrics.log"
alerts_log_file: str = "arcade_alerts.log"
@@ -66,6 +67,7 @@ class MonitoringConfig:
@dataclass
class HealthCheckResult:
"""Result of a health check."""
+
service: str
status: str # healthy, degraded, unhealthy
response_time_ms: float
@@ -77,6 +79,7 @@ class HealthCheckResult:
@dataclass
class PerformanceMetric:
"""Performance metric data point."""
+
metric_name: str
value: float
timestamp: datetime
@@ -86,6 +89,7 @@ class PerformanceMetric:
@dataclass
class Alert:
"""Alert notification."""
+
alert_id: str
severity: str # info, warning, critical
title: str
@@ -100,160 +104,168 @@ class Alert:
class SystemMonitor:
"""Monitors system resources and performance."""
-
+
def __init__(self, config: MonitoringConfig):
self.config = config
self.logger = logging.getLogger(f"{__name__}.SystemMonitor")
-
+
def get_system_metrics(self) -> Dict[str, float]:
"""Get current system resource metrics."""
try:
cpu_percent = psutil.cpu_percent(interval=1)
memory = psutil.virtual_memory()
- disk = psutil.disk_usage('/')
-
+ disk = psutil.disk_usage("/")
+
# Network I/O
net_io = psutil.net_io_counters()
-
+
return {
- 'cpu_usage_percent': cpu_percent,
- 'memory_usage_percent': memory.percent,
- 'memory_available_mb': memory.available / (1024 * 1024),
- 'disk_usage_percent': disk.percent,
- 'network_bytes_sent': net_io.bytes_sent,
- 'network_bytes_recv': net_io.bytes_recv,
- 'load_average': os.getloadavg()[0] if hasattr(os, 'getloadavg') else 0.0
+ "cpu_usage_percent": cpu_percent,
+ "memory_usage_percent": memory.percent,
+ "memory_available_mb": memory.available / (1024 * 1024),
+ "disk_usage_percent": disk.percent,
+ "network_bytes_sent": net_io.bytes_sent,
+ "network_bytes_recv": net_io.bytes_recv,
+ "load_average": (
+ os.getloadavg()[0] if hasattr(os, "getloadavg") else 0.0
+ ),
}
except Exception as e:
self.logger.error(f"Failed to collect system metrics: {e}")
return {}
-
+
def check_system_health(self) -> HealthCheckResult:
"""Perform system health check."""
start_time = time.time()
-
+
try:
metrics = self.get_system_metrics()
response_time = (time.time() - start_time) * 1000
-
+
# Determine health status
status = "healthy"
issues = []
-
- if metrics.get('cpu_usage_percent', 0) > self.config.cpu_usage_threshold:
+
+ if metrics.get("cpu_usage_percent", 0) > self.config.cpu_usage_threshold:
status = "degraded"
issues.append(f"High CPU usage: {metrics['cpu_usage_percent']:.1f}%")
-
- if metrics.get('memory_usage_percent', 0) > self.config.memory_usage_threshold:
+
+ if (
+ metrics.get("memory_usage_percent", 0)
+ > self.config.memory_usage_threshold
+ ):
status = "degraded"
- issues.append(f"High memory usage: {metrics['memory_usage_percent']:.1f}%")
-
- if metrics.get('disk_usage_percent', 0) > 90:
+ issues.append(
+ f"High memory usage: {metrics['memory_usage_percent']:.1f}%"
+ )
+
+ if metrics.get("disk_usage_percent", 0) > 90:
status = "unhealthy"
- issues.append(f"Critical disk usage: {metrics['disk_usage_percent']:.1f}%")
-
+ issues.append(
+ f"Critical disk usage: {metrics['disk_usage_percent']:.1f}%"
+ )
+
return HealthCheckResult(
service="system",
status=status,
response_time_ms=response_time,
timestamp=datetime.now(timezone.utc),
- details={
- 'metrics': metrics,
- 'issues': issues
- }
+ details={"metrics": metrics, "issues": issues},
)
-
+
except Exception as e:
return HealthCheckResult(
service="system",
status="unhealthy",
response_time_ms=(time.time() - start_time) * 1000,
timestamp=datetime.now(timezone.utc),
- error_message=str(e)
+ error_message=str(e),
)
class ArcadeAPIMonitor:
"""Monitors Arcade.dev API performance and health."""
-
+
def __init__(self, config: MonitoringConfig):
self.config = config
self.logger = logging.getLogger(f"{__name__}.ArcadeAPIMonitor")
- self.api_key = os.getenv('ARCADE_API_KEY')
- self.demo_mode = self.api_key is None or self.api_key == '' or self.api_key == 'demo_key'
-
+ self.api_key = os.getenv("ARCADE_API_KEY")
+ self.demo_mode = (
+ self.api_key is None or self.api_key == "" or self.api_key == "demo_key"
+ )
+
if self.demo_mode:
self.logger.info("Running in demo mode - API calls will be simulated")
else:
self.logger.info("Running with real API credentials")
-
+
async def check_api_health(self) -> HealthCheckResult:
"""Check Arcade.dev API health."""
start_time = time.time()
-
+
try:
if self.demo_mode:
# Demo mode: simulate API health check
await asyncio.sleep(0.1) # Simulate network call
response_time = (time.time() - start_time) * 1000
-
+
# Mock API response analysis
status = "healthy"
if response_time > self.config.response_time_threshold_ms:
status = "degraded"
-
+
return HealthCheckResult(
service="arcade_api",
status=status,
response_time_ms=response_time,
timestamp=datetime.now(timezone.utc),
details={
- 'endpoint': 'https://api.arcade.dev/health (simulated)',
- 'api_version': 'v1',
- 'demo_mode': True,
- 'response_time_threshold_ms': self.config.response_time_threshold_ms
- }
+ "endpoint": "https://api.arcade.dev/health (simulated)",
+ "api_version": "v1",
+ "demo_mode": True,
+ "response_time_threshold_ms": self.config.response_time_threshold_ms,
+ },
)
else:
# Real API mode: would make actual API call here
# For now, simulate since we don't have real API implementation
await asyncio.sleep(0.1) # Simulate network call
response_time = (time.time() - start_time) * 1000
-
+
status = "healthy"
if response_time > self.config.response_time_threshold_ms:
status = "degraded"
-
+
return HealthCheckResult(
service="arcade_api",
status=status,
response_time_ms=response_time,
timestamp=datetime.now(timezone.utc),
details={
- 'endpoint': 'https://api.arcade.dev/health',
- 'api_version': 'v1',
- 'demo_mode': False,
- 'response_time_threshold_ms': self.config.response_time_threshold_ms
- }
+ "endpoint": "https://api.arcade.dev/health",
+ "api_version": "v1",
+ "demo_mode": False,
+ "response_time_threshold_ms": self.config.response_time_threshold_ms,
+ },
)
-
+
except Exception as e:
return HealthCheckResult(
service="arcade_api",
status="unhealthy",
response_time_ms=(time.time() - start_time) * 1000,
timestamp=datetime.now(timezone.utc),
- error_message=str(e)
+ error_message=str(e),
)
-
+
async def test_api_operations(self) -> Dict[str, HealthCheckResult]:
"""Test various API operations for monitoring."""
operations = {
- 'code_analysis': self._test_code_analysis,
- 'test_generation': self._test_generation,
- 'documentation': self._test_documentation
+ "code_analysis": self._test_code_analysis,
+ "test_generation": self._test_generation,
+ "documentation": self._test_documentation,
}
-
+
results = {}
for operation_name, test_func in operations.items():
try:
@@ -265,106 +277,112 @@ async def test_api_operations(self) -> Dict[str, HealthCheckResult]:
status="unhealthy",
response_time_ms=0,
timestamp=datetime.now(timezone.utc),
- error_message=str(e)
+ error_message=str(e),
)
-
+
return results
-
+
async def _test_code_analysis(self) -> HealthCheckResult:
"""Test code analysis endpoint."""
start_time = time.time()
-
+
# Simulate API call (demo or real mode)
await asyncio.sleep(0.2)
-
+
response_time = (time.time() - start_time) * 1000
-
+
return HealthCheckResult(
service="arcade_api_code_analysis",
status="healthy" if response_time < 1000 else "degraded",
response_time_ms=response_time,
timestamp=datetime.now(timezone.utc),
- details={
- 'test_code_lines': 10,
- 'demo_mode': self.demo_mode
- }
+ details={"test_code_lines": 10, "demo_mode": self.demo_mode},
)
-
+
async def _test_generation(self) -> HealthCheckResult:
"""Test test generation endpoint."""
start_time = time.time()
-
+
# Simulate API call (demo or real mode)
await asyncio.sleep(0.3)
-
+
response_time = (time.time() - start_time) * 1000
-
+
return HealthCheckResult(
service="arcade_api_test_generation",
status="healthy" if response_time < 1500 else "degraded",
response_time_ms=response_time,
timestamp=datetime.now(timezone.utc),
- details={
- 'generated_tests': 5,
- 'demo_mode': self.demo_mode
- }
+ details={"generated_tests": 5, "demo_mode": self.demo_mode},
)
-
+
async def _test_documentation(self) -> HealthCheckResult:
"""Test documentation endpoint."""
start_time = time.time()
-
+
# Simulate API call (demo or real mode)
await asyncio.sleep(0.15)
-
+
response_time = (time.time() - start_time) * 1000
-
+
return HealthCheckResult(
service="arcade_api_documentation",
status="healthy" if response_time < 800 else "degraded",
response_time_ms=response_time,
timestamp=datetime.now(timezone.utc),
- details={
- 'doc_sections': 4,
- 'demo_mode': self.demo_mode
- }
+ details={"doc_sections": 4, "demo_mode": self.demo_mode},
)
class AlertManager:
"""Manages alerts and notifications."""
-
+
def __init__(self, config: MonitoringConfig):
self.config = config
self.logger = logging.getLogger(f"{__name__}.AlertManager")
self.active_alerts: Dict[str, Alert] = {}
self.alert_history: List[Alert] = []
self.alert_callbacks: List[Callable[[Alert], None]] = []
- self.suppressed_alerts: Dict[str, datetime] = {} # Track alerts to prevent duplicates
-
+ self.suppressed_alerts: Dict[str, datetime] = (
+ {}
+ ) # Track alerts to prevent duplicates
+
def register_alert_callback(self, callback: Callable[[Alert], None]):
"""Register a callback for alert notifications."""
self.alert_callbacks.append(callback)
-
- def create_alert(self, severity: str, title: str, message: str, service: str,
- metric_name: str = None, current_value: float = None,
- threshold: float = None) -> Optional[Alert]:
+
+ def create_alert(
+ self,
+ severity: str,
+ title: str,
+ message: str,
+ service: str,
+ metric_name: str = None,
+ current_value: float = None,
+ threshold: float = None,
+ ) -> Optional[Alert]:
"""Create a new alert, with duplicate suppression."""
# Create unique key for this type of alert
- alert_key = f"{service}_{metric_name}_{title}" if metric_name else f"{service}_{title}"
-
+ alert_key = (
+ f"{service}_{metric_name}_{title}" if metric_name else f"{service}_{title}"
+ )
+
# Check if we've recently created this type of alert (suppress duplicates for 5 minutes)
now = datetime.now(timezone.utc)
if alert_key in self.suppressed_alerts:
time_since_last = now - self.suppressed_alerts[alert_key]
if time_since_last < timedelta(minutes=5):
return None # Suppress duplicate alert
-
+
# Update suppression tracker
self.suppressed_alerts[alert_key] = now
-
- alert_id = f"{service}_{metric_name}_{int(time.time())}" if metric_name else f"{service}_{int(time.time())}"
-
+
+ alert_id = (
+ f"{service}_{metric_name}_{int(time.time())}"
+ if metric_name
+ else f"{service}_{int(time.time())}"
+ )
+
alert = Alert(
alert_id=alert_id,
severity=severity,
@@ -374,41 +392,41 @@ def create_alert(self, severity: str, title: str, message: str, service: str,
service=service,
metric_name=metric_name,
current_value=current_value,
- threshold=threshold
+ threshold=threshold,
)
-
+
self.active_alerts[alert_id] = alert
self.alert_history.append(alert)
-
+
# Trigger callbacks
for callback in self.alert_callbacks:
try:
callback(alert)
except Exception as e:
self.logger.error(f"Alert callback failed: {e}")
-
+
self.logger.warning(f"Alert created: {alert.title} - {alert.message}")
return alert
-
+
def resolve_alert(self, alert_id: str) -> bool:
"""Resolve an active alert."""
if alert_id in self.active_alerts:
alert = self.active_alerts[alert_id]
alert.resolved = True
del self.active_alerts[alert_id]
-
+
self.logger.info(f"Alert resolved: {alert.title}")
return True
-
+
return False
-
+
def check_metric_thresholds(self, metrics: Dict[str, float], service: str):
"""Check metrics against thresholds and create alerts if needed."""
if not self.config.enable_alerting:
return
-
+
# CPU usage alert
- cpu_usage = metrics.get('cpu_usage_percent', 0)
+ cpu_usage = metrics.get("cpu_usage_percent", 0)
if cpu_usage > self.config.cpu_usage_threshold:
self.create_alert(
severity="warning",
@@ -417,11 +435,11 @@ def check_metric_thresholds(self, metrics: Dict[str, float], service: str):
service=service,
metric_name="cpu_usage_percent",
current_value=cpu_usage,
- threshold=self.config.cpu_usage_threshold
+ threshold=self.config.cpu_usage_threshold,
)
-
+
# Memory usage alert
- memory_usage = metrics.get('memory_usage_percent', 0)
+ memory_usage = metrics.get("memory_usage_percent", 0)
if memory_usage > self.config.memory_usage_threshold:
self.create_alert(
severity="warning",
@@ -430,89 +448,95 @@ def check_metric_thresholds(self, metrics: Dict[str, float], service: str):
service=service,
metric_name="memory_usage_percent",
current_value=memory_usage,
- threshold=self.config.memory_usage_threshold
+ threshold=self.config.memory_usage_threshold,
)
-
+
def get_active_alerts(self) -> List[Alert]:
"""Get all active alerts."""
return list(self.active_alerts.values())
-
+
def get_alert_summary(self) -> Dict[str, Any]:
"""Get alert summary statistics."""
now = datetime.now(timezone.utc)
last_hour = now - timedelta(hours=1)
last_day = now - timedelta(days=1)
-
+
recent_alerts = [a for a in self.alert_history if a.timestamp > last_hour]
daily_alerts = [a for a in self.alert_history if a.timestamp > last_day]
-
+
return {
- 'active_alerts': len(self.active_alerts),
- 'alerts_last_hour': len(recent_alerts),
- 'alerts_last_day': len(daily_alerts),
- 'total_alerts': len(self.alert_history),
- 'severity_breakdown': {
- 'critical': len([a for a in recent_alerts if a.severity == 'critical']),
- 'warning': len([a for a in recent_alerts if a.severity == 'warning']),
- 'info': len([a for a in recent_alerts if a.severity == 'info'])
- }
+ "active_alerts": len(self.active_alerts),
+ "alerts_last_hour": len(recent_alerts),
+ "alerts_last_day": len(daily_alerts),
+ "total_alerts": len(self.alert_history),
+ "severity_breakdown": {
+ "critical": len([a for a in recent_alerts if a.severity == "critical"]),
+ "warning": len([a for a in recent_alerts if a.severity == "warning"]),
+ "info": len([a for a in recent_alerts if a.severity == "info"]),
+ },
}
class TelemetryCollector:
"""Collects and exports telemetry data."""
-
+
def __init__(self, config: MonitoringConfig):
self.config = config
self.logger = logging.getLogger(f"{__name__}.TelemetryCollector")
self.metrics_buffer: List[PerformanceMetric] = []
self.buffer_lock = threading.Lock()
-
- def record_metric(self, metric_name: str, value: float, tags: Dict[str, str] = None):
+
+ def record_metric(
+ self, metric_name: str, value: float, tags: Dict[str, str] = None
+ ):
"""Record a performance metric."""
if not self.config.enable_telemetry:
return
-
+
metric = PerformanceMetric(
metric_name=metric_name,
value=value,
timestamp=datetime.now(timezone.utc),
- tags=tags or {}
+ tags=tags or {},
)
-
+
with self.buffer_lock:
self.metrics_buffer.append(metric)
-
+
# Keep buffer size manageable
if len(self.metrics_buffer) > 1000:
self.metrics_buffer = self.metrics_buffer[-500:]
-
- def get_metrics(self, metric_name: str = None, last_minutes: int = 10) -> List[PerformanceMetric]:
+
+ def get_metrics(
+ self, metric_name: str = None, last_minutes: int = 10
+ ) -> List[PerformanceMetric]:
"""Get metrics from buffer."""
cutoff_time = datetime.now(timezone.utc) - timedelta(minutes=last_minutes)
-
+
with self.buffer_lock:
filtered_metrics = [
- m for m in self.metrics_buffer
- if m.timestamp > cutoff_time and (metric_name is None or m.metric_name == metric_name)
+ m
+ for m in self.metrics_buffer
+ if m.timestamp > cutoff_time
+ and (metric_name is None or m.metric_name == metric_name)
]
-
+
return filtered_metrics
-
- def export_metrics(self, format_type: str = 'json') -> str:
+
+ def export_metrics(self, format_type: str = "json") -> str:
"""Export metrics in specified format."""
with self.buffer_lock:
metrics_data = [
{
- 'metric_name': m.metric_name,
- 'value': m.value,
- 'timestamp': m.timestamp.isoformat(),
- 'tags': m.tags
+ "metric_name": m.metric_name,
+ "value": m.value,
+ "timestamp": m.timestamp.isoformat(),
+ "tags": m.tags,
}
for m in self.metrics_buffer
]
-
- if format_type == 'json':
+
+ if format_type == "json":
return json.dumps(metrics_data, indent=2)
else:
return str(metrics_data)
@@ -520,72 +544,72 @@ def export_metrics(self, format_type: str = 'json') -> str:
class ArcadeMonitoringDashboard:
"""Main monitoring dashboard orchestrating all monitoring components."""
-
+
def __init__(self, config: MonitoringConfig = None):
self.config = config or MonitoringConfig()
self.logger = logging.getLogger(__name__)
-
+
# Initialize components
self.system_monitor = SystemMonitor(self.config)
self.api_monitor = ArcadeAPIMonitor(self.config)
self.alert_manager = AlertManager(self.config)
self.telemetry = TelemetryCollector(self.config)
self.metrics_collector = MetricsCollector()
-
+
# Register alert callback
self.alert_manager.register_alert_callback(self._handle_alert)
-
+
# Monitoring state
self.monitoring_active = False
self.monitoring_task: Optional[asyncio.Task] = None
-
+
def _handle_alert(self, alert: Alert):
"""Handle alert notifications."""
# Log alert
alert_data = {
- 'alert_id': alert.alert_id,
- 'severity': alert.severity,
- 'title': alert.title,
- 'message': alert.message,
- 'service': alert.service,
- 'timestamp': alert.timestamp.isoformat()
+ "alert_id": alert.alert_id,
+ "severity": alert.severity,
+ "title": alert.title,
+ "message": alert.message,
+ "service": alert.service,
+ "timestamp": alert.timestamp.isoformat(),
}
-
+
# Write to alerts log
alerts_log = Path(self.config.alerts_log_file)
alerts_log.parent.mkdir(exist_ok=True)
-
- with open(alerts_log, 'a') as f:
- f.write(json.dumps(alert_data) + '\n')
-
+
+ with open(alerts_log, "a") as f:
+ f.write(json.dumps(alert_data) + "\n")
+
# In production, send to notification system (email, Slack, etc.)
print(f"šØ ALERT [{alert.severity.upper()}]: {alert.title}")
print(f" Service: {alert.service}")
print(f" Message: {alert.message}")
-
+
async def start_monitoring(self):
"""Start continuous monitoring."""
if self.monitoring_active:
self.logger.warning("Monitoring already active")
return
-
+
self.monitoring_active = True
self.monitoring_task = asyncio.create_task(self._monitoring_loop())
self.logger.info("Monitoring started")
-
+
async def stop_monitoring(self):
"""Stop monitoring."""
self.monitoring_active = False
-
+
if self.monitoring_task:
self.monitoring_task.cancel()
try:
await self.monitoring_task
except asyncio.CancelledError:
pass
-
+
self.logger.info("Monitoring stopped")
-
+
async def _monitoring_loop(self):
"""Main monitoring loop."""
while self.monitoring_active:
@@ -593,35 +617,34 @@ async def _monitoring_loop(self):
# System health check
system_health = self.system_monitor.check_system_health()
self.telemetry.record_metric(
- 'system_health_response_time_ms',
+ "system_health_response_time_ms",
system_health.response_time_ms,
- {'service': 'system', 'status': system_health.status}
+ {"service": "system", "status": system_health.status},
)
-
+
# Check system metrics against thresholds
- if system_health.details and 'metrics' in system_health.details:
+ if system_health.details and "metrics" in system_health.details:
self.alert_manager.check_metric_thresholds(
- system_health.details['metrics'],
- 'system'
+ system_health.details["metrics"], "system"
)
-
+
# API health checks
api_health = await self.api_monitor.check_api_health()
self.telemetry.record_metric(
- 'api_health_response_time_ms',
+ "api_health_response_time_ms",
api_health.response_time_ms,
- {'service': 'arcade_api', 'status': api_health.status}
+ {"service": "arcade_api", "status": api_health.status},
)
-
+
# Test API operations
operation_results = await self.api_monitor.test_api_operations()
for operation, result in operation_results.items():
self.telemetry.record_metric(
- f'api_operation_response_time_ms',
+ f"api_operation_response_time_ms",
result.response_time_ms,
- {'operation': operation, 'status': result.status}
+ {"operation": operation, "status": result.status},
)
-
+
# Check for slow operations
if result.response_time_ms > self.config.response_time_threshold_ms:
self.alert_manager.create_alert(
@@ -631,100 +654,114 @@ async def _monitoring_loop(self):
service="arcade_api",
metric_name="response_time_ms",
current_value=result.response_time_ms,
- threshold=self.config.response_time_threshold_ms
+ threshold=self.config.response_time_threshold_ms,
)
-
+
await asyncio.sleep(self.config.health_check_interval_seconds)
-
+
except Exception as e:
self.logger.error(f"Monitoring loop error: {e}")
await asyncio.sleep(5) # Wait before retrying
-
+
def get_dashboard_data(self) -> Dict[str, Any]:
"""Get comprehensive dashboard data."""
# System metrics
system_metrics = self.system_monitor.get_system_metrics()
-
+
# Recent performance metrics
recent_metrics = {
- 'system_response_times': [
- m.value for m in self.telemetry.get_metrics('system_health_response_time_ms', 10)
+ "system_response_times": [
+ m.value
+ for m in self.telemetry.get_metrics(
+ "system_health_response_time_ms", 10
+ )
+ ],
+ "api_response_times": [
+ m.value
+ for m in self.telemetry.get_metrics("api_health_response_time_ms", 10)
],
- 'api_response_times': [
- m.value for m in self.telemetry.get_metrics('api_health_response_time_ms', 10)
- ]
}
-
+
# Calculate averages
- if recent_metrics['system_response_times']:
- recent_metrics['avg_system_response_time'] = statistics.mean(recent_metrics['system_response_times'])
+ if recent_metrics["system_response_times"]:
+ recent_metrics["avg_system_response_time"] = statistics.mean(
+ recent_metrics["system_response_times"]
+ )
else:
- recent_metrics['avg_system_response_time'] = 0
-
- if recent_metrics['api_response_times']:
- recent_metrics['avg_api_response_time'] = statistics.mean(recent_metrics['api_response_times'])
+ recent_metrics["avg_system_response_time"] = 0
+
+ if recent_metrics["api_response_times"]:
+ recent_metrics["avg_api_response_time"] = statistics.mean(
+ recent_metrics["api_response_times"]
+ )
else:
- recent_metrics['avg_api_response_time'] = 0
-
+ recent_metrics["avg_api_response_time"] = 0
+
return {
- 'timestamp': datetime.now(timezone.utc).isoformat(),
- 'monitoring_active': self.monitoring_active,
- 'system_metrics': system_metrics,
- 'performance_metrics': recent_metrics,
- 'alerts': self.alert_manager.get_alert_summary(),
- 'active_alerts': [
+ "timestamp": datetime.now(timezone.utc).isoformat(),
+ "monitoring_active": self.monitoring_active,
+ "system_metrics": system_metrics,
+ "performance_metrics": recent_metrics,
+ "alerts": self.alert_manager.get_alert_summary(),
+ "active_alerts": [
{
- 'id': alert.alert_id,
- 'severity': alert.severity,
- 'title': alert.title,
- 'service': alert.service,
- 'timestamp': alert.timestamp.isoformat()
+ "id": alert.alert_id,
+ "severity": alert.severity,
+ "title": alert.title,
+ "service": alert.service,
+ "timestamp": alert.timestamp.isoformat(),
}
for alert in self.alert_manager.get_active_alerts()
- ]
+ ],
}
-
+
def print_dashboard(self):
"""Print monitoring dashboard to console."""
dashboard_data = self.get_dashboard_data()
-
+
print("\n" + "=" * 60)
print("š„ļø ARCADE.DEV MONITORING DASHBOARD")
print("=" * 60)
-
+
# Demo mode indicator
- if hasattr(self, 'api_monitor') and self.api_monitor.demo_mode:
+ if hasattr(self, "api_monitor") and self.api_monitor.demo_mode:
print("š§ Running in DEMO MODE (simulated API calls)")
-
+
# System Status
print(f"\nš System Metrics:")
- system_metrics = dashboard_data['system_metrics']
+ system_metrics = dashboard_data["system_metrics"]
if system_metrics:
print(f" CPU Usage: {system_metrics.get('cpu_usage_percent', 0):.1f}%")
- print(f" Memory Usage: {system_metrics.get('memory_usage_percent', 0):.1f}%")
+ print(
+ f" Memory Usage: {system_metrics.get('memory_usage_percent', 0):.1f}%"
+ )
print(f" Disk Usage: {system_metrics.get('disk_usage_percent', 0):.1f}%")
print(f" Load Average: {system_metrics.get('load_average', 0):.2f}")
-
+
# Performance
print(f"\nā” Performance:")
- perf_metrics = dashboard_data['performance_metrics']
- print(f" Avg System Response: {perf_metrics['avg_system_response_time']:.1f}ms")
+ perf_metrics = dashboard_data["performance_metrics"]
+ print(
+ f" Avg System Response: {perf_metrics['avg_system_response_time']:.1f}ms"
+ )
print(f" Avg API Response: {perf_metrics['avg_api_response_time']:.1f}ms")
-
+
# Alerts
print(f"\nšØ Alerts:")
- alert_summary = dashboard_data['alerts']
+ alert_summary = dashboard_data["alerts"]
print(f" Active: {alert_summary['active_alerts']}")
print(f" Last Hour: {alert_summary['alerts_last_hour']}")
print(f" Last Day: {alert_summary['alerts_last_day']}")
-
+
# Active Alerts
- active_alerts = dashboard_data['active_alerts']
+ active_alerts = dashboard_data["active_alerts"]
if active_alerts:
print(f"\nā ļø Active Alerts:")
for alert in active_alerts[:5]: # Show first 5
- print(f" [{alert['severity'].upper()}] {alert['title']} ({alert['service']})")
-
+ print(
+ f" [{alert['severity'].upper()}] {alert['title']} ({alert['service']})"
+ )
+
print("\n" + "=" * 60)
@@ -732,51 +769,57 @@ async def demonstrate_monitoring():
"""Demonstrate monitoring capabilities."""
print("š Arcade.dev Monitoring Demo")
print("=" * 50)
-
+
try:
# Create monitoring dashboard
config = MonitoringConfig(
health_check_interval_seconds=5, # Faster for demo
cpu_usage_threshold=50.0, # Lower threshold for demo alerts
- memory_usage_threshold=60.0
+ memory_usage_threshold=60.0,
)
-
+
dashboard = ArcadeMonitoringDashboard(config)
-
+
# Check demo mode status
- demo_status = "š§ DEMO MODE" if dashboard.api_monitor.demo_mode else "š LIVE MODE"
+ demo_status = (
+ "š§ DEMO MODE" if dashboard.api_monitor.demo_mode else "š LIVE MODE"
+ )
print(f"\nMode: {demo_status}")
- print(f"API Key Available: {'ā No' if dashboard.api_monitor.demo_mode else 'ā
Yes'}")
-
+ print(
+ f"API Key Available: {'ā No' if dashboard.api_monitor.demo_mode else 'ā
Yes'}"
+ )
+
print("\nš Starting monitoring...")
await dashboard.start_monitoring()
-
+
# Let it run for a bit and collect data
print("ā³ Collecting metrics for 30 seconds...")
-
+
for i in range(6): # 6 iterations of 5 seconds each
await asyncio.sleep(5)
dashboard.print_dashboard()
-
+
# Simulate some load to trigger alerts
if i == 2:
print("\nš„ Simulating high load...")
# This would trigger alerts in a real scenario
-
+
print("\nš Stopping monitoring...")
await dashboard.stop_monitoring()
-
+
# Final summary
print("\nš Final Monitoring Summary:")
dashboard.print_dashboard()
-
+
# Export metrics for review
metrics_export = dashboard.telemetry.export_metrics()
- print(f"\nš Exported {len(dashboard.telemetry.metrics_buffer)} metrics to JSON")
-
+ print(
+ f"\nš Exported {len(dashboard.telemetry.metrics_buffer)} metrics to JSON"
+ )
+
print("\nš Monitoring demonstration completed!")
return dashboard
-
+
except Exception as e:
print(f"\nā Error during monitoring demo: {e}")
raise
@@ -786,9 +829,9 @@ async def main():
"""Main demonstration function."""
logging.basicConfig(
level=logging.INFO,
- format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
+ format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
)
-
+
try:
await demonstrate_monitoring()
return 0
@@ -802,4 +845,4 @@ async def main():
if __name__ == "__main__":
exit_code = asyncio.run(main())
- sys.exit(exit_code)
\ No newline at end of file
+ sys.exit(exit_code)
diff --git a/examples/arcade-dev/interactive_demo.py b/examples/arcade-dev/interactive_demo.py
index 821becb..87a0eaf 100644
--- a/examples/arcade-dev/interactive_demo.py
+++ b/examples/arcade-dev/interactive_demo.py
@@ -25,6 +25,7 @@
from rich.table import Table
from rich.progress import Progress, SpinnerColumn, TextColumn
from rich.syntax import Syntax
+
HAS_RICH = True
except ImportError:
HAS_RICH = False
@@ -33,6 +34,7 @@
@dataclass
class DemoExample:
"""Represents a demo example."""
+
name: str
description: str
script_path: str
@@ -42,12 +44,12 @@ class DemoExample:
class InteractiveDemo:
"""Interactive demonstration manager."""
-
+
def __init__(self):
self.console = Console() if HAS_RICH else None
self.examples = self._load_examples()
self.logger = logging.getLogger(__name__)
-
+
def _load_examples(self) -> List[DemoExample]:
"""Load available examples."""
return [
@@ -56,77 +58,83 @@ def _load_examples(self) -> List[DemoExample]:
description="Fundamental Arcade.dev API client integration",
script_path="01_basic_integration/basic_arcade_client.py",
requirements=["ARCADE_API_KEY"],
- difficulty="beginner"
+ difficulty="beginner",
),
DemoExample(
name="Code Analysis",
description="Analyze code quality and structure",
script_path="02_code_analysis/code_analyzer.py",
requirements=["ARCADE_API_KEY"],
- difficulty="beginner"
+ difficulty="beginner",
),
DemoExample(
name="Test Generation",
description="Generate comprehensive test cases using AI",
script_path="03_automated_testing/test_generator.py",
requirements=["ARCADE_API_KEY"],
- difficulty="intermediate"
+ difficulty="intermediate",
),
DemoExample(
name="Documentation Generation",
description="Automatically generate documentation from code",
script_path="04_documentation/doc_generator.py",
requirements=["ARCADE_API_KEY"],
- difficulty="intermediate"
+ difficulty="intermediate",
),
DemoExample(
name="Cache Integration",
description="Integrate Arcade.dev with FACT's caching system",
script_path="07_cache_integration/cached_arcade_client.py",
requirements=["ARCADE_API_KEY", "FACT_CACHE_ENABLED"],
- difficulty="advanced"
- )
+ difficulty="advanced",
+ ),
]
-
+
def _print_banner(self):
"""Print welcome banner."""
if self.console:
- self.console.print(Panel.fit(
- "[bold blue]š® Arcade.dev Integration Examples[/bold blue]\n"
- "[dim]Interactive demonstration of FACT + Arcade.dev integration[/dim]",
- border_style="blue"
- ))
+ self.console.print(
+ Panel.fit(
+ "[bold blue]š® Arcade.dev Integration Examples[/bold blue]\n"
+ "[dim]Interactive demonstration of FACT + Arcade.dev integration[/dim]",
+ border_style="blue",
+ )
+ )
else:
print("=" * 60)
print("š® Arcade.dev Integration Examples")
print("Interactive demonstration of FACT + Arcade.dev integration")
print("=" * 60)
-
+
def _check_prerequisites(self) -> bool:
"""Check if prerequisites are met."""
missing_vars = []
required_vars = ["ARCADE_API_KEY"]
-
+
for var in required_vars:
if not os.getenv(var):
missing_vars.append(var)
-
+
if missing_vars:
if self.console:
- self.console.print(Panel(
- f"[red]Missing required environment variables:[/red]\n"
- f"{', '.join(missing_vars)}\n\n"
- f"[dim]Please copy .env.example to .env and configure your settings.[/dim]",
- title="Prerequisites Check Failed",
- border_style="red"
- ))
+ self.console.print(
+ Panel(
+ f"[red]Missing required environment variables:[/red]\n"
+ f"{', '.join(missing_vars)}\n\n"
+ f"[dim]Please copy .env.example to .env and configure your settings.[/dim]",
+ title="Prerequisites Check Failed",
+ border_style="red",
+ )
+ )
else:
- print(f"ā Missing required environment variables: {', '.join(missing_vars)}")
+ print(
+ f"ā Missing required environment variables: {', '.join(missing_vars)}"
+ )
print("Please copy .env.example to .env and configure your settings.")
return False
-
+
return True
-
+
def _display_examples_table(self):
"""Display available examples in a table."""
if self.console:
@@ -136,12 +144,12 @@ def _display_examples_table(self):
table.add_column("Description")
table.add_column("Difficulty", justify="center")
table.add_column("Status", justify="center")
-
+
for i, example in enumerate(self.examples, 1):
# Check if example script exists
script_path = Path(__file__).parent / example.script_path
status = "ā
" if script_path.exists() else "š§"
-
+
# Color code difficulty
if example.difficulty == "beginner":
difficulty = f"[green]{example.difficulty}[/green]"
@@ -149,15 +157,11 @@ def _display_examples_table(self):
difficulty = f"[yellow]{example.difficulty}[/yellow]"
else:
difficulty = f"[red]{example.difficulty}[/red]"
-
+
table.add_row(
- str(i),
- example.name,
- example.description,
- difficulty,
- status
+ str(i), example.name, example.description, difficulty, status
)
-
+
self.console.print(table)
else:
print("\nAvailable Examples:")
@@ -167,109 +171,119 @@ def _display_examples_table(self):
status = "ā
" if script_path.exists() else "š§"
print(f"{i:2d}. {example.name} ({example.difficulty}) {status}")
print(f" {example.description}")
-
+
def _get_user_choice(self) -> Optional[int]:
"""Get user's example choice."""
if self.console:
choice = Prompt.ask(
"\n[bold]Select an example to run[/bold]",
choices=[str(i) for i in range(1, len(self.examples) + 1)] + ["q"],
- default="q"
+ default="q",
)
else:
- choice = input(f"\nSelect an example (1-{len(self.examples)}) or 'q' to quit: ")
-
- if choice.lower() == 'q':
+ choice = input(
+ f"\nSelect an example (1-{len(self.examples)}) or 'q' to quit: "
+ )
+
+ if choice.lower() == "q":
return None
-
+
try:
return int(choice)
except ValueError:
return None
-
+
async def _run_example(self, example: DemoExample):
"""Run a selected example."""
script_path = Path(__file__).parent / example.script_path
-
+
if not script_path.exists():
if self.console:
- self.console.print(f"[red]ā Example script not found: {example.script_path}[/red]")
+ self.console.print(
+ f"[red]ā Example script not found: {example.script_path}[/red]"
+ )
else:
print(f"ā Example script not found: {example.script_path}")
return
-
+
# Check example-specific requirements
missing_reqs = []
for req in example.requirements:
if not os.getenv(req):
missing_reqs.append(req)
-
+
if missing_reqs:
if self.console:
- self.console.print(f"[yellow]ā ļø Missing requirements for this example: {', '.join(missing_reqs)}[/yellow]")
+ self.console.print(
+ f"[yellow]ā ļø Missing requirements for this example: {', '.join(missing_reqs)}[/yellow]"
+ )
if not Confirm.ask("Continue anyway?"):
return
else:
print(f"ā ļø Missing requirements: {', '.join(missing_reqs)}")
response = input("Continue anyway? (y/N): ")
- if response.lower() != 'y':
+ if response.lower() != "y":
return
-
+
# Display example info
if self.console:
- self.console.print(Panel(
- f"[bold]{example.name}[/bold]\n"
- f"{example.description}\n\n"
- f"[dim]Script: {example.script_path}[/dim]\n"
- f"[dim]Difficulty: {example.difficulty}[/dim]",
- title="Running Example",
- border_style="green"
- ))
+ self.console.print(
+ Panel(
+ f"[bold]{example.name}[/bold]\n"
+ f"{example.description}\n\n"
+ f"[dim]Script: {example.script_path}[/dim]\n"
+ f"[dim]Difficulty: {example.difficulty}[/dim]",
+ title="Running Example",
+ border_style="green",
+ )
+ )
else:
print(f"\nš Running: {example.name}")
print(f"Description: {example.description}")
print(f"Script: {example.script_path}")
-
+
# Run the example
try:
if self.console:
with Progress(
SpinnerColumn(),
TextColumn("[progress.description]{task.description}"),
- console=self.console
+ console=self.console,
) as progress:
task = progress.add_task("Executing example...", total=None)
-
+
# Execute the script
process = await asyncio.create_subprocess_exec(
- sys.executable, str(script_path),
+ sys.executable,
+ str(script_path),
stdout=asyncio.subprocess.PIPE,
- stderr=asyncio.subprocess.PIPE
+ stderr=asyncio.subprocess.PIPE,
)
-
+
stdout, stderr = await process.communicate()
-
+
progress.remove_task(task)
-
+
else:
print("Executing example...")
process = await asyncio.create_subprocess_exec(
- sys.executable, str(script_path),
+ sys.executable,
+ str(script_path),
stdout=asyncio.subprocess.PIPE,
- stderr=asyncio.subprocess.PIPE
+ stderr=asyncio.subprocess.PIPE,
)
stdout, stderr = await process.communicate()
-
+
# Display results
if process.returncode == 0:
if self.console:
- self.console.print("[green]ā
Example completed successfully![/green]")
+ self.console.print(
+ "[green]ā
Example completed successfully![/green]"
+ )
if stdout:
- self.console.print(Panel(
- stdout.decode(),
- title="Output",
- border_style="green"
- ))
+ self.console.print(
+ Panel(stdout.decode(), title="Output", border_style="green")
+ )
else:
print("ā
Example completed successfully!")
if stdout:
@@ -280,59 +294,63 @@ async def _run_example(self, example: DemoExample):
if self.console:
self.console.print("[red]ā Example failed![/red]")
if stderr:
- self.console.print(Panel(
- stderr.decode(),
- title="Error Output",
- border_style="red"
- ))
+ self.console.print(
+ Panel(
+ stderr.decode(),
+ title="Error Output",
+ border_style="red",
+ )
+ )
else:
print("ā Example failed!")
if stderr:
print("\nError Output:")
print("-" * 40)
print(stderr.decode())
-
+
except Exception as e:
if self.console:
self.console.print(f"[red]ā Error running example: {e}[/red]")
else:
print(f"ā Error running example: {e}")
-
+
async def run(self):
"""Run the interactive demo."""
self._print_banner()
-
+
if not self._check_prerequisites():
return 1
-
+
while True:
self._display_examples_table()
choice = self._get_user_choice()
-
+
if choice is None:
if self.console:
- self.console.print("\n[blue]š Thanks for exploring Arcade.dev examples![/blue]")
+ self.console.print(
+ "\n[blue]š Thanks for exploring Arcade.dev examples![/blue]"
+ )
else:
print("\nš Thanks for exploring Arcade.dev examples!")
break
-
+
if 1 <= choice <= len(self.examples):
example = self.examples[choice - 1]
await self._run_example(example)
-
+
if self.console:
if not Confirm.ask("\nRun another example?"):
break
else:
response = input("\nRun another example? (y/N): ")
- if response.lower() != 'y':
+ if response.lower() != "y":
break
else:
if self.console:
self.console.print("[red]Invalid choice. Please try again.[/red]")
else:
print("Invalid choice. Please try again.")
-
+
return 0
@@ -341,14 +359,14 @@ async def main():
# Setup basic logging
logging.basicConfig(
level=logging.WARNING, # Reduce noise in interactive mode
- format='%(levelname)s: %(message)s'
+ format="%(levelname)s: %(message)s",
)
-
+
# Check for rich dependency
if not HAS_RICH:
print("ā ļø For the best experience, install 'rich': pip install rich")
print("Continuing with basic text interface...\n")
-
+
demo = InteractiveDemo()
return await demo.run()
@@ -359,4 +377,4 @@ async def main():
sys.exit(exit_code)
except KeyboardInterrupt:
print("\nš Demo interrupted by user")
- sys.exit(0)
\ No newline at end of file
+ sys.exit(0)
diff --git a/examples/arcade-dev/requirements.txt b/examples/arcade-dev/requirements.txt
index 3f35774..0daac8e 100644
--- a/examples/arcade-dev/requirements.txt
+++ b/examples/arcade-dev/requirements.txt
@@ -1,7 +1,7 @@
# Arcade.dev Integration Examples - Python Dependencies
# Core HTTP and async libraries
-aiohttp>=3.8.0
+aiohttp>=3.13.4
asyncio-throttle>=1.0.2
httpx>=0.24.0
diff --git a/examples/arcade-dev/run_all_examples.py b/examples/arcade-dev/run_all_examples.py
index 9f827ff..0b75a6d 100644
--- a/examples/arcade-dev/run_all_examples.py
+++ b/examples/arcade-dev/run_all_examples.py
@@ -21,227 +21,230 @@
class ExampleRunner:
"""Manages execution of all Arcade.dev examples."""
-
+
def __init__(self):
self.base_dir = Path(__file__).parent
self.examples = self._discover_examples()
self.results: List[Dict[str, Any]] = []
-
+
def _discover_examples(self) -> List[Dict[str, str]]:
"""Discover all runnable examples."""
examples = []
-
+
# Known examples
known_examples = [
{
- 'name': 'Basic Integration',
- 'script': '01_basic_integration/basic_arcade_client.py',
- 'description': 'Basic API client integration'
+ "name": "Basic Integration",
+ "script": "01_basic_integration/basic_arcade_client.py",
+ "description": "Basic API client integration",
},
{
- 'name': 'Tool Registration',
- 'script': '02_tool_registration/register_fact_tools.py',
- 'description': 'Register FACT tools with Arcade.dev'
+ "name": "Tool Registration",
+ "script": "02_tool_registration/register_fact_tools.py",
+ "description": "Register FACT tools with Arcade.dev",
},
{
- 'name': 'Intelligent Routing',
- 'script': '03_intelligent_routing/hybrid_execution.py',
- 'description': 'Hybrid execution with intelligent routing'
+ "name": "Intelligent Routing",
+ "script": "03_intelligent_routing/hybrid_execution.py",
+ "description": "Hybrid execution with intelligent routing",
},
{
- 'name': 'Error Handling',
- 'script': '04_error_handling/resilient_execution.py',
- 'description': 'Resilient execution with error handling'
+ "name": "Error Handling",
+ "script": "04_error_handling/resilient_execution.py",
+ "description": "Resilient execution with error handling",
},
{
- 'name': 'Cache Integration (Enhanced)',
- 'script': '05_cache_integration/cached_arcade_client_enhanced.py',
- 'description': 'Enhanced cached API client'
+ "name": "Cache Integration (Enhanced)",
+ "script": "05_cache_integration/cached_arcade_client_enhanced.py",
+ "description": "Enhanced cached API client",
},
{
- 'name': 'Security',
- 'script': '06_security/secure_tool_execution.py',
- 'description': 'Secure tool execution with validation'
+ "name": "Security",
+ "script": "06_security/secure_tool_execution.py",
+ "description": "Secure tool execution with validation",
},
{
- 'name': 'Cache Integration',
- 'script': '07_cache_integration/cached_arcade_client.py',
- 'description': 'Cached API client with performance optimization'
+ "name": "Cache Integration",
+ "script": "07_cache_integration/cached_arcade_client.py",
+ "description": "Cached API client with performance optimization",
},
{
- 'name': 'Advanced Tools',
- 'script': '08_advanced_tools/advanced_tool_usage.py',
- 'description': 'Advanced tool usage patterns'
+ "name": "Advanced Tools",
+ "script": "08_advanced_tools/advanced_tool_usage.py",
+ "description": "Advanced tool usage patterns",
},
{
- 'name': 'Testing',
- 'script': '09_testing/arcade_integration_tests.py',
- 'description': 'Integration testing framework'
+ "name": "Testing",
+ "script": "09_testing/arcade_integration_tests.py",
+ "description": "Integration testing framework",
},
{
- 'name': 'Production Deployment',
- 'script': '10_deployment/production_deployment.py',
- 'description': 'Production deployment configuration'
+ "name": "Production Deployment",
+ "script": "10_deployment/production_deployment.py",
+ "description": "Production deployment configuration",
},
{
- 'name': 'Monitoring',
- 'script': '12_monitoring/arcade_monitoring.py',
- 'description': 'Monitoring and observability'
- }
+ "name": "Monitoring",
+ "script": "12_monitoring/arcade_monitoring.py",
+ "description": "Monitoring and observability",
+ },
]
-
+
# Filter to only include existing scripts
for example in known_examples:
- script_path = self.base_dir / example['script']
+ script_path = self.base_dir / example["script"]
if script_path.exists():
examples.append(example)
-
+
return examples
-
+
async def run_example(self, example: Dict[str, str]) -> Dict[str, Any]:
"""Run a single example and collect results."""
print(f"\n{'='*60}")
print(f"š Running: {example['name']}")
print(f"š Script: {example['script']}")
print(f"š Description: {example['description']}")
- print('='*60)
-
- script_path = self.base_dir / example['script']
+ print("=" * 60)
+
+ script_path = self.base_dir / example["script"]
start_time = time.time()
-
+
try:
# Run the example
process = await asyncio.create_subprocess_exec(
- sys.executable, str(script_path),
+ sys.executable,
+ str(script_path),
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
- cwd=str(script_path.parent)
+ cwd=str(script_path.parent),
)
-
+
stdout, stderr = await process.communicate()
duration = time.time() - start_time
-
+
result = {
- 'name': example['name'],
- 'script': example['script'],
- 'duration': duration,
- 'success': process.returncode == 0,
- 'return_code': process.returncode,
- 'stdout': stdout.decode() if stdout else '',
- 'stderr': stderr.decode() if stderr else ''
+ "name": example["name"],
+ "script": example["script"],
+ "duration": duration,
+ "success": process.returncode == 0,
+ "return_code": process.returncode,
+ "stdout": stdout.decode() if stdout else "",
+ "stderr": stderr.decode() if stderr else "",
}
-
- if result['success']:
+
+ if result["success"]:
print(f"ā
{example['name']} completed successfully ({duration:.2f}s)")
else:
- print(f"ā {example['name']} failed (return code: {process.returncode})")
-
+ print(
+ f"ā {example['name']} failed (return code: {process.returncode})"
+ )
+
# Show output/errors
- if result['stdout']:
+ if result["stdout"]:
print("\nš Output:")
print("-" * 40)
- print(result['stdout'])
-
- if result['stderr']:
+ print(result["stdout"])
+
+ if result["stderr"]:
print("\nā ļø Errors/Warnings:")
print("-" * 40)
- print(result['stderr'])
-
+ print(result["stderr"])
+
return result
-
+
except Exception as e:
duration = time.time() - start_time
print(f"ā {example['name']} failed with exception: {e}")
-
+
return {
- 'name': example['name'],
- 'script': example['script'],
- 'duration': duration,
- 'success': False,
- 'return_code': -1,
- 'stdout': '',
- 'stderr': str(e)
+ "name": example["name"],
+ "script": example["script"],
+ "duration": duration,
+ "success": False,
+ "return_code": -1,
+ "stdout": "",
+ "stderr": str(e),
}
-
+
def print_summary(self):
"""Print execution summary."""
- print("\n" + "="*80)
+ print("\n" + "=" * 80)
print("š EXECUTION SUMMARY")
- print("="*80)
-
+ print("=" * 80)
+
total_examples = len(self.results)
- successful = sum(1 for r in self.results if r['success'])
+ successful = sum(1 for r in self.results if r["success"])
failed = total_examples - successful
- total_duration = sum(r['duration'] for r in self.results)
-
+ total_duration = sum(r["duration"] for r in self.results)
+
print(f"Total Examples: {total_examples}")
print(f"Successful: {successful}")
print(f"Failed: {failed}")
print(f"Total Duration: {total_duration:.2f}s")
print(f"Average Duration: {total_duration/total_examples:.2f}s")
-
+
print("\nš Individual Results:")
print("-" * 80)
for result in self.results:
- status = "ā
" if result['success'] else "ā"
+ status = "ā
" if result["success"] else "ā"
print(f"{status} {result['name']:<30} {result['duration']:>8.2f}s")
-
+
if failed > 0:
print("\nā Failed Examples:")
print("-" * 40)
for result in self.results:
- if not result['success']:
+ if not result["success"]:
print(f"⢠{result['name']}: {result['stderr'][:100]}...")
-
- print("\n" + "="*80)
-
+
+ print("\n" + "=" * 80)
+
async def run_all(self):
"""Run all examples in sequence."""
print("š® Running All Arcade.dev Examples")
- print("="*50)
-
+ print("=" * 50)
+
if not self.examples:
print("ā No runnable examples found!")
return False
-
+
print(f"Found {len(self.examples)} examples to run")
-
+
for example in self.examples:
result = await self.run_example(example)
self.results.append(result)
-
+
# Small delay between examples
await asyncio.sleep(1)
-
+
self.print_summary()
-
+
# Return True if all examples succeeded
- return all(r['success'] for r in self.results)
+ return all(r["success"] for r in self.results)
async def main():
"""Main function."""
print("š® Arcade.dev Examples - Batch Runner")
- print("="*50)
-
+ print("=" * 50)
+
# Check prerequisites
- if not os.getenv('ARCADE_API_KEY'):
+ if not os.getenv("ARCADE_API_KEY"):
print("ā ARCADE_API_KEY environment variable not set")
print("Please copy .env.example to .env and configure your API key")
return 1
-
+
runner = ExampleRunner()
-
+
try:
success = await runner.run_all()
-
+
if success:
print("\nš All examples completed successfully!")
return 0
else:
print("\nā ļø Some examples failed. Check the summary above.")
return 1
-
+
except KeyboardInterrupt:
print("\nā Execution interrupted by user")
return 1
@@ -252,4 +255,4 @@ async def main():
if __name__ == "__main__":
exit_code = asyncio.run(main())
- sys.exit(exit_code)
\ No newline at end of file
+ sys.exit(exit_code)
diff --git a/examples/arcade-dev/utils/__init__.py b/examples/arcade-dev/utils/__init__.py
index 18623d0..89205d5 100644
--- a/examples/arcade-dev/utils/__init__.py
+++ b/examples/arcade-dev/utils/__init__.py
@@ -4,11 +4,16 @@
This package provides helper functions and utilities for running FACT arcade examples.
"""
-from .import_helper import setup_fact_imports, verify_fact_imports, get_fact_module_path, print_fact_module_info
+from .import_helper import (
+ setup_fact_imports,
+ verify_fact_imports,
+ get_fact_module_path,
+ print_fact_module_info,
+)
__all__ = [
- 'setup_fact_imports',
- 'verify_fact_imports',
- 'get_fact_module_path',
- 'print_fact_module_info'
-]
\ No newline at end of file
+ "setup_fact_imports",
+ "verify_fact_imports",
+ "get_fact_module_path",
+ "print_fact_module_info",
+]
diff --git a/examples/arcade-dev/utils/import_helper.py b/examples/arcade-dev/utils/import_helper.py
index 596b65e..01e4217 100644
--- a/examples/arcade-dev/utils/import_helper.py
+++ b/examples/arcade-dev/utils/import_helper.py
@@ -14,26 +14,28 @@
# Try to import python-dotenv if available
try:
from dotenv import load_dotenv
+
_DOTENV_AVAILABLE = True
except ImportError:
_DOTENV_AVAILABLE = False
+
def setup_fact_imports() -> Path:
"""
Set up Python path to allow importing FACT modules from any example location.
-
+
This function finds the project root (directory containing 'src') and adds it
to sys.path if it's not already there.
-
+
Returns:
Path to the project root directory
-
+
Raises:
ImportError: If the FACT project root cannot be found
"""
# Start from the current file's directory and work upward
current_path = Path(__file__).resolve()
-
+
# Look for the project root by finding the directory containing 'src'
for parent in current_path.parents:
src_dir = parent / "src"
@@ -45,24 +47,24 @@ def setup_fact_imports() -> Path:
"Could not find FACT project root. "
"Please ensure this script is within the FACT project directory structure."
)
-
+
# Add project root to Python path if not already present
project_root_str = str(project_root)
if project_root_str not in sys.path:
sys.path.insert(0, project_root_str)
-
+
# Load environment variables from root .env file
env_file = project_root / ".env"
if env_file.exists() and _DOTENV_AVAILABLE:
load_dotenv(env_file)
-
+
return project_root
def verify_fact_imports() -> bool:
"""
Verify that essential FACT modules can be imported.
-
+
Returns:
True if all essential modules can be imported, False otherwise
"""
@@ -70,6 +72,7 @@ def verify_fact_imports() -> bool:
# Test importing core FACT modules
from src.core.driver import FACTDriver
from src.cache.manager import CacheManager
+
return True
except ImportError as e:
print(f"Failed to import FACT modules: {e}")
@@ -79,30 +82,30 @@ def verify_fact_imports() -> bool:
def get_fact_module_path(module_name: str) -> Optional[Path]:
"""
Get the full path to a FACT module.
-
+
Args:
module_name: Name of the module (e.g., 'core.driver', 'cache.manager')
-
+
Returns:
Path to the module file if found, None otherwise
"""
try:
project_root = setup_fact_imports()
- module_parts = module_name.split('.')
+ module_parts = module_name.split(".")
module_path = project_root / "src"
-
+
for part in module_parts:
module_path = module_path / part
-
+
# Try both .py file and directory with __init__.py
- py_file = module_path.with_suffix('.py')
+ py_file = module_path.with_suffix(".py")
if py_file.exists():
return py_file
-
+
init_file = module_path / "__init__.py"
if init_file.exists():
return module_path
-
+
return None
except Exception:
return None
@@ -113,28 +116,32 @@ def print_fact_module_info():
try:
project_root = setup_fact_imports()
src_dir = project_root / "src"
-
+
print("š FACT Module Information:")
print(f" Project Root: {project_root}")
print(f" Source Directory: {src_dir}")
-
+
# List available modules
print("\nš¦ Available Modules:")
for item in src_dir.iterdir():
- if item.is_dir() and not item.name.startswith('.'):
+ if item.is_dir() and not item.name.startswith("."):
print(f" š {item.name}/")
# List key files in each module
for subitem in item.iterdir():
- if subitem.is_file() and subitem.suffix == '.py' and subitem.name != '__init__.py':
+ if (
+ subitem.is_file()
+ and subitem.suffix == ".py"
+ and subitem.name != "__init__.py"
+ ):
print(f" š {subitem.name}")
-
+
# Test imports
print("\nā
Import Test:")
if verify_fact_imports():
print(" All essential FACT modules can be imported successfully!")
else:
print(" ā Some FACT modules failed to import")
-
+
except Exception as e:
print(f"ā Error getting module info: {e}")
@@ -146,4 +153,4 @@ def print_fact_module_info():
if __name__ == "__main__":
# When run directly, display module information
- print_fact_module_info()
\ No newline at end of file
+ print_fact_module_info()
diff --git a/examples/arcade-dev/verify_setup.py b/examples/arcade-dev/verify_setup.py
index b0afd75..3657a34 100644
--- a/examples/arcade-dev/verify_setup.py
+++ b/examples/arcade-dev/verify_setup.py
@@ -17,46 +17,51 @@
# Load environment variables from .env file if it exists
try:
from dotenv import load_dotenv
+
load_dotenv()
except ImportError:
pass # dotenv not available, skip loading
# Setup logging
-logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')
+logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
logger = logging.getLogger(__name__)
class SetupVerifier:
"""Verifies the setup for Arcade.dev examples."""
-
+
def __init__(self):
self.errors: List[str] = []
self.warnings: List[str] = []
self.success: List[str] = []
-
+
def check_python_version(self) -> bool:
"""Check Python version compatibility."""
required_version = (3, 8)
current_version = sys.version_info[:2]
-
+
if current_version >= required_version:
- self.success.append(f"ā
Python {'.'.join(map(str, current_version))} (>= 3.8)")
+ self.success.append(
+ f"ā
Python {'.'.join(map(str, current_version))} (>= 3.8)"
+ )
return True
else:
- self.errors.append(f"ā Python {'.'.join(map(str, current_version))} < 3.8 (required)")
+ self.errors.append(
+ f"ā Python {'.'.join(map(str, current_version))} < 3.8 (required)"
+ )
return False
-
+
def check_required_packages(self) -> bool:
"""Check if required Python packages are installed."""
# Map package names to their import names
required_packages = {
- 'aiohttp': 'aiohttp',
- 'asyncio': 'asyncio',
- 'pydantic': 'pydantic',
- 'python-dotenv': 'dotenv',
- 'redis': 'redis'
+ "aiohttp": "aiohttp",
+ "asyncio": "asyncio",
+ "pydantic": "pydantic",
+ "python-dotenv": "dotenv",
+ "redis": "redis",
}
-
+
all_installed = True
for package_name, import_name in required_packages.items():
try:
@@ -65,147 +70,156 @@ def check_required_packages(self) -> bool:
except ImportError:
self.errors.append(f"ā Package '{package_name}' is not installed")
all_installed = False
-
+
return all_installed
-
+
def check_environment_variables(self) -> bool:
"""Check required environment variables."""
required_vars = [
- ('ARCADE_API_KEY', True),
- ('ARCADE_API_URL', False),
- ('ARCADE_TIMEOUT', False),
- ('ARCADE_MAX_RETRIES', False),
- ('FACT_LOG_LEVEL', False),
- ('FACT_CACHE_ENABLED', False)
+ ("ARCADE_API_KEY", True),
+ ("ARCADE_API_URL", False),
+ ("ARCADE_TIMEOUT", False),
+ ("ARCADE_MAX_RETRIES", False),
+ ("FACT_LOG_LEVEL", False),
+ ("FACT_CACHE_ENABLED", False),
]
-
+
all_set = True
for var_name, required in required_vars:
value = os.getenv(var_name)
if value:
- if var_name == 'ARCADE_API_KEY':
+ if var_name == "ARCADE_API_KEY":
display_value = f"{value[:8]}..." if len(value) > 8 else "***"
else:
display_value = value
self.success.append(f"ā
{var_name}={display_value}")
elif required:
- self.errors.append(f"ā Required environment variable '{var_name}' is not set")
+ self.errors.append(
+ f"ā Required environment variable '{var_name}' is not set"
+ )
all_set = False
else:
- self.warnings.append(f"ā ļø Optional environment variable '{var_name}' is not set")
-
+ self.warnings.append(
+ f"ā ļø Optional environment variable '{var_name}' is not set"
+ )
+
return all_set
-
+
def check_fact_framework(self) -> bool:
"""Check if FACT framework components are accessible."""
- fact_modules = [
- 'src.core.driver',
- 'src.cache.manager'
- ]
-
+ fact_modules = ["src.core.driver", "src.cache.manager"]
+
# Add parent directories to path
project_root = Path(__file__).parent.parent.parent
sys.path.insert(0, str(project_root))
-
+
all_accessible = True
for module_name in fact_modules:
try:
importlib.import_module(module_name)
self.success.append(f"ā
FACT module '{module_name}' is accessible")
except ImportError as e:
- self.errors.append(f"ā FACT module '{module_name}' is not accessible: {e}")
+ self.errors.append(
+ f"ā FACT module '{module_name}' is not accessible: {e}"
+ )
all_accessible = False
-
+
return all_accessible
-
+
def check_file_structure(self) -> bool:
"""Check if required files and directories exist."""
base_dir = Path(__file__).parent
required_paths = [
- 'README.md',
- 'requirements.txt',
- '01_basic_integration/basic_arcade_client.py',
- 'config'
+ "README.md",
+ "requirements.txt",
+ "01_basic_integration/basic_arcade_client.py",
+ "config",
]
-
+
all_exist = True
for path_str in required_paths:
path = base_dir / path_str
if path.exists():
self.success.append(f"ā
Path '{path_str}' exists")
else:
- if path_str == 'config':
- self.warnings.append(f"ā ļø Directory '{path_str}' does not exist (will be created)")
+ if path_str == "config":
+ self.warnings.append(
+ f"ā ļø Directory '{path_str}' does not exist (will be created)"
+ )
else:
self.errors.append(f"ā Required path '{path_str}' does not exist")
all_exist = False
-
+
return all_exist
-
+
async def check_network_connectivity(self) -> bool:
"""Check network connectivity to Arcade.dev API."""
try:
import aiohttp
- api_url = os.getenv('ARCADE_API_URL', 'https://api.arcade.dev')
-
+
+ api_url = os.getenv("ARCADE_API_URL", "https://api.arcade.dev")
+
# List of endpoints to try (in order of preference)
- endpoints_to_try = [
- '/health',
- '/v1/health',
- '/status',
- '/'
- ]
-
+ endpoints_to_try = ["/health", "/v1/health", "/status", "/"]
+
async with aiohttp.ClientSession() as session:
for endpoint in endpoints_to_try:
try:
- async with session.get(f"{api_url}{endpoint}", timeout=10) as response:
+ async with session.get(
+ f"{api_url}{endpoint}", timeout=10
+ ) as response:
if response.status == 200:
- self.success.append(f"ā
Network connectivity to {api_url} is working (endpoint: {endpoint})")
+ self.success.append(
+ f"ā
Network connectivity to {api_url} is working (endpoint: {endpoint})"
+ )
return True
elif response.status == 404:
# 404 is expected for some endpoints, continue trying
continue
elif response.status in [401, 403]:
# Authentication errors indicate the API is reachable
- self.success.append(f"ā
Network connectivity to {api_url} is working (authentication required)")
+ self.success.append(
+ f"ā
Network connectivity to {api_url} is working (authentication required)"
+ )
return True
else:
- self.warnings.append(f"ā ļø {api_url}{endpoint} returned status {response.status}")
+ self.warnings.append(
+ f"ā ļø {api_url}{endpoint} returned status {response.status}"
+ )
continue
except asyncio.TimeoutError:
continue
except Exception as e:
continue
-
+
# If we get here, none of the endpoints worked
- self.warnings.append(f"ā ļø Could not verify API connectivity to {api_url} (all endpoints returned 404 or failed)")
- self.warnings.append(f"ā ļø This may be normal if the API requires authentication or has a different endpoint structure")
+ self.warnings.append(
+ f"ā ļø Could not verify API connectivity to {api_url} (all endpoints returned 404 or failed)"
+ )
+ self.warnings.append(
+ f"ā ļø This may be normal if the API requires authentication or has a different endpoint structure"
+ )
return False
-
+
except Exception as e:
self.warnings.append(f"ā ļø Could not verify network connectivity: {e}")
return False
-
+
def create_missing_directories(self):
"""Create missing directories."""
base_dir = Path(__file__).parent
- directories = [
- 'config',
- 'output',
- 'logs'
- ]
-
+ directories = ["config", "output", "logs"]
+
for dir_name in directories:
dir_path = base_dir / dir_name
if not dir_path.exists():
dir_path.mkdir(parents=True, exist_ok=True)
self.success.append(f"ā
Created directory '{dir_name}'")
-
+
async def run_verification(self) -> bool:
"""Run all verification checks."""
print("š Verifying Arcade.dev Examples Setup...\n")
-
+
checks = [
("Python Version", self.check_python_version),
("Required Packages", self.check_required_packages),
@@ -213,52 +227,54 @@ async def run_verification(self) -> bool:
("FACT Framework", self.check_fact_framework),
("File Structure", self.check_file_structure),
]
-
+
# Run synchronous checks
all_passed = True
for check_name, check_func in checks:
print(f"Checking {check_name}...")
passed = check_func()
all_passed = all_passed and passed
-
+
# Run async checks
print("Checking Network Connectivity...")
await self.check_network_connectivity()
-
+
# Create missing directories
print("Creating Missing Directories...")
self.create_missing_directories()
-
+
return all_passed
-
+
def print_summary(self):
"""Print verification summary."""
- print("\n" + "="*60)
+ print("\n" + "=" * 60)
print("SETUP VERIFICATION SUMMARY")
- print("="*60)
-
+ print("=" * 60)
+
if self.success:
print("\nā
PASSED:")
for item in self.success:
print(f" {item}")
-
+
if self.warnings:
print("\nā ļø WARNINGS:")
for item in self.warnings:
print(f" {item}")
-
+
if self.errors:
print("\nā ERRORS:")
for item in self.errors:
print(f" {item}")
-
- print("\n" + "="*60)
-
+
+ print("\n" + "=" * 60)
+
if self.errors:
print("ā Setup verification FAILED. Please fix the errors above.")
print("\nNext steps:")
print("1. Install missing packages: pip install -r requirements.txt")
- print("2. Copy .env.example to .env and configure your environment variables:")
+ print(
+ "2. Copy .env.example to .env and configure your environment variables:"
+ )
print(" cp .env.example .env")
print(" # Then edit .env with your actual API keys and configuration")
print("3. Ensure FACT framework is properly installed")
@@ -267,7 +283,9 @@ def print_summary(self):
else:
print("ā
Setup verification PASSED!")
if self.warnings:
- print("Note: There are warnings that should be addressed for optimal functionality.")
+ print(
+ "Note: There are warnings that should be addressed for optimal functionality."
+ )
print("\nYou can now run the Arcade.dev examples:")
print(" python 01_basic_integration/basic_arcade_client.py")
return True
@@ -276,12 +294,12 @@ def print_summary(self):
async def main():
"""Main verification function."""
verifier = SetupVerifier()
-
+
try:
success = await verifier.run_verification()
verifier.print_summary()
return 0 if success else 1
-
+
except KeyboardInterrupt:
print("\nā Verification interrupted by user")
return 1
@@ -292,4 +310,4 @@ async def main():
if __name__ == "__main__":
exit_code = asyncio.run(main())
- sys.exit(exit_code)
\ No newline at end of file
+ sys.exit(exit_code)
diff --git a/examples/tool_execution_demo.py b/examples/tool_execution_demo.py
index d7bcc96..ef2eabe 100644
--- a/examples/tool_execution_demo.py
+++ b/examples/tool_execution_demo.py
@@ -13,7 +13,7 @@
from typing import Dict, Any
# Add the project root to the path
-sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
+sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
# Import FACT system components
from src.tools.executor import ToolExecutor, create_tool_call
@@ -29,15 +29,15 @@
description="Perform basic mathematical calculations",
parameters={
"operation": {
- "type": "string",
+ "type": "string",
"description": "Mathematical operation",
- "enum": ["add", "subtract", "multiply", "divide"]
+ "enum": ["add", "subtract", "multiply", "divide"],
},
"a": {"type": "number", "description": "First number"},
- "b": {"type": "number", "description": "Second number"}
+ "b": {"type": "number", "description": "Second number"},
},
requires_auth=False,
- timeout_seconds=10
+ timeout_seconds=10,
)
def calculator_tool(operation: str, a: float, b: float) -> Dict[str, Any]:
"""Perform basic mathematical calculations."""
@@ -45,22 +45,22 @@ def calculator_tool(operation: str, a: float, b: float) -> Dict[str, Any]:
"add": lambda x, y: x + y,
"subtract": lambda x, y: x - y,
"multiply": lambda x, y: x * y,
- "divide": lambda x, y: x / y if y != 0 else None
+ "divide": lambda x, y: x / y if y != 0 else None,
}
-
+
if operation not in operations:
return {"error": f"Unknown operation: {operation}"}
-
+
if operation == "divide" and b == 0:
return {"error": "Division by zero"}
-
+
result = operations[operation](a, b)
-
+
return {
"operation": operation,
"operands": [a, b],
"result": result,
- "expression": f"{a} {operation} {b} = {result}"
+ "expression": f"{a} {operation} {b} = {result}",
}
@@ -68,31 +68,33 @@ def calculator_tool(operation: str, a: float, b: float) -> Dict[str, Any]:
name="Demo.TextProcessor",
description="Process and analyze text content",
parameters={
- "text": {
- "type": "string",
- "description": "Text to process",
- "maxLength": 1000
- },
+ "text": {"type": "string", "description": "Text to process", "maxLength": 1000},
"operations": {
"type": "array",
"description": "Operations to perform",
"items": {
"type": "string",
- "enum": ["uppercase", "lowercase", "reverse", "word_count", "char_count"]
+ "enum": [
+ "uppercase",
+ "lowercase",
+ "reverse",
+ "word_count",
+ "char_count",
+ ],
},
- "default": ["word_count"]
- }
+ "default": ["word_count"],
+ },
},
requires_auth=False,
- timeout_seconds=15
+ timeout_seconds=15,
)
def text_processor_tool(text: str, operations: list = None) -> Dict[str, Any]:
"""Process and analyze text content."""
if operations is None:
operations = ["word_count"]
-
+
results = {"original_text": text, "processed": {}}
-
+
for operation in operations:
if operation == "uppercase":
results["processed"]["uppercase"] = text.upper()
@@ -104,7 +106,7 @@ def text_processor_tool(text: str, operations: list = None) -> Dict[str, Any]:
results["processed"]["word_count"] = len(text.split())
elif operation == "char_count":
results["processed"]["char_count"] = len(text)
-
+
return results
@@ -115,31 +117,33 @@ def text_processor_tool(text: str, operations: list = None) -> Dict[str, Any]:
"data_type": {
"type": "string",
"description": "Type of data to generate",
- "enum": ["numbers", "names", "emails", "dates"]
+ "enum": ["numbers", "names", "emails", "dates"],
},
"count": {
"type": "integer",
"description": "Number of items to generate",
"minimum": 1,
"maximum": 100,
- "default": 10
+ "default": 10,
},
"format": {
"type": "string",
"description": "Output format",
"enum": ["list", "json", "csv"],
- "default": "list"
- }
+ "default": "list",
+ },
},
requires_auth=False,
- timeout_seconds=20
+ timeout_seconds=20,
)
-def data_generator_tool(data_type: str, count: int = 10, format: str = "list") -> Dict[str, Any]:
+def data_generator_tool(
+ data_type: str, count: int = 10, format: str = "list"
+) -> Dict[str, Any]:
"""Generate sample data sets for testing."""
import random
import string
from datetime import datetime, timedelta
-
+
generators = {
"numbers": lambda: [random.randint(1, 1000) for _ in range(count)],
"names": lambda: [f"User_{i+1}" for i in range(count)],
@@ -147,20 +151,16 @@ def data_generator_tool(data_type: str, count: int = 10, format: str = "list") -
"dates": lambda: [
(datetime.now() - timedelta(days=random.randint(0, 365))).isoformat()
for _ in range(count)
- ]
+ ],
}
-
+
if data_type not in generators:
return {"error": f"Unknown data type: {data_type}"}
-
+
data = generators[data_type]()
-
- result = {
- "data_type": data_type,
- "count": len(data),
- "format": format
- }
-
+
+ result = {"data_type": data_type, "count": len(data), "format": format}
+
if format == "list":
result["data"] = data
elif format == "json":
@@ -170,85 +170,77 @@ def data_generator_tool(data_type: str, count: int = 10, format: str = "list") -
result["data"] = "value\n" + "\n".join(map(str, data))
else:
result["data"] = f"{data_type}\n" + "\n".join(map(str, data))
-
+
return result
async def demo_basic_tool_execution():
"""Demonstrate basic tool execution without Arcade."""
print("\n=== Demo: Basic Tool Execution ===")
-
+
# Create tool executor (local execution only)
executor = ToolExecutor(
- arcade_client=None,
- enable_rate_limiting=True,
- max_calls_per_minute=60
+ arcade_client=None, enable_rate_limiting=True, max_calls_per_minute=60
)
-
+
# Test calculator tool
print("\n1. Testing Calculator Tool:")
calc_call = create_tool_call(
- "Demo.Calculator",
- {"operation": "multiply", "a": 15, "b": 7}
+ "Demo.Calculator", {"operation": "multiply", "a": 15, "b": 7}
)
-
+
result = await executor.execute_tool_call(calc_call)
print(f" Result: {result.data}")
print(f" Execution time: {result.execution_time_ms:.2f}ms")
-
+
# Test text processor tool
print("\n2. Testing Text Processor Tool:")
text_call = create_tool_call(
"Demo.TextProcessor",
{
"text": "Hello World! This is a test.",
- "operations": ["uppercase", "word_count", "char_count"]
- }
+ "operations": ["uppercase", "word_count", "char_count"],
+ },
)
-
+
result = await executor.execute_tool_call(text_call)
print(f" Original: {result.data['original_text']}")
print(f" Processed: {json.dumps(result.data['processed'], indent=2)}")
-
+
# Test data generator tool
print("\n3. Testing Data Generator Tool:")
data_call = create_tool_call(
- "Demo.DataGenerator",
- {"data_type": "emails", "count": 5, "format": "list"}
+ "Demo.DataGenerator", {"data_type": "emails", "count": 5, "format": "list"}
)
-
+
result = await executor.execute_tool_call(data_call)
print(f" Generated {result.data['count']} {result.data['data_type']}:")
- for email in result.data['data']:
+ for email in result.data["data"]:
print(f" - {email}")
async def demo_error_handling():
"""Demonstrate error handling in tool execution."""
print("\n=== Demo: Error Handling ===")
-
+
executor = ToolExecutor(arcade_client=None)
-
+
# Test division by zero
print("\n1. Testing Division by Zero:")
error_call = create_tool_call(
- "Demo.Calculator",
- {"operation": "divide", "a": 10, "b": 0}
+ "Demo.Calculator", {"operation": "divide", "a": 10, "b": 0}
)
-
+
result = await executor.execute_tool_call(error_call)
if result.success:
print(f" Error message: {result.data['error']}")
else:
print(f" Execution failed: {result.error}")
-
+
# Test invalid tool
print("\n2. Testing Invalid Tool:")
- invalid_call = create_tool_call(
- "NonExistent.Tool",
- {"param": "value"}
- )
-
+ invalid_call = create_tool_call("NonExistent.Tool", {"param": "value"})
+
result = await executor.execute_tool_call(invalid_call)
print(f" Error: {result.error}")
print(f" Status: {result.status_code}")
@@ -257,26 +249,26 @@ async def demo_error_handling():
async def demo_concurrent_execution():
"""Demonstrate concurrent tool execution."""
print("\n=== Demo: Concurrent Execution ===")
-
+
executor = ToolExecutor(arcade_client=None)
-
+
# Create multiple tool calls
tool_calls = [
- create_tool_call("Demo.Calculator", {"operation": "add", "a": i, "b": i*2})
+ create_tool_call("Demo.Calculator", {"operation": "add", "a": i, "b": i * 2})
for i in range(1, 6)
]
-
+
print(f"\nExecuting {len(tool_calls)} calculator operations concurrently...")
-
+
start_time = time.time()
results = await executor.execute_tool_calls(tool_calls)
execution_time = (time.time() - start_time) * 1000
-
+
print(f"Total execution time: {execution_time:.2f}ms")
print("Results:")
for i, result in enumerate(results):
if result.success:
- expr = result.data['expression']
+ expr = result.data["expression"]
print(f" {i+1}. {expr} (took {result.execution_time_ms:.1f}ms)")
else:
print(f" {i+1}. Error: {result.error}")
@@ -285,24 +277,21 @@ async def demo_concurrent_execution():
async def demo_rate_limiting():
"""Demonstrate rate limiting functionality."""
print("\n=== Demo: Rate Limiting ===")
-
+
# Create executor with low rate limit
executor = ToolExecutor(
arcade_client=None,
enable_rate_limiting=True,
- max_calls_per_minute=3 # Very low limit for demo
+ max_calls_per_minute=3, # Very low limit for demo
)
-
+
print("\nTesting rate limiting (max 3 calls per minute):")
-
+
for i in range(5):
- call = create_tool_call(
- "Demo.Calculator",
- {"operation": "add", "a": i, "b": 1}
- )
-
+ call = create_tool_call("Demo.Calculator", {"operation": "add", "a": i, "b": 1})
+
result = await executor.execute_tool_call(call)
-
+
if result.success:
print(f" Call {i+1}: Success - {result.data['result']}")
else:
@@ -312,34 +301,36 @@ async def demo_rate_limiting():
async def demo_metrics_collection():
"""Demonstrate metrics collection and reporting."""
print("\n=== Demo: Metrics Collection ===")
-
+
executor = ToolExecutor(arcade_client=None)
metrics_collector = get_metrics_collector()
-
+
# Execute several tools to generate metrics
print("\nExecuting tools to generate metrics...")
-
+
operations = [
{"operation": "add", "a": 10, "b": 5},
{"operation": "multiply", "a": 3, "b": 7},
{"operation": "subtract", "a": 20, "b": 8},
- {"operation": "divide", "a": 15, "b": 3}
+ {"operation": "divide", "a": 15, "b": 3},
]
-
+
for op in operations:
call = create_tool_call("Demo.Calculator", op)
result = await executor.execute_tool_call(call)
-
+
# Metrics are automatically collected by the executor
- print(f" {op['operation']}: {result.data.get('result') if result.success else 'Error'}")
-
+ print(
+ f" {op['operation']}: {result.data.get('result') if result.success else 'Error'}"
+ )
+
# Get system metrics
print("\nSystem Metrics:")
system_metrics = metrics_collector.get_system_metrics(time_window_minutes=5)
print(f" Total executions: {system_metrics.total_executions}")
print(f" Success rate: {100 - system_metrics.error_rate:.1f}%")
print(f" Average execution time: {system_metrics.average_execution_time:.2f}ms")
-
+
# Get tool-specific metrics
print("\nCalculator Tool Metrics:")
tool_metrics = metrics_collector.get_tool_metrics("Demo.Calculator")
@@ -351,36 +342,36 @@ async def demo_metrics_collection():
async def demo_available_tools():
"""Demonstrate tool discovery and information retrieval."""
print("\n=== Demo: Tool Discovery ===")
-
+
executor = ToolExecutor(arcade_client=None)
-
+
# Get list of available tools
tools = executor.get_available_tools()
-
+
print(f"\nFound {len(tools)} available tools:")
for tool in tools:
- function_info = tool['function']
+ function_info = tool["function"]
print(f"\n Tool: {function_info['name']}")
print(f" Description: {function_info['description']}")
-
+
# Show parameters
- params = function_info.get('parameters', {}).get('properties', {})
+ params = function_info.get("parameters", {}).get("properties", {})
if params:
print(" Parameters:")
for param_name, param_info in params.items():
- param_type = param_info.get('type', 'unknown')
- param_desc = param_info.get('description', 'No description')
+ param_type = param_info.get("type", "unknown")
+ param_desc = param_info.get("description", "No description")
print(f" - {param_name} ({param_type}): {param_desc}")
async def demo_arcade_integration():
"""Demonstrate Arcade.dev integration (simulated)."""
print("\n=== Demo: Arcade.dev Integration (Simulated) ===")
-
+
# Note: This would require actual Arcade.dev credentials
print("\nSimulating Arcade.dev integration...")
print("(In production, this would connect to actual Arcade.dev platform)")
-
+
# Create mock Arcade client for demo
class MockArcadeClient:
async def execute_tool(self, tool_name, arguments, **kwargs):
@@ -389,18 +380,17 @@ async def execute_tool(self, tool_name, arguments, **kwargs):
"success": True,
"data": f"Arcade executed {tool_name} with {arguments}",
"execution_environment": "secure_container",
- "execution_id": f"arcade_exec_{int(time.time())}"
+ "execution_id": f"arcade_exec_{int(time.time())}",
}
-
+
mock_client = MockArcadeClient()
executor = ToolExecutor(arcade_client=mock_client)
-
+
# Execute tool via "Arcade"
call = create_tool_call(
- "Demo.Calculator",
- {"operation": "multiply", "a": 6, "b": 9}
+ "Demo.Calculator", {"operation": "multiply", "a": 6, "b": 9}
)
-
+
result = await executor.execute_tool_call(call)
print(f" Arcade execution result: {result.data}")
print(f" Execution time: {result.execution_time_ms:.2f}ms")
@@ -410,7 +400,7 @@ async def main():
"""Run all demonstrations."""
print("FACT System Tool Execution Framework Demo")
print("=" * 50)
-
+
try:
await demo_basic_tool_execution()
await demo_error_handling()
@@ -419,16 +409,17 @@ async def main():
await demo_metrics_collection()
await demo_available_tools()
await demo_arcade_integration()
-
+
print("\n" + "=" * 50)
print("All demos completed successfully!")
-
+
except Exception as e:
print(f"\nDemo failed with error: {e}")
import traceback
+
traceback.print_exc()
if __name__ == "__main__":
# Run the demo
- asyncio.run(main())
\ No newline at end of file
+ asyncio.run(main())
diff --git a/fact-memory/examples/basic_usage.py b/fact-memory/examples/basic_usage.py
index 3a5f904..f99d3cb 100644
--- a/fact-memory/examples/basic_usage.py
+++ b/fact-memory/examples/basic_usage.py
@@ -13,121 +13,123 @@
async def basic_memory_operations():
"""Demonstrate basic memory operations."""
-
+
# Initialize memory manager with FACT cache integration
memory_manager = FactMemoryManager(
cache_config={
"prefix": "fact_memory_demo",
"min_tokens": 50, # Lower for demo
"max_size": "1MB",
- "ttl_seconds": 3600
+ "ttl_seconds": 3600,
}
)
-
+
user_id = "demo_user_123"
-
+
print("=== FACT Memory System - Basic Usage Demo ===\n")
-
+
# 1. Add various types of memories
print("1. Adding memories...")
-
+
memories = [
{
"content": "User prefers dark mode interface with high contrast colors",
"memory_type": MemoryType.PREFERENCE,
- "tags": ["ui", "accessibility", "dark-mode"]
+ "tags": ["ui", "accessibility", "dark-mode"],
},
{
"content": "User is a software engineer working on AI/ML projects",
"memory_type": MemoryType.FACT,
- "tags": ["profession", "ai", "ml"]
+ "tags": ["profession", "ai", "ml"],
},
{
"content": "User frequently asks about Python optimization techniques",
"memory_type": MemoryType.BEHAVIOR,
- "tags": ["python", "optimization", "patterns"]
+ "tags": ["python", "optimization", "patterns"],
},
{
"content": "Always provide code examples with detailed explanations",
"memory_type": MemoryType.INSTRUCTION,
- "tags": ["communication", "code", "teaching"]
- }
+ "tags": ["communication", "code", "teaching"],
+ },
]
-
+
added_memories = []
for memory_data in memories:
memory = await memory_manager.add_memory(
user_id=user_id,
content=memory_data["content"],
memory_type=memory_data["memory_type"],
- tags=memory_data["tags"]
+ tags=memory_data["tags"],
)
added_memories.append(memory)
print(f"ā Added {memory_data['memory_type'].value}: {memory.id}")
-
+
print(f"\nAdded {len(added_memories)} memories for user {user_id}\n")
-
+
# 2. Search for memories
print("2. Searching memories...")
-
+
search_queries = [
"What are the user's interface preferences?",
"Tell me about the user's profession",
"How should I communicate with this user?",
- "What programming topics interest the user?"
+ "What programming topics interest the user?",
]
-
+
for query in search_queries:
print(f"\nQuery: '{query}'")
results = await memory_manager.search_memories(
- user_id=user_id,
- query=query,
- limit=3
+ user_id=user_id, query=query, limit=3
)
-
+
for i, memory in enumerate(results, 1):
print(f" {i}. [{memory.memory_type.value}] {memory.content[:60]}...")
- print(f" Relevance: {memory.relevance_score:.2f} | Tags: {', '.join(memory.tags)}")
-
+ print(
+ f" Relevance: {memory.relevance_score:.2f} | Tags: {', '.join(memory.tags)}"
+ )
+
# 3. Get all memories by type
print("\n3. Retrieving memories by type...")
-
+
for memory_type in MemoryType:
memories = await memory_manager.get_memories_by_type(user_id, memory_type)
print(f"{memory_type.value.title()}: {len(memories)} memories")
-
+
# 4. Update a memory
print("\n4. Updating a memory...")
-
- preference_memory = next(m for m in added_memories if m.memory_type == MemoryType.PREFERENCE)
+
+ preference_memory = next(
+ m for m in added_memories if m.memory_type == MemoryType.PREFERENCE
+ )
updated_memory = await memory_manager.update_memory(
user_id=user_id,
memory_id=preference_memory.id,
content="User prefers dark mode with blue accent colors and large fonts for accessibility",
- tags=["ui", "accessibility", "dark-mode", "fonts", "colors"]
+ tags=["ui", "accessibility", "dark-mode", "fonts", "colors"],
)
print(f"ā Updated memory: {updated_memory.id}")
-
+
# 5. Get memory statistics
print("\n5. Memory statistics...")
-
+
stats = await memory_manager.get_user_stats(user_id)
print(f"Total memories: {stats['total_memories']}")
print(f"Memory types: {stats['memory_by_type']}")
print(f"Cache hit rate: {stats['cache_hit_rate']:.1%}")
print(f"Average access time: {stats['avg_access_time_ms']}ms")
-
+
return added_memories
async def advanced_search_examples():
"""Demonstrate advanced search capabilities."""
-
+
memory_manager = FactMemoryManager()
user_id = "advanced_user_456"
-
+
print("\n=== Advanced Search Examples ===\n")
-
+
# Add some complex memories for advanced search
complex_memories = [
"User is building a machine learning pipeline using Python and scikit-learn for customer churn prediction",
@@ -136,48 +138,45 @@ async def advanced_search_examples():
"Uses VS Code with Copilot extension and prefers vim keybindings",
"Interested in functional programming concepts, especially in JavaScript and Python",
"Team lead responsible for code reviews and architecture decisions",
- "Prefers Test-Driven Development approach with pytest for Python projects"
+ "Prefers Test-Driven Development approach with pytest for Python projects",
]
-
+
for i, content in enumerate(complex_memories):
await memory_manager.add_memory(
user_id=user_id,
content=content,
memory_type=MemoryType.FACT,
- tags=[f"tech_{i}", "development"]
+ tags=[f"tech_{i}", "development"],
)
-
+
# Advanced search scenarios
search_scenarios = [
{
"query": "What technologies does the user work with?",
- "description": "Technology stack inquiry"
+ "description": "Technology stack inquiry",
},
{
"query": "How does the user prefer to develop software?",
- "description": "Development methodology preferences"
+ "description": "Development methodology preferences",
},
{
"query": "What is the user's role and responsibilities?",
- "description": "Professional role inquiry"
+ "description": "Professional role inquiry",
},
{
"query": "Tell me about the user's deployment preferences",
- "description": "Infrastructure and deployment"
- }
+ "description": "Infrastructure and deployment",
+ },
]
-
+
for scenario in search_scenarios:
print(f"Scenario: {scenario['description']}")
print(f"Query: '{scenario['query']}'")
-
+
results = await memory_manager.search_memories(
- user_id=user_id,
- query=scenario["query"],
- limit=5,
- min_relevance=0.3
+ user_id=user_id, query=scenario["query"], limit=5, min_relevance=0.3
)
-
+
print(f"Found {len(results)} relevant memories:")
for i, memory in enumerate(results, 1):
print(f" {i}. {memory.content[:80]}...")
@@ -187,26 +186,23 @@ async def advanced_search_examples():
async def mcp_integration_example():
"""Demonstrate MCP server integration."""
-
+
print("=== MCP Integration Example ===\n")
-
+
# This would typically be done in a separate process/server
from fact_memory.mcp import FactMemoryMCPServer
-
+
# Initialize MCP server
mcp_server = FactMemoryMCPServer(
host="localhost",
port=8080,
- memory_config={
- "max_memories_per_user": 1000,
- "cache_ttl_seconds": 3600
- }
+ memory_config={"max_memories_per_user": 1000, "cache_ttl_seconds": 3600},
)
-
+
print("MCP Server initialized with tools:")
for tool_name in mcp_server.get_available_tools():
print(f" - {tool_name}")
-
+
# Simulate MCP client interactions
client_requests = [
{
@@ -215,45 +211,46 @@ async def mcp_integration_example():
"content": "User prefers concise explanations with practical examples",
"userId": "mcp_user_789",
"memoryType": "instruction",
- "tags": ["communication", "learning"]
- }
+ "tags": ["communication", "learning"],
+ },
},
{
"tool": "search-memories",
"params": {
"query": "How should I communicate with this user?",
"userId": "mcp_user_789",
- "limit": 3
- }
- }
+ "limit": 3,
+ },
+ },
]
-
+
print("\nSimulating MCP client requests:")
for request in client_requests:
print(f"\nRequest: {request['tool']}")
print(f"Params: {request['params']}")
-
+
# Simulate tool execution
response = await mcp_server.handle_tool_request(
- tool_name=request["tool"],
- parameters=request["params"]
+ tool_name=request["tool"], parameters=request["params"]
)
-
+
print(f"Response: {response['content'][0]['text'][:100]}...")
if __name__ == "__main__":
+
async def main():
try:
await basic_memory_operations()
await advanced_search_examples()
await mcp_integration_example()
-
+
print("\nā All examples completed successfully!")
-
+
except Exception as e:
print(f"\nā Error running examples: {e}")
import traceback
+
traceback.print_exc()
-
- asyncio.run(main())
\ No newline at end of file
+
+ asyncio.run(main())
diff --git a/fact-memory/examples/mcp_client.py b/fact-memory/examples/mcp_client.py
index fdf21ec..cbf3baf 100644
--- a/fact-memory/examples/mcp_client.py
+++ b/fact-memory/examples/mcp_client.py
@@ -16,6 +16,7 @@
@dataclass
class MCPResponse:
"""Represents an MCP tool response."""
+
content: List[Dict[str, Any]]
metadata: Dict[str, Any] = None
@@ -23,30 +24,30 @@ class MCPResponse:
class FactMemoryMCPClient:
"""
MCP client for FACT Memory System.
-
+
This client demonstrates how to interact with the FACT Memory MCP server
using the standard MCP protocol. It's compatible with existing Mem0 workflows.
"""
-
+
def __init__(self, server_url: str = "http://localhost:8080", api_key: str = None):
self.server_url = server_url
self.api_key = api_key
self.headers = {
"Content-Type": "application/json",
- "Authorization": f"Bearer {api_key}" if api_key else None
+ "Authorization": f"Bearer {api_key}" if api_key else None,
}
-
+
async def tool(self, tool_name: str, parameters: Dict[str, Any]) -> MCPResponse:
"""
Execute an MCP tool on the FACT Memory server.
-
+
This method simulates the MCP protocol interaction.
In a real implementation, this would make HTTP requests to the MCP server.
"""
# Simulate MCP tool execution
print(f"š” MCP Tool Call: {tool_name}")
print(f" Parameters: {json.dumps(parameters, indent=2)}")
-
+
# Simulate different tool responses
if tool_name == "add-memory":
return await self._handle_add_memory(parameters)
@@ -60,13 +61,13 @@ async def tool(self, tool_name: str, parameters: Dict[str, Any]) -> MCPResponse:
return await self._handle_get_stats(parameters)
else:
raise ValueError(f"Unknown tool: {tool_name}")
-
+
async def _handle_add_memory(self, params: Dict[str, Any]) -> MCPResponse:
"""Simulate add-memory tool response."""
memory_id = f"mem_{hash(params['content']) % 1000000:06d}"
-
+
response_text = f"Memory added successfully. ID: {memory_id}"
-
+
return MCPResponse(
content=[{"type": "text", "text": response_text}],
metadata={
@@ -74,49 +75,55 @@ async def _handle_add_memory(self, params: Dict[str, Any]) -> MCPResponse:
"userId": params["userId"],
"memoryType": params.get("memoryType", "fact"),
"tokenCount": len(params["content"].split()) * 1.3, # Rough estimate
- "timestamp": "2024-01-15T10:30:00Z"
- }
+ "timestamp": "2024-01-15T10:30:00Z",
+ },
)
-
+
async def _handle_search_memories(self, params: Dict[str, Any]) -> MCPResponse:
"""Simulate search-memories tool response."""
# Simulate search results based on query
query = params["query"].lower()
user_id = params["userId"]
-
+
# Mock search results based on common queries
mock_memories = []
-
+
if "preference" in query or "interface" in query:
- mock_memories.append({
- "content": "User prefers dark mode interface with high contrast",
- "relevance": 0.95,
- "type": "preference",
- "tags": ["ui", "accessibility"],
- "created": "2024-01-15T10:30:00Z"
- })
-
+ mock_memories.append(
+ {
+ "content": "User prefers dark mode interface with high contrast",
+ "relevance": 0.95,
+ "type": "preference",
+ "tags": ["ui", "accessibility"],
+ "created": "2024-01-15T10:30:00Z",
+ }
+ )
+
if "profession" in query or "work" in query:
- mock_memories.append({
- "content": "User is a software engineer working on AI/ML projects",
- "relevance": 0.87,
- "type": "fact",
- "tags": ["profession", "ai"],
- "created": "2024-01-14T15:20:00Z"
- })
-
+ mock_memories.append(
+ {
+ "content": "User is a software engineer working on AI/ML projects",
+ "relevance": 0.87,
+ "type": "fact",
+ "tags": ["profession", "ai"],
+ "created": "2024-01-14T15:20:00Z",
+ }
+ )
+
if "communication" in query or "explain" in query:
- mock_memories.append({
- "content": "Always provide code examples with detailed explanations",
- "relevance": 0.78,
- "type": "instruction",
- "tags": ["communication", "code"],
- "created": "2024-01-13T09:15:00Z"
- })
-
+ mock_memories.append(
+ {
+ "content": "Always provide code examples with detailed explanations",
+ "relevance": 0.78,
+ "type": "instruction",
+ "tags": ["communication", "code"],
+ "created": "2024-01-13T09:15:00Z",
+ }
+ )
+
# Format response text
response_lines = []
- for memory in mock_memories[:params.get("limit", 10)]:
+ for memory in mock_memories[: params.get("limit", 10)]:
response_lines.append(
f"Memory: {memory['content']}\n"
f"Relevance: {memory['relevance']}\n"
@@ -124,9 +131,9 @@ async def _handle_search_memories(self, params: Dict[str, Any]) -> MCPResponse:
f"Tags: {', '.join(memory['tags'])}\n"
f"Created: {memory['created']}\n---"
)
-
+
response_text = "\n".join(response_lines)
-
+
return MCPResponse(
content=[{"type": "text", "text": response_text}],
metadata={
@@ -134,14 +141,14 @@ async def _handle_search_memories(self, params: Dict[str, Any]) -> MCPResponse:
"searchTime": "45ms",
"cacheHit": True,
"query": params["query"],
- "userId": user_id
- }
+ "userId": user_id,
+ },
)
-
+
async def _handle_get_memories(self, params: Dict[str, Any]) -> MCPResponse:
"""Simulate get-memories tool response."""
user_id = params["userId"]
-
+
# Mock user memories
memories = [
{
@@ -149,17 +156,17 @@ async def _handle_get_memories(self, params: Dict[str, Any]) -> MCPResponse:
"content": "User prefers dark mode interface",
"type": "preference",
"created": "2024-01-15T10:30:00Z",
- "tags": ["ui", "accessibility"]
+ "tags": ["ui", "accessibility"],
},
{
- "id": "mem_001235",
+ "id": "mem_001235",
"content": "User is a Python developer",
"type": "fact",
"created": "2024-01-14T15:20:00Z",
- "tags": ["profession", "python"]
- }
+ "tags": ["profession", "python"],
+ },
]
-
+
response_lines = []
for memory in memories:
response_lines.append(
@@ -169,45 +176,45 @@ async def _handle_get_memories(self, params: Dict[str, Any]) -> MCPResponse:
f"Created: {memory['created']}\n"
f"Tags: {', '.join(memory['tags'])}\n---"
)
-
+
response_text = "\n".join(response_lines)
-
+
return MCPResponse(
content=[{"type": "text", "text": response_text}],
metadata={
"totalMemories": len(memories),
"userId": user_id,
"page": 1,
- "hasMore": False
- }
+ "hasMore": False,
+ },
)
-
+
async def _handle_delete_memory(self, params: Dict[str, Any]) -> MCPResponse:
"""Simulate delete-memory tool response."""
memory_id = params["memoryId"]
user_id = params["userId"]
-
+
response_text = f"Memory {memory_id} deleted successfully for user {user_id}"
-
+
return MCPResponse(
content=[{"type": "text", "text": response_text}],
metadata={
"memoryId": memory_id,
"userId": user_id,
- "deletedAt": "2024-01-15T10:30:00Z"
- }
+ "deletedAt": "2024-01-15T10:30:00Z",
+ },
)
-
+
async def _handle_get_stats(self, params: Dict[str, Any]) -> MCPResponse:
"""Simulate get-memory-stats tool response."""
user_id = params["userId"]
-
+
response_text = (
f"Memory Statistics for user: {user_id}\n"
f"Total memories: 45\n"
f"Memory types:\n"
f"- Preferences: 12\n"
- f"- Facts: 18\n"
+ f"- Facts: 18\n"
f"- Context: 8\n"
f"- Behavior: 5\n"
f"- Instructions: 2\n"
@@ -215,7 +222,7 @@ async def _handle_get_stats(self, params: Dict[str, Any]) -> MCPResponse:
f"Cache hit rate: 94.2%\n"
f"Average access time: 28ms"
)
-
+
return MCPResponse(
content=[{"type": "text", "text": response_text}],
metadata={
@@ -226,93 +233,91 @@ async def _handle_get_stats(self, params: Dict[str, Any]) -> MCPResponse:
"fact": 18,
"context": 8,
"behavior": 5,
- "instruction": 2
+ "instruction": 2,
},
"totalSizeBytes": 128409,
"cacheHitRate": 0.942,
- "avgAccessTimeMs": 28
- }
+ "avgAccessTimeMs": 28,
+ },
)
async def demonstrate_mcp_compatibility():
"""Demonstrate MCP compatibility with Mem0-style usage."""
-
+
print("=== FACT Memory MCP Client Demo ===\n")
-
+
# Initialize MCP client
client = FactMemoryMCPClient(
- server_url="http://localhost:8080",
- api_key="your_fact_memory_api_key"
+ server_url="http://localhost:8080", api_key="your_fact_memory_api_key"
)
-
+
user_id = "demo_user_mcp"
-
+
# 1. Add memories (Mem0-compatible)
print("1. Adding memories via MCP...")
-
+
memories_to_add = [
{
"content": "User prefers concise code reviews with actionable feedback",
"userId": user_id,
"memoryType": "instruction",
- "tags": ["code-review", "communication"]
+ "tags": ["code-review", "communication"],
},
{
"content": "User works with React, TypeScript, and Node.js daily",
"userId": user_id,
"memoryType": "fact",
- "tags": ["technology", "frontend", "backend"]
+ "tags": ["technology", "frontend", "backend"],
},
{
"content": "User asks for optimization tips during code discussions",
"userId": user_id,
"memoryType": "behavior",
- "tags": ["optimization", "learning"]
- }
+ "tags": ["optimization", "learning"],
+ },
]
-
+
for memory_data in memories_to_add:
response = await client.tool("add-memory", memory_data)
print(f"ā {response.content[0]['text']}")
-
+
print()
-
+
# 2. Search memories
print("2. Searching memories...")
-
- search_response = await client.tool("search-memories", {
- "query": "How should I provide feedback to this user?",
- "userId": user_id,
- "limit": 5,
- "minRelevance": 0.3
- })
-
+
+ search_response = await client.tool(
+ "search-memories",
+ {
+ "query": "How should I provide feedback to this user?",
+ "userId": user_id,
+ "limit": 5,
+ "minRelevance": 0.3,
+ },
+ )
+
print("Search Results:")
print(search_response.content[0]["text"])
print(f"Search completed in {search_response.metadata['searchTime']}")
print()
-
+
# 3. Get all memories
print("3. Retrieving all memories...")
-
- all_memories_response = await client.tool("get-memories", {
- "userId": user_id,
- "limit": 10,
- "sortBy": "created"
- })
-
+
+ all_memories_response = await client.tool(
+ "get-memories", {"userId": user_id, "limit": 10, "sortBy": "created"}
+ )
+
print("All Memories:")
print(all_memories_response.content[0]["text"])
print()
-
+
# 4. Get memory statistics
print("4. Getting memory statistics...")
-
- stats_response = await client.tool("get-memory-stats", {
- "userId": user_id
- })
-
+
+ stats_response = await client.tool("get-memory-stats", {"userId": user_id})
+
print("Memory Statistics:")
print(stats_response.content[0]["text"])
print()
@@ -320,72 +325,76 @@ async def demonstrate_mcp_compatibility():
async def demonstrate_advanced_mcp_features():
"""Demonstrate FACT Memory's enhanced MCP features."""
-
+
print("=== Advanced FACT Memory Features ===\n")
-
+
client = FactMemoryMCPClient()
user_id = "advanced_user_mcp"
-
+
# Add memory with metadata and tags
print("1. Adding memory with rich metadata...")
-
- enhanced_memory = await client.tool("add-memory", {
- "content": "User prefers functional programming patterns in JavaScript and Python",
- "userId": user_id,
- "memoryType": "preference",
- "tags": ["programming", "functional", "javascript", "python"],
- "metadata": {
- "source": "code_review_session",
- "confidence": 0.95,
- "context": "discussing async/await patterns",
- "related_topics": ["promises", "async", "map", "filter", "reduce"]
- }
- })
-
+
+ enhanced_memory = await client.tool(
+ "add-memory",
+ {
+ "content": "User prefers functional programming patterns in JavaScript and Python",
+ "userId": user_id,
+ "memoryType": "preference",
+ "tags": ["programming", "functional", "javascript", "python"],
+ "metadata": {
+ "source": "code_review_session",
+ "confidence": 0.95,
+ "context": "discussing async/await patterns",
+ "related_topics": ["promises", "async", "map", "filter", "reduce"],
+ },
+ },
+ )
+
print(f"ā Enhanced memory added: {enhanced_memory.metadata['memoryId']}")
print()
-
+
# Advanced search with filtering
print("2. Advanced search with type filtering...")
-
- filtered_search = await client.tool("search-memories", {
- "query": "programming preferences and patterns",
- "userId": user_id,
- "memoryType": "preference",
- "tags": ["programming", "functional"],
- "minRelevance": 0.5,
- "limit": 3
- })
-
+
+ filtered_search = await client.tool(
+ "search-memories",
+ {
+ "query": "programming preferences and patterns",
+ "userId": user_id,
+ "memoryType": "preference",
+ "tags": ["programming", "functional"],
+ "minRelevance": 0.5,
+ "limit": 3,
+ },
+ )
+
print("Filtered Search Results:")
print(filtered_search.content[0]["text"])
print()
-
+
# Batch operations (FACT Memory enhancement)
print("3. Memory statistics with detailed breakdown...")
-
- detailed_stats = await client.tool("get-memory-stats", {
- "userId": user_id
- })
-
+
+ detailed_stats = await client.tool("get-memory-stats", {"userId": user_id})
+
stats_data = detailed_stats.metadata
print(f"Cache Performance:")
print(f" Hit Rate: {stats_data['cacheHitRate']:.1%}")
print(f" Avg Access Time: {stats_data['avgAccessTimeMs']}ms")
print(f"Memory Distribution:")
- for mem_type, count in stats_data['memoryByType'].items():
+ for mem_type, count in stats_data["memoryByType"].items():
print(f" {mem_type.title()}: {count}")
print()
async def demonstrate_mem0_migration():
"""Show how existing Mem0 code works with FACT Memory."""
-
+
print("=== Mem0 Migration Compatibility ===\n")
-
+
# This is exactly how Mem0 clients work - no changes needed!
client = FactMemoryMCPClient()
-
+
print("Existing Mem0 client code (unchanged):")
print("""
# Original Mem0 usage
@@ -399,32 +408,31 @@ async def demonstrate_mem0_migration():
"userId": "alice"
})
""")
-
+
# Execute the exact same code
print("\nExecuting with FACT Memory (100% compatible):")
-
- await client.tool("add-memory", {
- "content": "User prefers dark mode",
- "userId": "alice"
- })
-
- results = await client.tool("search-memories", {
- "query": "interface preferences",
- "userId": "alice"
- })
-
+
+ await client.tool(
+ "add-memory", {"content": "User prefers dark mode", "userId": "alice"}
+ )
+
+ results = await client.tool(
+ "search-memories", {"query": "interface preferences", "userId": "alice"}
+ )
+
print("ā Perfect compatibility - existing Mem0 code works unchanged!")
print("ā But with FACT Memory's superior performance and caching!")
print()
if __name__ == "__main__":
+
async def main():
try:
await demonstrate_mcp_compatibility()
await demonstrate_advanced_mcp_features()
await demonstrate_mem0_migration()
-
+
print("š All MCP examples completed successfully!")
print("\nKey Benefits of FACT Memory over Mem0:")
print("⢠3-5x faster response times (cache-based)")
@@ -432,10 +440,11 @@ async def main():
print("⢠Enhanced memory types and metadata")
print("⢠Superior semantic understanding")
print("⢠Integrated with FACT SDK ecosystem")
-
+
except Exception as e:
print(f"\nā Error running MCP examples: {e}")
import traceback
+
traceback.print_exc()
-
- asyncio.run(main())
\ No newline at end of file
+
+ asyncio.run(main())
diff --git a/fact-memory/examples/performance_comparison.py b/fact-memory/examples/performance_comparison.py
index 019d336..7b40cbc 100644
--- a/fact-memory/examples/performance_comparison.py
+++ b/fact-memory/examples/performance_comparison.py
@@ -16,6 +16,7 @@
@dataclass
class PerformanceMetrics:
"""Performance measurement results."""
+
operation: str
approach: str
response_times: List[float]
@@ -27,33 +28,35 @@ class PerformanceMetrics:
class VectorMemorySimulator:
"""Simulates traditional vector database memory system."""
-
+
def __init__(self):
self.memories = {}
self.embeddings_cache = {}
-
+
async def add_memory(self, user_id: str, content: str) -> str:
"""Simulate vector embedding and storage."""
# Simulate embedding generation (100-200ms)
await asyncio.sleep(0.15)
-
+
memory_id = f"vec_{hash(content) % 1000000:06d}"
self.memories[memory_id] = {
"user_id": user_id,
"content": content,
"embedding": [0.1] * 1536, # Mock embedding
- "created_at": time.time()
+ "created_at": time.time(),
}
return memory_id
-
- async def search_memories(self, user_id: str, query: str, limit: int = 10) -> List[Dict]:
+
+ async def search_memories(
+ self, user_id: str, query: str, limit: int = 10
+ ) -> List[Dict]:
"""Simulate vector similarity search."""
# Simulate query embedding generation
await asyncio.sleep(0.12)
-
+
# Simulate vector similarity computation
await asyncio.sleep(0.08)
-
+
# Return mock results
user_memories = [m for m in self.memories.values() if m["user_id"] == user_id]
return user_memories[:limit]
@@ -61,35 +64,37 @@ async def search_memories(self, user_id: str, query: str, limit: int = 10) -> Li
class FactMemorySimulator:
"""Simulates FACT Memory cache-based system."""
-
+
def __init__(self):
self.cache = {}
self.cache_hits = 0
self.total_requests = 0
-
+
async def add_memory(self, user_id: str, content: str) -> str:
"""Simulate prompt cache storage."""
# Simulate cache write (10-20ms)
await asyncio.sleep(0.015)
-
+
memory_id = f"fact_{hash(content) % 1000000:06d}"
cache_key = f"memory:{user_id}:{memory_id}"
-
+
self.cache[cache_key] = {
"user_id": user_id,
"content": content,
"created_at": time.time(),
- "access_count": 0
+ "access_count": 0,
}
return memory_id
-
- async def search_memories(self, user_id: str, query: str, limit: int = 10) -> List[Dict]:
+
+ async def search_memories(
+ self, user_id: str, query: str, limit: int = 10
+ ) -> List[Dict]:
"""Simulate cache-based search with LLM semantic understanding."""
self.total_requests += 1
-
+
# Check cache for similar queries
query_cache_key = f"search:{user_id}:{hash(query) % 10000}"
-
+
if query_cache_key in self.cache:
# Cache hit - super fast response
self.cache_hits += 1
@@ -99,40 +104,45 @@ async def search_memories(self, user_id: str, query: str, limit: int = 10) -> Li
await asyncio.sleep(0.065) # 65ms LLM processing
# Cache the result
self.cache[query_cache_key] = {"results": "cached"}
-
+
# Return mock results
- user_memories = [m for m in self.cache.values()
- if isinstance(m, dict) and m.get("user_id") == user_id]
+ user_memories = [
+ m
+ for m in self.cache.values()
+ if isinstance(m, dict) and m.get("user_id") == user_id
+ ]
return user_memories[:limit]
-
+
def get_cache_hit_rate(self) -> float:
return self.cache_hits / max(self.total_requests, 1)
-async def measure_performance(system, operation_func, iterations: int = 100) -> PerformanceMetrics:
+async def measure_performance(
+ system, operation_func, iterations: int = 100
+) -> PerformanceMetrics:
"""Measure performance of a specific operation."""
-
+
response_times = []
-
+
for i in range(iterations):
start_time = time.time()
await operation_func(system, i)
end_time = time.time()
-
+
response_times.append((end_time - start_time) * 1000) # Convert to ms
-
+
# Small delay between operations to simulate real usage
await asyncio.sleep(0.001)
-
+
avg_time = statistics.mean(response_times)
p95_time = statistics.quantiles(response_times, n=20)[18] # 95th percentile
-
+
cache_hit_rate = 0.0
- if hasattr(system, 'get_cache_hit_rate'):
+ if hasattr(system, "get_cache_hit_rate"):
cache_hit_rate = system.get_cache_hit_rate()
-
+
throughput = iterations / (sum(response_times) / 1000) # ops per second
-
+
return PerformanceMetrics(
operation=operation_func.__name__,
approach=system.__class__.__name__,
@@ -140,7 +150,7 @@ async def measure_performance(system, operation_func, iterations: int = 100) ->
avg_time=avg_time,
p95_time=p95_time,
cache_hit_rate=cache_hit_rate,
- throughput_ops_per_sec=throughput
+ throughput_ops_per_sec=throughput,
)
@@ -154,153 +164,172 @@ async def add_memory_test(system, iteration: int):
async def search_memory_test(system, iteration: int):
"""Test memory search performance."""
user_id = f"user_{iteration % 10}"
-
+
# Use repeating queries to simulate cache hits
queries = [
"What are the user's preferences?",
"Tell me about the user's work",
"How should I communicate with this user?",
"What technologies does the user use?",
- "What are the user's interests?"
+ "What are the user's interests?",
]
-
+
query = queries[iteration % len(queries)]
await system.search_memories(user_id, query)
async def run_performance_comparison():
"""Run comprehensive performance comparison."""
-
+
print("=== FACT Memory vs Vector Database Performance Comparison ===\n")
-
+
# Initialize systems
vector_system = VectorMemorySimulator()
fact_system = FactMemorySimulator()
-
+
# Warm up both systems with some initial data
print("š Warming up systems with initial data...")
-
+
for i in range(20):
await vector_system.add_memory(f"user_{i % 5}", f"Initial memory {i}")
await fact_system.add_memory(f"user_{i % 5}", f"Initial memory {i}")
-
+
print("ā Systems warmed up\n")
-
+
# Test scenarios
test_scenarios = [
("Add Memory", add_memory_test, 50),
- ("Search Memory", search_memory_test, 100)
+ ("Search Memory", search_memory_test, 100),
]
-
+
results = {}
-
+
for scenario_name, test_func, iterations in test_scenarios:
print(f"š Testing {scenario_name} ({iterations} iterations)...")
-
+
# Test Vector Database approach
print(f" Testing Vector Database approach...")
vector_metrics = await measure_performance(vector_system, test_func, iterations)
-
- # Test FACT Memory approach
+
+ # Test FACT Memory approach
print(f" Testing FACT Memory approach...")
fact_metrics = await measure_performance(fact_system, test_func, iterations)
-
- results[scenario_name] = {
- "vector": vector_metrics,
- "fact": fact_metrics
- }
-
+
+ results[scenario_name] = {"vector": vector_metrics, "fact": fact_metrics}
+
print(f"ā {scenario_name} testing completed\n")
-
+
# Display results
print("š Performance Comparison Results")
print("=" * 60)
-
+
for scenario_name, scenario_results in results.items():
vector_metrics = scenario_results["vector"]
fact_metrics = scenario_results["fact"]
-
+
print(f"\n{scenario_name}:")
- print(f"{'Metric':<25} {'Vector DB':<15} {'FACT Memory':<15} {'Improvement':<15}")
+ print(
+ f"{'Metric':<25} {'Vector DB':<15} {'FACT Memory':<15} {'Improvement':<15}"
+ )
print("-" * 75)
-
+
# Average response time
- improvement = ((vector_metrics.avg_time - fact_metrics.avg_time) / vector_metrics.avg_time) * 100
- print(f"{'Avg Response Time':<25} {vector_metrics.avg_time:<14.1f}ms {fact_metrics.avg_time:<14.1f}ms {improvement:<14.1f}%")
-
+ improvement = (
+ (vector_metrics.avg_time - fact_metrics.avg_time) / vector_metrics.avg_time
+ ) * 100
+ print(
+ f"{'Avg Response Time':<25} {vector_metrics.avg_time:<14.1f}ms {fact_metrics.avg_time:<14.1f}ms {improvement:<14.1f}%"
+ )
+
# P95 response time
- improvement = ((vector_metrics.p95_time - fact_metrics.p95_time) / vector_metrics.p95_time) * 100
- print(f"{'P95 Response Time':<25} {vector_metrics.p95_time:<14.1f}ms {fact_metrics.p95_time:<14.1f}ms {improvement:<14.1f}%")
-
+ improvement = (
+ (vector_metrics.p95_time - fact_metrics.p95_time) / vector_metrics.p95_time
+ ) * 100
+ print(
+ f"{'P95 Response Time':<25} {vector_metrics.p95_time:<14.1f}ms {fact_metrics.p95_time:<14.1f}ms {improvement:<14.1f}%"
+ )
+
# Throughput
- improvement = ((fact_metrics.throughput_ops_per_sec - vector_metrics.throughput_ops_per_sec) / vector_metrics.throughput_ops_per_sec) * 100
- print(f"{'Throughput (ops/sec)':<25} {vector_metrics.throughput_ops_per_sec:<14.1f} {fact_metrics.throughput_ops_per_sec:<14.1f} {improvement:<14.1f}%")
-
+ improvement = (
+ (
+ fact_metrics.throughput_ops_per_sec
+ - vector_metrics.throughput_ops_per_sec
+ )
+ / vector_metrics.throughput_ops_per_sec
+ ) * 100
+ print(
+ f"{'Throughput (ops/sec)':<25} {vector_metrics.throughput_ops_per_sec:<14.1f} {fact_metrics.throughput_ops_per_sec:<14.1f} {improvement:<14.1f}%"
+ )
+
# Cache hit rate (only for FACT Memory)
if fact_metrics.cache_hit_rate > 0:
- print(f"{'Cache Hit Rate':<25} {'N/A':<15} {fact_metrics.cache_hit_rate:<14.1%} {'N/A':<15}")
-
+ print(
+ f"{'Cache Hit Rate':<25} {'N/A':<15} {fact_metrics.cache_hit_rate:<14.1%} {'N/A':<15}"
+ )
+
return results
async def demonstrate_cache_benefits():
"""Demonstrate the benefits of prompt caching."""
-
+
print("\nš Cache Performance Benefits Demo")
print("=" * 50)
-
+
fact_system = FactMemorySimulator()
-
+
# Add some memories
user_id = "cache_demo_user"
for i in range(10):
await fact_system.add_memory(user_id, f"Memory content {i}")
-
+
# Test repeated searches (simulating real usage patterns)
common_queries = [
"What are the user's preferences?",
"Tell me about the user's work style",
- "How should I communicate with this user?"
+ "How should I communicate with this user?",
]
-
+
print(f"\nTesting repeated searches (simulating real usage)...")
search_times = []
-
+
for round_num in range(3):
print(f"\nRound {round_num + 1}:")
round_times = []
-
+
for query in common_queries:
start_time = time.time()
await fact_system.search_memories(user_id, query)
end_time = time.time()
-
+
response_time = (end_time - start_time) * 1000
round_times.append(response_time)
search_times.append(response_time)
-
+
cache_status = "HIT" if fact_system.get_cache_hit_rate() > 0 else "MISS"
- print(f" Query: '{query[:40]}...' - {response_time:.1f}ms ({cache_status})")
-
+ print(
+ f" Query: '{query[:40]}...' - {response_time:.1f}ms ({cache_status})"
+ )
+
avg_round_time = statistics.mean(round_times)
print(f" Round average: {avg_round_time:.1f}ms")
-
+
print(f"\nCache Performance Summary:")
print(f" Overall cache hit rate: {fact_system.get_cache_hit_rate():.1%}")
print(f" Average search time: {statistics.mean(search_times):.1f}ms")
print(f" First search (cold): {search_times[0]:.1f}ms")
print(f" Cached searches: {statistics.mean(search_times[3:]):.1f}ms")
-
+
cache_speedup = search_times[0] / statistics.mean(search_times[3:])
print(f" Cache speedup: {cache_speedup:.1f}x faster")
async def resource_usage_comparison():
"""Compare resource usage between approaches."""
-
+
print("\nš¾ Resource Usage Comparison")
print("=" * 40)
-
+
# Simulate resource usage
vector_db_resources = {
"memory_per_embedding": 1536 * 4, # 1536 dimensions * 4 bytes (float32)
@@ -308,64 +337,91 @@ async def resource_usage_comparison():
"cpu_per_search": 0.85, # High CPU for similarity computation
"storage_per_memory": 8192, # Bytes including metadata and indexing
}
-
+
fact_memory_resources = {
"memory_per_cache_entry": 512, # Efficient cache representation
"index_overhead_factor": 1.1, # Minimal indexing overhead
"cpu_per_search": 0.15, # Low CPU with cache hits
"storage_per_memory": 1024, # Compact storage format
}
-
+
num_memories = 10000
-
+
print(f"\nFor {num_memories:,} memories:")
print(f"{'Resource':<20} {'Vector DB':<15} {'FACT Memory':<15} {'Savings':<15}")
print("-" * 70)
-
+
# Memory usage
- vector_memory = (num_memories * vector_db_resources["memory_per_embedding"] *
- vector_db_resources["index_overhead_factor"]) / (1024 * 1024) # MB
- fact_memory = (num_memories * fact_memory_resources["memory_per_cache_entry"] *
- fact_memory_resources["index_overhead_factor"]) / (1024 * 1024) # MB
-
+ vector_memory = (
+ num_memories
+ * vector_db_resources["memory_per_embedding"]
+ * vector_db_resources["index_overhead_factor"]
+ ) / (
+ 1024 * 1024
+ ) # MB
+ fact_memory = (
+ num_memories
+ * fact_memory_resources["memory_per_cache_entry"]
+ * fact_memory_resources["index_overhead_factor"]
+ ) / (
+ 1024 * 1024
+ ) # MB
+
memory_savings = ((vector_memory - fact_memory) / vector_memory) * 100
- print(f"{'Memory Usage':<20} {vector_memory:<14.1f}MB {fact_memory:<14.1f}MB {memory_savings:<14.1f}%")
-
+ print(
+ f"{'Memory Usage':<20} {vector_memory:<14.1f}MB {fact_memory:<14.1f}MB {memory_savings:<14.1f}%"
+ )
+
# Storage usage
- vector_storage = (num_memories * vector_db_resources["storage_per_memory"]) / (1024 * 1024) # MB
- fact_storage = (num_memories * fact_memory_resources["storage_per_memory"]) / (1024 * 1024) # MB
-
+ vector_storage = (num_memories * vector_db_resources["storage_per_memory"]) / (
+ 1024 * 1024
+ ) # MB
+ fact_storage = (num_memories * fact_memory_resources["storage_per_memory"]) / (
+ 1024 * 1024
+ ) # MB
+
storage_savings = ((vector_storage - fact_storage) / vector_storage) * 100
- print(f"{'Storage Usage':<20} {vector_storage:<14.1f}MB {fact_storage:<14.1f}MB {storage_savings:<14.1f}%")
-
+ print(
+ f"{'Storage Usage':<20} {vector_storage:<14.1f}MB {fact_storage:<14.1f}MB {storage_savings:<14.1f}%"
+ )
+
# CPU usage per search
- cpu_savings = ((vector_db_resources["cpu_per_search"] - fact_memory_resources["cpu_per_search"]) /
- vector_db_resources["cpu_per_search"]) * 100
- print(f"{'CPU per Search':<20} {vector_db_resources["cpu_per_search"]:<14.1f} {fact_memory_resources["cpu_per_search"]:<14.1f} {cpu_savings:<14.1f}%")
+ cpu_savings = (
+ (
+ vector_db_resources["cpu_per_search"]
+ - fact_memory_resources["cpu_per_search"]
+ )
+ / vector_db_resources["cpu_per_search"]
+ ) * 100
+ print(
+ f"{'CPU per Search':<20} {vector_db_resources["cpu_per_search"]:<14.1f} {fact_memory_resources["cpu_per_search"]:<14.1f} {cpu_savings:<14.1f}%"
+ )
if __name__ == "__main__":
+
async def main():
try:
results = await run_performance_comparison()
await demonstrate_cache_benefits()
await resource_usage_comparison()
-
+
print("\n" + "=" * 60)
print("šÆ Key Performance Advantages of FACT Memory:")
print("⢠3-5x faster response times through intelligent caching")
print("⢠85%+ cache hit rates for typical usage patterns")
- print("⢠70%+ reduction in memory and storage requirements")
+ print("⢠70%+ reduction in memory and storage requirements")
print("⢠80%+ reduction in CPU usage per search")
print("⢠No vector database infrastructure needed")
print("⢠Native LLM semantic understanding")
print("⢠Seamless FACT SDK integration")
-
+
print("\nā
Performance comparison completed successfully!")
-
+
except Exception as e:
print(f"\nā Error in performance comparison: {e}")
import traceback
+
traceback.print_exc()
-
- asyncio.run(main())
\ No newline at end of file
+
+ asyncio.run(main())
diff --git a/fact-memory/src/hello_mcp_server.py b/fact-memory/src/hello_mcp_server.py
index 0e8e421..4da5103 100755
--- a/fact-memory/src/hello_mcp_server.py
+++ b/fact-memory/src/hello_mcp_server.py
@@ -21,6 +21,7 @@
except ImportError:
print("FastMCP not available. Install with: pip install fastmcp")
import sys
+
sys.exit(1)
# Configure logging
@@ -41,83 +42,83 @@
- server_info: Provides information about this MCP server
This server demonstrates FastMCP best practices and integration patterns.
- """
+ """,
)
+
@mcp.tool()
async def hello(ctx: Context = None) -> Dict[str, Any]:
"""
Simple hello world tool that returns a greeting message.
-
+
Args:
ctx: FastMCP context for logging and resource access
-
+
Returns:
Dictionary containing a simple hello message and timestamp
"""
if ctx:
await ctx.info("Executing hello tool")
-
+
result = {
"message": "Hello from FACT Memory MCP Server!",
"timestamp": datetime.now().isoformat(),
"server": "FACT Hello World Server",
- "status": "active"
+ "status": "active",
}
-
+
if ctx:
await ctx.info("Hello tool executed successfully")
-
+
return result
+
@mcp.tool()
-async def greet(
- name: str,
- ctx: Context = None
-) -> Dict[str, Any]:
+async def greet(name: str, ctx: Context = None) -> Dict[str, Any]:
"""
Personalized greeting tool that takes a name parameter.
-
+
Args:
name: The name to include in the greeting
ctx: FastMCP context for logging and resource access
-
+
Returns:
Dictionary containing a personalized greeting message
"""
if ctx:
await ctx.info(f"Executing greet tool for name: {name}")
-
+
# Validate input
if not name or not name.strip():
if ctx:
await ctx.error("Empty name provided to greet tool")
return {
"error": "Name parameter is required and cannot be empty",
- "status": "error"
+ "status": "error",
}
-
+
# Clean the name input
clean_name = name.strip().title()
-
+
result = {
"message": f"Hello, {clean_name}! Welcome to the FACT Memory System.",
"greeting_for": clean_name,
"timestamp": datetime.now().isoformat(),
"server": "FACT Hello World Server",
- "status": "success"
+ "status": "success",
}
-
+
if ctx:
await ctx.info(f"Greet tool executed successfully for: {clean_name}")
-
+
return result
+
@mcp.resource("fact://server_info")
async def get_server_info() -> Dict[str, Any]:
"""
Provide comprehensive information about this MCP server.
-
+
Returns:
Dictionary containing server metadata, capabilities, and status
"""
@@ -131,86 +132,79 @@ async def get_server_info() -> Dict[str, Any]:
{
"name": "hello",
"description": "Returns a simple hello message",
- "parameters": []
+ "parameters": [],
},
{
- "name": "greet",
+ "name": "greet",
"description": "Returns a personalized greeting",
"parameters": [
{
"name": "name",
"type": "string",
"required": True,
- "description": "Name to include in greeting"
+ "description": "Name to include in greeting",
}
- ]
- }
+ ],
+ },
],
"resources": [
{
"name": "server_info",
- "description": "Server information and metadata"
+ "description": "Server information and metadata",
}
- ]
+ ],
},
"status": {
"running": True,
"uptime_start": datetime.now().isoformat(),
- "health": "healthy"
+ "health": "healthy",
},
"integration": {
"fact_memory_compatible": True,
"mcp_version": "1.0",
- "transport": "stdio"
+ "transport": "stdio",
},
"contact": {
"documentation": "fact-memory/docs/",
- "repository": "FACT Memory System"
- }
+ "repository": "FACT Memory System",
+ },
}
+
class HelloWorldMCPServer:
"""
Hello World MCP Server class for advanced configuration and lifecycle management.
"""
-
- def __init__(
- self,
- name: str = "FACT Hello World Server",
- debug: bool = False
- ):
+
+ def __init__(self, name: str = "FACT Hello World Server", debug: bool = False):
self.name = name
self.debug = debug
self.mcp = mcp
self._setup_logging()
-
+
def _setup_logging(self):
"""Configure logging based on debug setting."""
level = logging.DEBUG if self.debug else logging.INFO
logging.getLogger().setLevel(level)
-
+
if self.debug:
logger.info(f"Debug mode enabled for {self.name}")
-
+
def start_http(self, host: str = "localhost", port: int = 8080):
"""
Start the MCP server with HTTP transport.
-
+
Args:
host: Server host address
port: Server port number
"""
logger.info(f"Starting {self.name} on {host}:{port}")
try:
- self.mcp.run(
- transport="streamable-http",
- host=host,
- port=port
- )
+ self.mcp.run(transport="streamable-http", host=host, port=port)
except Exception as e:
logger.error(f"Failed to start HTTP server: {e}")
raise
-
+
def run_stdio(self):
"""
Run server with STDIO transport for CLI integration.
@@ -222,33 +216,34 @@ def run_stdio(self):
except Exception as e:
logger.error(f"Failed to start STDIO server: {e}")
raise
-
+
async def test_tools(self):
"""
Test all available tools for development and debugging.
"""
logger.info("Testing MCP tools...")
-
+
try:
# Test hello tool
hello_result = await hello()
logger.info(f"Hello tool result: {hello_result}")
-
+
# Test greet tool
greet_result = await greet("FastMCP Developer")
logger.info(f"Greet tool result: {greet_result}")
-
+
# Test resource
server_info = await get_server_info()
logger.info(f"Server info: {server_info}")
-
+
logger.info("All tools tested successfully!")
return True
-
+
except Exception as e:
logger.error(f"Tool testing failed: {e}")
return False
+
def main():
"""
Main entry point for the Hello World MCP Server.
@@ -256,48 +251,45 @@ def main():
"""
import sys
import argparse
-
+
parser = argparse.ArgumentParser(description="FACT Hello World MCP Server")
parser.add_argument(
- "--mode",
- choices=["stdio", "http", "test"],
+ "--mode",
+ choices=["stdio", "http", "test"],
default="stdio",
- help="Server mode: stdio (default), http, or test"
+ help="Server mode: stdio (default), http, or test",
)
parser.add_argument(
- "--host",
+ "--host",
default="localhost",
- help="Host address for HTTP mode (default: localhost)"
+ help="Host address for HTTP mode (default: localhost)",
)
parser.add_argument(
- "--port",
- type=int,
+ "--port",
+ type=int,
default=8080,
- help="Port number for HTTP mode (default: 8080)"
- )
- parser.add_argument(
- "--debug",
- action="store_true",
- help="Enable debug logging"
+ help="Port number for HTTP mode (default: 8080)",
)
-
+ parser.add_argument("--debug", action="store_true", help="Enable debug logging")
+
args = parser.parse_args()
-
+
# Create server instance
server = HelloWorldMCPServer(debug=args.debug)
-
+
if args.mode == "stdio":
logger.info("Running in STDIO mode for MCP client integration")
server.run_stdio()
-
+
elif args.mode == "http":
logger.info(f"Running in HTTP mode on {args.host}:{args.port}")
server.start_http(args.host, args.port)
-
+
elif args.mode == "test":
logger.info("Running in test mode")
success = asyncio.run(server.test_tools())
sys.exit(0 if success else 1)
+
if __name__ == "__main__":
- main()
\ No newline at end of file
+ main()
diff --git a/fact-memory/src/test_client.py b/fact-memory/src/test_client.py
index 4a70643..7335aa6 100644
--- a/fact-memory/src/test_client.py
+++ b/fact-memory/src/test_client.py
@@ -15,75 +15,79 @@
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
+
async def test_server_directly():
"""
Test the server by importing and calling functions directly.
This simulates what would happen when the server receives MCP calls.
"""
logger.info("Testing server functions directly...")
-
+
# Import the server functions
from hello_mcp_server import hello, greet, get_server_info
-
+
try:
# Test hello tool
logger.info("Testing hello tool...")
hello_result = await hello()
logger.info(f"Hello result: {json.dumps(hello_result, indent=2)}")
-
+
# Test greet tool with valid input
logger.info("Testing greet tool with valid input...")
greet_result = await greet("FastMCP Developer")
logger.info(f"Greet result: {json.dumps(greet_result, indent=2)}")
-
+
# Test greet tool with empty input (error case)
logger.info("Testing greet tool with empty input...")
error_result = await greet("")
logger.info(f"Error result: {json.dumps(error_result, indent=2)}")
-
+
# Test server info resource
logger.info("Testing server info resource...")
server_info = await get_server_info()
logger.info(f"Server info: {json.dumps(server_info, indent=2)}")
-
+
logger.info("ā
All direct tests passed!")
return True
-
+
except Exception as e:
logger.error(f"ā Direct testing failed: {e}")
return False
+
async def test_with_fastmcp_client():
"""
Test using FastMCP client (if available).
This demonstrates the proper way to interact with MCP servers.
"""
logger.info("Testing with FastMCP client...")
-
+
try:
from fastmcp import Client
from hello_mcp_server import mcp
-
+
# Create client and connect to our server
async with Client(mcp) as client:
# Test hello tool
logger.info("Calling hello tool via MCP client...")
hello_response = await client.call_tool("hello", {})
logger.info(f"Hello response: {hello_response}")
-
+
# Test greet tool
logger.info("Calling greet tool via MCP client...")
greet_response = await client.call_tool("greet", {"name": "MCP Client"})
logger.info(f"Greet response: {greet_response}")
-
+
# Test resource access (Note: Resource access may not be available in all FastMCP client versions)
- logger.info("Skipping resource access test - method varies by FastMCP version")
+ logger.info(
+ "Skipping resource access test - method varies by FastMCP version"
+ )
# resource_response = await client.read_resource("fact://server_info")
# logger.info(f"Resource response: {resource_response}")
-
+
logger.info("ā
FastMCP client tests passed!")
return True
-
+
except ImportError:
logger.warning("FastMCP client not available for testing")
return False
@@ -91,89 +95,94 @@ async def test_with_fastmcp_client():
logger.error(f"ā FastMCP client testing failed: {e}")
return False
+
async def benchmark_performance():
"""
Simple performance benchmark for the server tools.
"""
logger.info("Running performance benchmark...")
-
+
from hello_mcp_server import hello, greet
import time
-
+
# Benchmark hello tool
start_time = time.time()
iterations = 1000
-
+
for i in range(iterations):
await hello()
-
+
hello_duration = time.time() - start_time
hello_avg = (hello_duration / iterations) * 1000 # ms per call
-
+
# Benchmark greet tool
start_time = time.time()
-
+
for i in range(iterations):
await greet(f"User{i}")
-
+
greet_duration = time.time() - start_time
greet_avg = (greet_duration / iterations) * 1000 # ms per call
-
+
logger.info(f"Performance Results ({iterations} iterations):")
logger.info(f" Hello tool: {hello_avg:.2f}ms per call")
logger.info(f" Greet tool: {greet_avg:.2f}ms per call")
logger.info(f" Total time: {hello_duration + greet_duration:.2f}s")
+
def validate_server_structure():
"""
Validate that the server follows FastMCP best practices.
"""
logger.info("Validating server structure...")
-
+
from hello_mcp_server import mcp, HelloWorldMCPServer
-
+
# Check server instance
assert mcp.name == "FACT Hello World Server", "Server name mismatch"
assert mcp.instructions, "Server instructions missing"
-
+
# Check tools are registered
tools = mcp._tool_manager._tools
assert "hello" in tools, "Hello tool not registered"
assert "greet" in tools, "Greet tool not registered"
-
+
# Check resources are registered
resources = mcp._resource_manager._resources
- assert any("server_info" in str(uri) for uri in resources), "Server info resource not registered"
-
+ assert any(
+ "server_info" in str(uri) for uri in resources
+ ), "Server info resource not registered"
+
# Check server class
server = HelloWorldMCPServer()
assert server.name == "FACT Hello World Server", "Server class name mismatch"
- assert hasattr(server, 'run_stdio'), "STDIO method missing"
- assert hasattr(server, 'start_http'), "HTTP method missing"
- assert hasattr(server, 'test_tools'), "Test method missing"
-
+ assert hasattr(server, "run_stdio"), "STDIO method missing"
+ assert hasattr(server, "start_http"), "HTTP method missing"
+ assert hasattr(server, "test_tools"), "Test method missing"
+
logger.info("ā
Server structure validation passed!")
+
async def main():
"""
Main test runner that executes all test scenarios.
"""
logger.info("š Starting FACT Hello World MCP Server Tests")
logger.info("=" * 60)
-
+
tests_passed = 0
total_tests = 0
-
+
# Test 1: Direct function calls
total_tests += 1
if await test_server_directly():
tests_passed += 1
-
+
# Test 2: FastMCP client integration
total_tests += 1
if await test_with_fastmcp_client():
tests_passed += 1
-
+
# Test 3: Server structure validation
total_tests += 1
try:
@@ -182,7 +191,7 @@ async def main():
logger.info("ā
Structure validation passed!")
except Exception as e:
logger.error(f"ā Structure validation failed: {e}")
-
+
# Test 4: Performance benchmark
total_tests += 1
try:
@@ -191,11 +200,11 @@ async def main():
logger.info("ā
Performance benchmark completed!")
except Exception as e:
logger.error(f"ā Performance benchmark failed: {e}")
-
+
# Summary
logger.info("=" * 60)
logger.info(f"š Test Summary: {tests_passed}/{total_tests} tests passed")
-
+
if tests_passed == total_tests:
logger.info("š All tests passed! Server is ready for use.")
return True
@@ -203,6 +212,7 @@ async def main():
logger.warning(f"ā ļø {total_tests - tests_passed} tests failed.")
return False
+
if __name__ == "__main__":
success = asyncio.run(main())
- exit(0 if success else 1)
\ No newline at end of file
+ exit(0 if success else 1)
diff --git a/fact-memory/src/test_mcp_integration.py b/fact-memory/src/test_mcp_integration.py
index 3f422bc..4952604 100644
--- a/fact-memory/src/test_mcp_integration.py
+++ b/fact-memory/src/test_mcp_integration.py
@@ -16,12 +16,13 @@
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
+
async def test_mcp_stdio_protocol():
"""
Test the MCP server with proper MCP protocol initialization.
"""
logger.info("Testing MCP STDIO protocol compliance...")
-
+
try:
# Start the server as a subprocess
server_path = Path(__file__).parent / "hello_mcp_server.py"
@@ -31,12 +32,12 @@ async def test_mcp_stdio_protocol():
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
- bufsize=0
+ bufsize=0,
)
-
+
# Wait a moment for server to start
await asyncio.sleep(0.5)
-
+
# Send MCP initialization
init_request = {
"jsonrpc": "2.0",
@@ -44,24 +45,16 @@ async def test_mcp_stdio_protocol():
"method": "initialize",
"params": {
"protocolVersion": "2024-11-05",
- "capabilities": {
- "roots": {
- "listChanged": True
- },
- "sampling": {}
- },
- "clientInfo": {
- "name": "test-client",
- "version": "1.0.0"
- }
- }
+ "capabilities": {"roots": {"listChanged": True}, "sampling": {}},
+ "clientInfo": {"name": "test-client", "version": "1.0.0"},
+ },
}
-
+
# Send the request
request_json = json.dumps(init_request) + "\n"
process.stdin.write(request_json)
process.stdin.flush()
-
+
# Wait for response with timeout
start_time = time.time()
response_line = ""
@@ -74,24 +67,26 @@ async def test_mcp_stdio_protocol():
if chunk:
response_line += chunk
# Look for complete JSON message
- if '\n' in response_line:
+ if "\n" in response_line:
# Get the first complete line
- lines = response_line.split('\n')
+ lines = response_line.split("\n")
response_line = lines[0]
break
await asyncio.sleep(0.01)
except:
break
-
+
# Clean up
process.terminate()
process.wait(timeout=2)
-
+
if response_line:
try:
response = json.loads(response_line.strip())
- logger.info(f"MCP Initialize Response: {json.dumps(response, indent=2)}")
-
+ logger.info(
+ f"MCP Initialize Response: {json.dumps(response, indent=2)}"
+ )
+
# Validate response structure
if "result" in response and "protocolVersion" in response["result"]:
logger.info("ā
MCP protocol initialization successful!")
@@ -106,142 +101,146 @@ async def test_mcp_stdio_protocol():
else:
logger.error("ā No response received from MCP server")
return False
-
+
except Exception as e:
logger.error(f"ā MCP protocol test failed: {e}")
return False
+
async def test_error_handling():
"""
Test error handling for invalid inputs.
"""
logger.info("Testing error handling...")
-
+
from hello_mcp_server import greet
-
+
try:
# Test empty name
result = await greet("")
assert "error" in result, "Empty name should return error"
logger.info("ā
Empty name error handling works")
-
+
# Test whitespace-only name
result = await greet(" ")
assert "error" in result, "Whitespace name should return error"
logger.info("ā
Whitespace name error handling works")
-
+
# Test valid name
result = await greet("Test User")
assert "message" in result, "Valid name should return success"
logger.info("ā
Valid name handling works")
-
+
return True
-
+
except Exception as e:
logger.error(f"ā Error handling test failed: {e}")
return False
+
async def test_resource_access():
"""
Test resource access functionality.
"""
logger.info("Testing resource access...")
-
+
from hello_mcp_server import get_server_info
-
+
try:
server_info = await get_server_info()
-
+
# Validate required fields
required_fields = ["name", "version", "capabilities", "status"]
for field in required_fields:
assert field in server_info, f"Missing required field: {field}"
-
+
# Validate capabilities structure
caps = server_info["capabilities"]
assert "tools" in caps, "Missing tools in capabilities"
assert "resources" in caps, "Missing resources in capabilities"
-
+
# Validate tool information
tools = caps["tools"]
tool_names = [tool["name"] for tool in tools]
assert "hello" in tool_names, "Missing hello tool"
assert "greet" in tool_names, "Missing greet tool"
-
+
logger.info("ā
Resource access validation passed")
return True
-
+
except Exception as e:
logger.error(f"ā Resource access test failed: {e}")
return False
+
async def test_performance_stress():
"""
Test performance under load.
"""
logger.info("Testing performance under stress...")
-
+
from hello_mcp_server import hello, greet
-
+
try:
# Test rapid-fire calls
start_time = time.time()
tasks = []
-
+
# Create 100 concurrent hello calls
for i in range(100):
tasks.append(hello())
-
+
# Create 100 concurrent greet calls
for i in range(100):
tasks.append(greet(f"User{i}"))
-
+
# Execute all tasks concurrently
results = await asyncio.gather(*tasks)
-
+
duration = time.time() - start_time
avg_time = (duration / len(tasks)) * 1000 # ms per call
-
+
logger.info(f"Stress test completed: {len(tasks)} calls in {duration:.2f}s")
logger.info(f"Average response time: {avg_time:.2f}ms per call")
-
+
# Validate all results are successful
success_count = 0
for result in results:
if "message" in result or "greeting_for" in result:
success_count += 1
-
+
success_rate = (success_count / len(results)) * 100
logger.info(f"Success rate: {success_rate:.1f}%")
-
+
if success_rate >= 99:
logger.info("ā
Stress test passed")
return True
else:
logger.warning("ā ļø Stress test had some failures")
return False
-
+
except Exception as e:
logger.error(f"ā Stress test failed: {e}")
return False
+
async def main():
"""
Run comprehensive MCP integration tests.
"""
logger.info("š Starting MCP Integration Tests")
logger.info("=" * 60)
-
+
tests = [
("MCP Protocol Compliance", test_mcp_stdio_protocol),
("Error Handling", test_error_handling),
("Resource Access", test_resource_access),
("Performance Stress", test_performance_stress),
]
-
+
passed = 0
total = len(tests)
-
+
for test_name, test_func in tests:
logger.info(f"\nš§Ŗ Running: {test_name}")
try:
@@ -252,10 +251,10 @@ async def main():
logger.error(f"ā {test_name} FAILED")
except Exception as e:
logger.error(f"ā {test_name} FAILED with exception: {e}")
-
+
logger.info("=" * 60)
logger.info(f"š Integration Test Summary: {passed}/{total} tests passed")
-
+
if passed == total:
logger.info("š All integration tests passed! MCP server is production ready.")
return True
@@ -263,6 +262,7 @@ async def main():
logger.warning(f"ā ļø {total - passed} integration tests failed.")
return False
+
if __name__ == "__main__":
success = asyncio.run(main())
- exit(0 if success else 1)
\ No newline at end of file
+ exit(0 if success else 1)
diff --git a/main.py b/main.py
index 1f46bf9..f9ecdae 100644
--- a/main.py
+++ b/main.py
@@ -12,7 +12,7 @@
import argparse
# Add src directory to Python path
-sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'src'))
+sys.path.insert(0, os.path.join(os.path.dirname(__file__), "src"))
from src.core.cli import main as cli_main
from src.core.config import get_config
@@ -23,22 +23,22 @@ async def init_command():
"""Initialize the FACT system environment."""
try:
print("š Initializing FACT System...")
-
+
# Get configuration
config = get_config()
print(f"ā
Configuration loaded")
print(f" ⢠Database: {config.database_path}")
print(f" ⢠Model: {config.claude_model}")
-
+
# Initialize driver
driver = await get_driver(config)
print("ā
System initialized successfully")
print(" ⢠Database schema ready")
print(" ⢠Tools registered")
print(" ⢠Ready for queries")
-
+
return 0
-
+
except Exception as e:
print(f"ā Initialization failed: {e}")
return 1
@@ -48,17 +48,17 @@ async def demo_command():
"""Run a demonstration of the FACT system."""
try:
print("šŖ Running FACT System Demo...")
-
+
# Initialize system
driver = await get_driver()
-
+
# Demo queries
demo_queries = [
"Show me all companies in our database",
"What is TechCorp's latest revenue?",
- "List all financial records for Q1 2025"
+ "List all financial records for Q1 2025",
]
-
+
for i, query in enumerate(demo_queries, 1):
print(f"\nš Demo Query {i}: {query}")
try:
@@ -66,10 +66,10 @@ async def demo_command():
print(f"š Response: {response}")
except Exception as e:
print(f"ā Query failed: {e}")
-
+
print("\nā
Demo completed")
return 0
-
+
except Exception as e:
print(f"ā Demo failed: {e}")
return 1
@@ -79,25 +79,21 @@ async def main():
"""Main entry point with command routing."""
parser = argparse.ArgumentParser(
description="FACT System - Fast-Access Cached Tools",
- formatter_class=argparse.RawDescriptionHelpFormatter
+ formatter_class=argparse.RawDescriptionHelpFormatter,
)
-
+
parser.add_argument(
"command",
nargs="?",
choices=["init", "demo", "interactive"],
default="interactive",
- help="Command to execute (default: interactive)"
- )
-
- parser.add_argument(
- "--query",
- type=str,
- help="Process a single query and exit"
+ help="Command to execute (default: interactive)",
)
-
+
+ parser.add_argument("--query", type=str, help="Process a single query and exit")
+
args = parser.parse_args()
-
+
try:
if args.command == "init":
return await init_command()
@@ -109,7 +105,7 @@ async def main():
else:
parser.print_help()
return 1
-
+
except KeyboardInterrupt:
print("\nš Interrupted by user")
return 0
@@ -121,7 +117,7 @@ async def main():
if __name__ == "__main__":
"""
Main entry point for the FACT system.
-
+
Usage:
python main.py # Interactive mode
python main.py init # Initialize system
@@ -129,4 +125,4 @@ async def main():
python main.py --query "..." # Single query mode
"""
exit_code = asyncio.run(main())
- sys.exit(exit_code)
\ No newline at end of file
+ sys.exit(exit_code)
diff --git a/requirements.txt b/requirements.txt
index 8b2d73f..2a7736b 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,10 +1,10 @@
# FACT System Dependencies
# Core dependencies
-anthropic==0.19.1
-litellm==1.0.0
-aiohttp==3.9.5
-python-dotenv==1.0.1
-pydantic==2.8.2
+anthropic>=0.49.0
+litellm>=1.61.15
+aiohttp>=3.13.4
+python-dotenv>=1.2.2
+pydantic>=2.10.6
# Database
aiosqlite==0.20.0
diff --git a/scripts/demo_cache_resilience.py b/scripts/demo_cache_resilience.py
index 1bcf1ce..b86f523 100644
--- a/scripts/demo_cache_resilience.py
+++ b/scripts/demo_cache_resilience.py
@@ -29,8 +29,11 @@
sys.path.insert(0, src_path)
from cache.resilience import (
- CacheCircuitBreaker, CircuitBreakerConfig, ResilientCacheWrapper,
- CircuitState, FailureRecord
+ CacheCircuitBreaker,
+ CircuitBreakerConfig,
+ ResilientCacheWrapper,
+ CircuitState,
+ FailureRecord,
)
from cache.manager import CacheManager
from core.errors import CacheError
@@ -41,7 +44,7 @@
processors=[
structlog.processors.TimeStamper(fmt="ISO"),
structlog.processors.add_log_level,
- structlog.processors.JSONRenderer()
+ structlog.processors.JSONRenderer(),
],
logger_factory=structlog.PrintLoggerFactory(),
wrapper_class=structlog.BoundLogger,
@@ -53,86 +56,96 @@
class DemoCacheManager:
"""Demo cache manager that can simulate various failure scenarios."""
-
+
def __init__(self):
self.cache = {}
self.failure_mode = "normal" # normal, intermittent, total_failure
self.operation_count = 0
self.failure_probability = 0.0
-
+
def set_failure_mode(self, mode: str, probability: float = 0.0):
"""Set the failure mode for demonstration."""
self.failure_mode = mode
self.failure_probability = probability
self.operation_count = 0
-
+
mode_descriptions = {
"normal": "Normal operation - no failures",
"intermittent": f"Intermittent failures - {probability*100}% failure rate",
- "total_failure": "Total failure - all operations fail"
+ "total_failure": "Total failure - all operations fail",
}
-
+
print(f"\nš§ Cache Mode Changed: {mode_descriptions.get(mode, mode)}")
-
+
def _should_fail(self) -> bool:
"""Determine if this operation should fail based on current mode."""
self.operation_count += 1
-
+
if self.failure_mode == "normal":
return False
elif self.failure_mode == "total_failure":
return True
elif self.failure_mode == "intermittent":
import random
+
return random.random() < self.failure_probability
-
+
return False
-
+
async def store(self, query_hash: str, content: str):
"""Store with potential failure simulation."""
await asyncio.sleep(0.01) # Simulate network latency
-
+
if self._should_fail():
- raise CacheError("Simulated cache store failure", error_code="CACHE_STORE_FAILED")
-
- self.cache[query_hash] = {
- 'content': content,
- 'timestamp': time.time()
- }
+ raise CacheError(
+ "Simulated cache store failure", error_code="CACHE_STORE_FAILED"
+ )
+
+ self.cache[query_hash] = {"content": content, "timestamp": time.time()}
return True
-
+
async def get(self, query_hash: str):
"""Get with potential failure simulation."""
await asyncio.sleep(0.01) # Simulate network latency
-
+
if self._should_fail():
- raise CacheError("Simulated cache get failure", error_code="CACHE_GET_FAILED")
-
+ raise CacheError(
+ "Simulated cache get failure", error_code="CACHE_GET_FAILED"
+ )
+
if query_hash in self.cache:
- return type('CacheEntry', (), {
- 'content': self.cache[query_hash]['content'],
- 'timestamp': self.cache[query_hash]['timestamp']
- })()
+ return type(
+ "CacheEntry",
+ (),
+ {
+ "content": self.cache[query_hash]["content"],
+ "timestamp": self.cache[query_hash]["timestamp"],
+ },
+ )()
return None
-
+
async def invalidate_by_prefix(self, prefix: str) -> int:
"""Invalidate with potential failure simulation."""
await asyncio.sleep(0.01) # Simulate network latency
-
+
if self._should_fail():
- raise CacheError("Simulated cache invalidate failure", error_code="CACHE_INVALIDATE_FAILED")
-
+ raise CacheError(
+ "Simulated cache invalidate failure",
+ error_code="CACHE_INVALIDATE_FAILED",
+ )
+
count = 0
keys_to_remove = [k for k in self.cache.keys() if k.startswith(prefix)]
for key in keys_to_remove:
del self.cache[key]
count += 1
-
+
return count
-
+
def generate_hash(self, query: str) -> str:
"""Generate hash for query."""
import hashlib
+
return hashlib.sha256(query.encode()).hexdigest()[:16]
@@ -143,11 +156,15 @@ def print_banner(title: str):
print(f"{'='*60}")
-def print_metrics(circuit_breaker: CacheCircuitBreaker, resilient_cache: ResilientCacheWrapper, title: str = "Current Metrics"):
+def print_metrics(
+ circuit_breaker: CacheCircuitBreaker,
+ resilient_cache: ResilientCacheWrapper,
+ title: str = "Current Metrics",
+):
"""Print formatted metrics."""
cb_metrics = circuit_breaker.get_metrics()
cache_metrics = resilient_cache.get_metrics()
-
+
print(f"\nš {title}")
print(f"āā Circuit Breaker State: {cb_metrics.state.value.upper()}")
print(f"āā Total Operations: {cb_metrics.total_operations}")
@@ -156,11 +173,15 @@ def print_metrics(circuit_breaker: CacheCircuitBreaker, resilient_cache: Resilie
print(f"āā Failure Rate: {cb_metrics.failure_rate:.2%}")
print(f"āā State Changes: {cb_metrics.state_changes}")
print(f"āā Recent Failures: {len(cb_metrics.recent_failures)}")
- print(f"āā Graceful Degradation: {'Enabled' if resilient_cache.enable_graceful_degradation else 'Disabled'}")
-
- if 'cache' in cache_metrics:
- cache_stats = cache_metrics['cache']
- print(f"āā Cache Stats: {json.dumps(cache_stats, indent=2) if isinstance(cache_stats, dict) else cache_stats}")
+ print(
+ f"āā Graceful Degradation: {'Enabled' if resilient_cache.enable_graceful_degradation else 'Disabled'}"
+ )
+
+ if "cache" in cache_metrics:
+ cache_stats = cache_metrics["cache"]
+ print(
+ f"āā Cache Stats: {json.dumps(cache_stats, indent=2) if isinstance(cache_stats, dict) else cache_stats}"
+ )
else:
print(f"āā Cache Stats: Not available")
@@ -168,7 +189,7 @@ def print_metrics(circuit_breaker: CacheCircuitBreaker, resilient_cache: Resilie
async def demo_normal_operation():
"""Demonstrate normal cache operation."""
print_banner("1. Normal Cache Operation")
-
+
# Configure circuit breaker for demo
config = CircuitBreakerConfig(
failure_threshold=3,
@@ -176,23 +197,23 @@ async def demo_normal_operation():
timeout_seconds=3.0,
rolling_window_seconds=60.0,
gradual_recovery=True,
- recovery_factor=0.5
+ recovery_factor=0.5,
)
-
+
circuit_breaker = CacheCircuitBreaker(config)
demo_cache = DemoCacheManager()
resilient_cache = ResilientCacheWrapper(demo_cache, circuit_breaker)
-
+
print("ā
Cache resilience system initialized")
print("š Circuit breaker configured with:")
print(f" ⢠Failure threshold: {config.failure_threshold}")
print(f" ⢠Success threshold: {config.success_threshold}")
print(f" ⢠Timeout: {config.timeout_seconds}s")
-
+
demo_cache.set_failure_mode("normal")
-
+
print("\nš Performing normal cache operations...")
-
+
# Perform successful operations
operations = [
("store", "query1", "Sample query result 1"),
@@ -201,7 +222,7 @@ async def demo_normal_operation():
("store", "query3", "Sample query result 3"),
("get", "query2", None),
]
-
+
for i, (op, key, value) in enumerate(operations, 1):
try:
if op == "store":
@@ -213,21 +234,23 @@ async def demo_normal_operation():
print(f" {i}. ā
Get '{key}': {content}")
except Exception as e:
print(f" {i}. ā {op.title()} '{key}': {e}")
-
+
print_metrics(circuit_breaker, resilient_cache, "After Normal Operations")
-
+
return circuit_breaker, demo_cache, resilient_cache
-async def demo_failure_and_circuit_breaker(circuit_breaker, demo_cache, resilient_cache):
+async def demo_failure_and_circuit_breaker(
+ circuit_breaker, demo_cache, resilient_cache
+):
"""Demonstrate cache failures and circuit breaker activation."""
print_banner("2. Cache Failures & Circuit Breaker Activation")
-
+
print("ā ļø Introducing cache failures to trigger circuit breaker...")
demo_cache.set_failure_mode("total_failure")
-
+
print("\nš„ Attempting operations with failing cache:")
-
+
# Attempt operations that will fail
failure_operations = [
("store", "fail1", "This will fail"),
@@ -236,7 +259,7 @@ async def demo_failure_and_circuit_breaker(circuit_breaker, demo_cache, resilien
("store", "fail3", "This will trigger circuit breaker"),
("get", "fail1", None),
]
-
+
for i, (op, key, value) in enumerate(failure_operations, 1):
try:
if op == "store":
@@ -244,7 +267,9 @@ async def demo_failure_and_circuit_breaker(circuit_breaker, demo_cache, resilien
print(f" {i}. ā
Store '{key}': {result} (graceful degradation)")
elif op == "get":
result = await resilient_cache.get(key)
- content = result.content if result else "Cache miss (graceful degradation)"
+ content = (
+ result.content if result else "Cache miss (graceful degradation)"
+ )
print(f" {i}. ā
Get '{key}': {content}")
except CacheError as e:
if "CIRCUIT_BREAKER_OPEN" in str(e.error_code):
@@ -253,32 +278,36 @@ async def demo_failure_and_circuit_breaker(circuit_breaker, demo_cache, resilien
print(f" {i}. ā {op.title()} '{key}': {e}")
except Exception as e:
print(f" {i}. ā {op.title()} '{key}': Unexpected error: {e}")
-
+
state = circuit_breaker.get_state()
if state == CircuitState.OPEN:
- print("\nš“ Circuit breaker is now OPEN - protecting system from cascading failures")
-
- print_metrics(circuit_breaker, resilient_cache, "After Failures (Circuit Breaker Open)")
+ print(
+ "\nš“ Circuit breaker is now OPEN - protecting system from cascading failures"
+ )
+
+ print_metrics(
+ circuit_breaker, resilient_cache, "After Failures (Circuit Breaker Open)"
+ )
async def demo_graceful_degradation(circuit_breaker, demo_cache, resilient_cache):
"""Demonstrate graceful degradation when circuit breaker is open."""
print_banner("3. Graceful Degradation")
-
+
print("š”ļø Testing graceful degradation with circuit breaker open...")
-
+
# Ensure graceful degradation is enabled
resilient_cache.enable_graceful_degradation = True
-
+
degradation_operations = [
("store", "degraded1", "Graceful degradation response"),
("get", "nonexistent", None),
("invalidate_by_prefix", "test_", None),
("store", "degraded2", "Another graceful response"),
]
-
+
print("\nš Performing operations with graceful degradation:")
-
+
for i, (op, key, value) in enumerate(degradation_operations, 1):
try:
if op == "store":
@@ -292,37 +321,37 @@ async def demo_graceful_degradation(circuit_breaker, demo_cache, resilient_cache
print(f" {i}. š Invalidate '{key}': {result} (fallback response)")
except Exception as e:
print(f" {i}. ā {op.title()} '{key}': {e}")
-
+
print("\nā
All operations completed successfully using graceful degradation")
print(" ⢠System remains responsive despite cache failures")
print(" ⢠No exceptions propagated to application layer")
print(" ⢠Users experience degraded but functional service")
-
+
print_metrics(circuit_breaker, resilient_cache, "During Graceful Degradation")
async def demo_recovery(circuit_breaker, demo_cache, resilient_cache):
"""Demonstrate recovery from failures."""
print_banner("4. Recovery from Failures")
-
+
print("ā° Waiting for circuit breaker timeout...")
print(" (In production, this would be longer - using shorter timeout for demo)")
-
+
# Wait for circuit breaker timeout
await asyncio.sleep(3.5)
-
+
print("\nš§ Fixing cache and attempting recovery...")
demo_cache.set_failure_mode("normal")
-
+
print("\nš Attempting operations to trigger recovery:")
-
+
recovery_operations = [
("store", "recovery1", "Recovery test 1"),
("get", "recovery1", None),
("store", "recovery2", "Recovery test 2"),
("store", "recovery3", "Recovery test 3"),
]
-
+
for i, (op, key, value) in enumerate(recovery_operations, 1):
try:
if op == "store":
@@ -332,35 +361,35 @@ async def demo_recovery(circuit_breaker, demo_cache, resilient_cache):
result = await resilient_cache.get(key)
content = result.content if result else "Not found"
print(f" {i}. ā
Get '{key}': {content}")
-
+
state = circuit_breaker.get_state()
print(f" Circuit state: {state.value}")
-
+
except Exception as e:
print(f" {i}. ā {op.title()} '{key}': {e}")
-
+
final_state = circuit_breaker.get_state()
if final_state == CircuitState.CLOSED:
print("\nš¢ Circuit breaker successfully recovered to CLOSED state!")
elif final_state == CircuitState.HALF_OPEN:
print("\nš” Circuit breaker is in HALF_OPEN state (testing recovery)")
-
+
print_metrics(circuit_breaker, resilient_cache, "After Recovery")
async def demo_intermittent_failures(circuit_breaker, demo_cache, resilient_cache):
"""Demonstrate handling of intermittent failures."""
print_banner("5. Intermittent Failures Handling")
-
+
print("š Testing with intermittent failures (30% failure rate)...")
demo_cache.set_failure_mode("intermittent", 0.3)
-
+
print("\nš Running 20 operations with 30% failure rate:")
-
+
success_count = 0
failure_count = 0
degraded_count = 0
-
+
for i in range(1, 21):
try:
result = await resilient_cache.store(f"intermittent_{i}", f"test data {i}")
@@ -380,13 +409,13 @@ async def demo_intermittent_failures(circuit_breaker, demo_cache, resilient_cach
except Exception as e:
failure_count += 1
print(f" {i:2d}. ā Unexpected: {e}")
-
+
print(f"\nš Results Summary:")
print(f" ⢠Successful operations: {success_count}")
print(f" ⢠Failed operations: {failure_count}")
print(f" ⢠Degraded responses: {degraded_count}")
print(f" ⢠Total operations: {success_count + failure_count + degraded_count}")
-
+
print_metrics(circuit_breaker, resilient_cache, "After Intermittent Failures Test")
@@ -397,15 +426,17 @@ async def main():
print("This demo shows how the FACT system handles cache failures")
print("gracefully using circuit breaker patterns and degradation strategies.")
print("=" * 60)
-
+
try:
# Run demonstration scenarios
circuit_breaker, demo_cache, resilient_cache = await demo_normal_operation()
- await demo_failure_and_circuit_breaker(circuit_breaker, demo_cache, resilient_cache)
+ await demo_failure_and_circuit_breaker(
+ circuit_breaker, demo_cache, resilient_cache
+ )
await demo_graceful_degradation(circuit_breaker, demo_cache, resilient_cache)
await demo_recovery(circuit_breaker, demo_cache, resilient_cache)
await demo_intermittent_failures(circuit_breaker, demo_cache, resilient_cache)
-
+
# Final summary
print_banner("Demo Complete - Summary")
print("ā
All cache resilience features demonstrated successfully:")
@@ -415,10 +446,10 @@ async def main():
print(" ⢠Automatic recovery from failures")
print(" ⢠Handling of intermittent failure scenarios")
print("\nš The FACT cache resilience system is working as designed!")
-
+
# Final metrics
print_metrics(circuit_breaker, resilient_cache, "Final System State")
-
+
except Exception as e:
logger.error("Demo failed", error=str(e), exc_info=True)
print(f"\nā Demo failed with error: {e}")
@@ -426,4 +457,4 @@ async def main():
if __name__ == "__main__":
- asyncio.run(main())
\ No newline at end of file
+ asyncio.run(main())
diff --git a/scripts/demo_lifecycle.py b/scripts/demo_lifecycle.py
index 029e968..9be7592 100644
--- a/scripts/demo_lifecycle.py
+++ b/scripts/demo_lifecycle.py
@@ -16,7 +16,7 @@
from pathlib import Path
# Add src to path for imports
-sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'src'))
+sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "src"))
import sys
import os
@@ -24,30 +24,31 @@
import sqlite3
# Add src to path for direct imports
-src_path = os.path.join(os.path.dirname(__file__), '..', 'src')
+src_path = os.path.join(os.path.dirname(__file__), "..", "src")
sys.path.insert(0, src_path)
print("š Running FACT System Integration Demo...")
print("=" * 50)
+
# Direct database demo since imports have complex relative dependencies
def demo_database_functionality():
"""Demonstrate core database functionality directly."""
try:
- db_path = os.path.join(os.path.dirname(__file__), '..', 'data', 'fact_demo.db')
+ db_path = os.path.join(os.path.dirname(__file__), "..", "data", "fact_demo.db")
if not os.path.exists(db_path):
print("ā Database not found. Run 'python main.py init' first.")
return False
-
+
conn = sqlite3.connect(db_path)
cursor = conn.cursor()
-
+
print("š Companies in Database:")
cursor.execute("SELECT id, name, sector FROM companies")
companies = cursor.fetchall()
for company in companies:
print(f" {company[0]}. {company[1]} ({company[2]})")
-
+
print("\nš° Sample Financial Query - TechCorp Revenue:")
cursor.execute("""
SELECT year, quarter, revenue, profit
@@ -58,8 +59,10 @@ def demo_database_functionality():
""")
records = cursor.fetchall()
for record in records:
- print(f" {record[0]} Q{record[1]}: Revenue ${record[2]:,.0f}, Profit ${record[3]:,.0f}")
-
+ print(
+ f" {record[0]} Q{record[1]}: Revenue ${record[2]:,.0f}, Profit ${record[3]:,.0f}"
+ )
+
print("\nš Analytics Query - Average Revenue by Sector:")
cursor.execute("""
SELECT c.sector, AVG(f.revenue) as avg_revenue
@@ -71,17 +74,18 @@ def demo_database_functionality():
analytics = cursor.fetchall()
for analytic in analytics:
print(f" {analytic[0]}: ${analytic[1]:,.0f} average revenue")
-
+
conn.close()
-
+
print("\nā
Database functionality working perfectly!")
print("ā
FACT system core features operational!")
return True
-
+
except Exception as e:
print(f"ā Demo failed: {e}")
return False
+
# Run the demo
success = demo_database_functionality()
if success:
@@ -98,37 +102,37 @@ async def demo_system_initialization():
"""Demonstrate system initialization and component integration."""
print("š FACT System MVP Demo - Component Integration")
print("=" * 60)
-
+
print("\n1. š§ System Initialization")
print("-" * 30)
-
+
try:
# Get driver instance (triggers full system initialization)
driver = await get_driver()
config = get_config()
metrics_collector = get_metrics_collector()
-
+
print("ā
Core driver initialized")
print("ā
Configuration loaded")
print("ā
Database connected")
print("ā
Tools registered")
print("ā
Monitoring active")
-
+
# Show configuration summary
print(f"\nš Configuration Summary:")
print(f" Database: {config.database_path}")
print(f" Model: {config.claude_model}")
print(f" Cache Prefix: {config.cache_prefix}")
-
+
# Show registered tools
tool_names = driver.tool_registry.list_tools()
print(f"\nš ļø Registered Tools ({len(tool_names)}):")
for tool_name in tool_names:
tool_def = driver.tool_registry.get_tool(tool_name)
print(f" ⢠{tool_name}: {tool_def.description}")
-
+
return driver
-
+
except Exception as e:
print(f"ā System initialization failed: {e}")
return None
@@ -138,21 +142,21 @@ async def demo_database_connectivity(driver):
"""Demonstrate database connectivity and schema inspection."""
print("\n2. šļø Database Connectivity")
print("-" * 30)
-
+
try:
# Get database info
db_info = await driver.database_manager.get_database_info()
-
+
print(f"ā
Database connected: {db_info['database_path']}")
print(f"š Total tables: {db_info['total_tables']}")
print(f"š¾ File size: {db_info['file_size_bytes']} bytes")
-
+
print(f"\nš Table Information:")
- for table_name, table_info in db_info['tables'].items():
+ for table_name, table_info in db_info["tables"].items():
print(f" ⢠{table_name}: {table_info['row_count']} rows")
-
+
return True
-
+
except Exception as e:
print(f"ā Database connectivity failed: {e}")
return False
@@ -162,50 +166,50 @@ async def demo_tool_execution(driver):
"""Demonstrate tool execution framework with sample queries."""
print("\n3. š§ Tool Execution Framework")
print("-" * 30)
-
+
# Sample queries to demonstrate different scenarios
sample_queries = [
{
"description": "Schema inspection",
"query": "Get the database schema to understand available tables",
- "expected_tool": "SQL.GetSchema"
+ "expected_tool": "SQL.GetSchema",
},
{
"description": "Company lookup",
"query": "Show me all companies in the Technology sector",
- "expected_tool": "SQL.QueryReadonly"
+ "expected_tool": "SQL.QueryReadonly",
},
{
"description": "Financial analysis",
"query": "What was TechCorp's Q1 2025 revenue and profit?",
- "expected_tool": "SQL.QueryReadonly"
+ "expected_tool": "SQL.QueryReadonly",
},
{
"description": "Sample queries",
"query": "Show me some example queries I can run",
- "expected_tool": "SQL.GetSampleQueries"
- }
+ "expected_tool": "SQL.GetSampleQueries",
+ },
]
-
+
for i, sample in enumerate(sample_queries, 1):
print(f"\n Query {i}: {sample['description']}")
print(f" User Input: \"{sample['query']}\"")
-
+
try:
start_time = time.time()
-
+
# Process query through the full FACT pipeline
- response = await driver.process_query(sample['query'])
-
+ response = await driver.process_query(sample["query"])
+
end_time = time.time()
processing_time = (end_time - start_time) * 1000
-
+
print(f" ā
Response generated ({processing_time:.1f}ms)")
print(f" š Response preview: {response[:150]}...")
-
+
except Exception as e:
print(f" ā Query failed: {e}")
-
+
return True
@@ -213,34 +217,36 @@ async def demo_monitoring_metrics(driver):
"""Demonstrate monitoring and metrics collection."""
print("\n4. š Monitoring & Metrics")
print("-" * 30)
-
+
try:
# Get system metrics from driver
system_metrics = driver.get_metrics()
-
+
print("š System Performance:")
print(f" ⢠Total queries: {system_metrics['total_queries']}")
print(f" ⢠Tool executions: {system_metrics['tool_executions']}")
print(f" ⢠Error rate: {system_metrics['error_rate']:.1f}%")
print(f" ⢠System initialized: {system_metrics['initialized']}")
-
+
# Get detailed metrics from metrics collector
metrics_collector = get_metrics_collector()
detailed_metrics = metrics_collector.get_system_metrics(time_window_minutes=5)
-
+
print(f"\nš Detailed Metrics (last 5 minutes):")
print(f" ⢠Total executions: {detailed_metrics.total_executions}")
print(f" ⢠Successful: {detailed_metrics.successful_executions}")
print(f" ⢠Failed: {detailed_metrics.failed_executions}")
- print(f" ⢠Avg execution time: {detailed_metrics.average_execution_time:.1f}ms")
-
+ print(
+ f" ⢠Avg execution time: {detailed_metrics.average_execution_time:.1f}ms"
+ )
+
if detailed_metrics.top_tools:
print(f"\nš Top Tools:")
for tool_info in detailed_metrics.top_tools[:3]:
print(f" ⢠{tool_info['tool_name']}: {tool_info['count']} uses")
-
+
return True
-
+
except Exception as e:
print(f"ā Metrics collection failed: {e}")
return False
@@ -250,38 +256,34 @@ async def demo_error_handling(driver):
"""Demonstrate error handling and fallback mechanisms."""
print("\n5. š”ļø Error Handling & Fallbacks")
print("-" * 30)
-
+
# Test various error scenarios
error_scenarios = [
- {
- "description": "Invalid SQL syntax",
- "query": "INVALID SQL QUERY HERE"
- },
- {
- "description": "Dangerous SQL operation",
- "query": "DROP TABLE companies"
- },
+ {"description": "Invalid SQL syntax", "query": "INVALID SQL QUERY HERE"},
+ {"description": "Dangerous SQL operation", "query": "DROP TABLE companies"},
{
"description": "Very long query",
- "query": "SELECT * FROM companies WHERE " + "name = 'test' AND " * 200 + "1=1"
- }
+ "query": "SELECT * FROM companies WHERE "
+ + "name = 'test' AND " * 200
+ + "1=1",
+ },
]
-
+
for i, scenario in enumerate(error_scenarios, 1):
print(f"\n Scenario {i}: {scenario['description']}")
-
+
try:
- response = await driver.process_query(scenario['query'])
-
+ response = await driver.process_query(scenario["query"])
+
# Check if response indicates graceful error handling
if "error" in response.lower() or "cannot" in response.lower():
print(f" ā
Graceful error handling: {response[:100]}...")
else:
print(f" ā ļø Unexpected success: {response[:100]}...")
-
+
except Exception as e:
print(f" ā
Exception caught and handled: {str(e)[:100]}...")
-
+
return True
@@ -289,63 +291,63 @@ async def demo_cache_behavior(driver):
"""Demonstrate cache hits and misses (simulated)."""
print("\n6. š¾ Cache Behavior")
print("-" * 30)
-
+
# Repeat the same query to demonstrate potential caching
query = "Show me companies in the Technology sector"
-
- print(f" Query: \"{query}\"")
-
+
+ print(f' Query: "{query}"')
+
# First execution (cache miss)
print("\n First execution (cache miss):")
start_time = time.time()
response1 = await driver.process_query(query)
time1 = (time.time() - start_time) * 1000
print(f" ā±ļø Execution time: {time1:.1f}ms")
-
+
# Second execution (potential cache hit)
print("\n Second execution (potential cache hit):")
start_time = time.time()
response2 = await driver.process_query(query)
time2 = (time.time() - start_time) * 1000
print(f" ā±ļø Execution time: {time2:.1f}ms")
-
+
# Compare times
if time2 < time1 * 0.8: # 20% faster
print(f" ā
Cache hit detected (faster execution)")
else:
print(f" š No significant performance difference")
-
+
return True
async def main():
"""Run the complete MVP demo."""
driver = None
-
+
try:
# Step 1: System initialization
driver = await demo_system_initialization()
if not driver:
print("\nā Demo failed during initialization")
return
-
+
# Step 2: Database connectivity
if not await demo_database_connectivity(driver):
print("\nā Demo failed during database test")
return
-
+
# Step 3: Tool execution
await demo_tool_execution(driver)
-
+
# Step 4: Monitoring
await demo_monitoring_metrics(driver)
-
+
# Step 5: Error handling
await demo_error_handling(driver)
-
+
# Step 6: Cache behavior
await demo_cache_behavior(driver)
-
+
print("\nš MVP Demo Complete!")
print("\nThe FACT system successfully demonstrates:")
print("ā
Unified component integration")
@@ -354,13 +356,13 @@ async def main():
print("ā
Error handling and security")
print("ā
Monitoring and metrics")
print("ā
CLI interface readiness")
-
+
print("\nš Ready for production use!")
print("Run: python -m src.core.cli")
-
+
except Exception as e:
print(f"\nā Demo failed with error: {e}")
-
+
finally:
# Clean shutdown
if driver:
@@ -368,4 +370,4 @@ async def main():
if __name__ == "__main__":
- asyncio.run(main())
\ No newline at end of file
+ asyncio.run(main())
diff --git a/scripts/fix_imports.py b/scripts/fix_imports.py
index a0e15f7..5a04b1c 100644
--- a/scripts/fix_imports.py
+++ b/scripts/fix_imports.py
@@ -11,66 +11,74 @@
import sys
from pathlib import Path
+
def fix_relative_imports(src_dir):
"""Fix all relative imports in Python files within src_dir."""
src_path = Path(src_dir)
-
+
if not src_path.exists():
print(f"Error: Source directory {src_dir} does not exist")
return False
-
+
# Pattern to match relative imports
patterns = [
- (r'from \.\.([a-zA-Z_][a-zA-Z0-9_]*(?:\.[a-zA-Z_][a-zA-Z0-9_]*)*)', r'from \1'), # from ..module
- (r'from \.([a-zA-Z_][a-zA-Z0-9_]*(?:\.[a-zA-Z_][a-zA-Z0-9_]*)*)', r'from \1'), # from .module
+ (
+ r"from \.\.([a-zA-Z_][a-zA-Z0-9_]*(?:\.[a-zA-Z_][a-zA-Z0-9_]*)*)",
+ r"from \1",
+ ), # from ..module
+ (
+ r"from \.([a-zA-Z_][a-zA-Z0-9_]*(?:\.[a-zA-Z_][a-zA-Z0-9_]*)*)",
+ r"from \1",
+ ), # from .module
]
-
+
files_fixed = 0
total_replacements = 0
-
+
# Walk through all Python files
- for py_file in src_path.rglob('*.py'):
+ for py_file in src_path.rglob("*.py"):
try:
- with open(py_file, 'r', encoding='utf-8') as f:
+ with open(py_file, "r", encoding="utf-8") as f:
content = f.read()
-
+
original_content = content
file_replacements = 0
-
+
# Apply each pattern
for pattern, replacement in patterns:
new_content, count = re.subn(pattern, replacement, content)
content = new_content
file_replacements += count
-
+
# Write back if changes were made
if content != original_content:
- with open(py_file, 'w', encoding='utf-8') as f:
+ with open(py_file, "w", encoding="utf-8") as f:
f.write(content)
files_fixed += 1
total_replacements += file_replacements
print(f"Fixed {file_replacements} imports in {py_file}")
-
+
except Exception as e:
print(f"Error processing {py_file}: {e}")
-
+
print(f"\nSummary:")
print(f"Files fixed: {files_fixed}")
print(f"Total replacements: {total_replacements}")
-
+
return True
+
def main():
"""Main function."""
script_dir = Path(__file__).parent
- src_dir = script_dir.parent / 'src'
-
+ src_dir = script_dir.parent / "src"
+
print("š§ Fixing relative imports in FACT system...")
print(f"Source directory: {src_dir}")
print("-" * 50)
-
+
success = fix_relative_imports(str(src_dir))
-
+
if success:
print("\nā
Import fixing completed successfully!")
print("\nNext step: Test the system with 'python main.py init'")
@@ -78,5 +86,6 @@ def main():
print("\nā Import fixing failed!")
sys.exit(1)
+
if __name__ == "__main__":
- main()
\ No newline at end of file
+ main()
diff --git a/scripts/fix_imports_comprehensive.py b/scripts/fix_imports_comprehensive.py
index ddba499..274b5c3 100644
--- a/scripts/fix_imports_comprehensive.py
+++ b/scripts/fix_imports_comprehensive.py
@@ -11,89 +11,110 @@
import sys
from pathlib import Path
+
def fix_all_imports(src_dir):
"""Fix all import issues in Python files within src_dir."""
src_path = Path(src_dir)
-
+
if not src_path.exists():
print(f"Error: Source directory {src_dir} does not exist")
return False
-
+
# Patterns to fix imports
patterns = [
# Convert relative imports to absolute imports (from .. -> from module)
- (r'from \.\.([a-zA-Z_][a-zA-Z0-9_]*(?:\.[a-zA-Z_][a-zA-Z0-9_]*)*)', r'from \1'),
-
+ (r"from \.\.([a-zA-Z_][a-zA-Z0-9_]*(?:\.[a-zA-Z_][a-zA-Z0-9_]*)*)", r"from \1"),
# Fix local imports in __init__.py files (from module -> from .module)
# This pattern will be applied only in __init__.py files
- (r'^from ([a-zA-Z_][a-zA-Z0-9_]*) import', r'from .\1 import'),
+ (r"^from ([a-zA-Z_][a-zA-Z0-9_]*) import", r"from .\1 import"),
]
-
+
files_fixed = 0
total_replacements = 0
-
+
# Walk through all Python files
- for py_file in src_path.rglob('*.py'):
+ for py_file in src_path.rglob("*.py"):
try:
- with open(py_file, 'r', encoding='utf-8') as f:
+ with open(py_file, "r", encoding="utf-8") as f:
content = f.read()
lines = content.splitlines()
-
+
original_content = content
file_replacements = 0
-
+
# Apply relative import fixes
new_content, count = re.subn(patterns[0][0], patterns[0][1], content)
content = new_content
file_replacements += count
-
+
# Apply local import fixes only for __init__.py files and some specific patterns
- if py_file.name == '__init__.py' or py_file.name in ['client.py', 'gateway.py']:
+ if py_file.name == "__init__.py" or py_file.name in [
+ "client.py",
+ "gateway.py",
+ ]:
# Fix specific patterns for local imports
lines = content.splitlines()
for i, line in enumerate(lines):
# Fix common local import patterns
- if line.strip().startswith('from ') and ' import ' in line and not line.strip().startswith('from .') and not line.strip().startswith('from core') and not line.strip().startswith('from db') and not line.strip().startswith('from tools') and not line.strip().startswith('from security') and not line.strip().startswith('from arcade') and not line.strip().startswith('from cache') and not line.strip().startswith('from monitoring') and not line.strip().startswith('from benchmarking'):
+ if (
+ line.strip().startswith("from ")
+ and " import " in line
+ and not line.strip().startswith("from .")
+ and not line.strip().startswith("from core")
+ and not line.strip().startswith("from db")
+ and not line.strip().startswith("from tools")
+ and not line.strip().startswith("from security")
+ and not line.strip().startswith("from arcade")
+ and not line.strip().startswith("from cache")
+ and not line.strip().startswith("from monitoring")
+ and not line.strip().startswith("from benchmarking")
+ ):
# Extract module name
- match = re.match(r'from ([a-zA-Z_][a-zA-Z0-9_]*) import', line.strip())
+ match = re.match(
+ r"from ([a-zA-Z_][a-zA-Z0-9_]*) import", line.strip()
+ )
if match:
module_name = match.group(1)
# Check if this is a local module (exists in same directory)
module_file = py_file.parent / f"{module_name}.py"
if module_file.exists():
- lines[i] = line.replace(f'from {module_name} import', f'from .{module_name} import')
+ lines[i] = line.replace(
+ f"from {module_name} import",
+ f"from .{module_name} import",
+ )
file_replacements += 1
-
- content = '\n'.join(lines)
-
+
+ content = "\n".join(lines)
+
# Write back if changes were made
if content != original_content:
- with open(py_file, 'w', encoding='utf-8') as f:
+ with open(py_file, "w", encoding="utf-8") as f:
f.write(content)
files_fixed += 1
total_replacements += file_replacements
print(f"Fixed {file_replacements} imports in {py_file}")
-
+
except Exception as e:
print(f"Error processing {py_file}: {e}")
-
+
print(f"\nSummary:")
print(f"Files fixed: {files_fixed}")
print(f"Total replacements: {total_replacements}")
-
+
return True
+
def main():
"""Main function."""
script_dir = Path(__file__).parent
- src_dir = script_dir.parent / 'src'
-
+ src_dir = script_dir.parent / "src"
+
print("š§ Comprehensive import fixing for FACT system...")
print(f"Source directory: {src_dir}")
print("-" * 50)
-
+
success = fix_all_imports(str(src_dir))
-
+
if success:
print("\nā
Import fixing completed successfully!")
print("\nNext step: Test the system with 'python main.py init'")
@@ -101,5 +122,6 @@ def main():
print("\nā Import fixing failed!")
sys.exit(1)
+
if __name__ == "__main__":
- main()
\ No newline at end of file
+ main()
diff --git a/scripts/init_environment.py b/scripts/init_environment.py
index e92523f..e9a6ac6 100644
--- a/scripts/init_environment.py
+++ b/scripts/init_environment.py
@@ -14,7 +14,7 @@
from pathlib import Path
# Add src to path for imports
-sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'src'))
+sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "src"))
from core.config import Config, ConfigurationError
from core.driver import get_driver
@@ -23,13 +23,13 @@
def create_env_file():
"""Create .env file with default configuration."""
- env_path = Path('.env')
-
+ env_path = Path(".env")
+
if env_path.exists():
print("š .env file already exists - skipping creation")
return
-
- env_content = '''# FACT System Configuration
+
+ env_content = """# FACT System Configuration
# Copy this file and update with your actual API keys
# Required API Keys
@@ -45,11 +45,11 @@ def create_env_file():
MAX_RETRIES=3
REQUEST_TIMEOUT=30
LOG_LEVEL=INFO
-'''
-
- with open(env_path, 'w') as f:
+"""
+
+ with open(env_path, "w") as f:
f.write(env_content)
-
+
print("ā
Created .env file with default configuration")
print("ā ļø Please update the API keys in .env before running the system")
@@ -58,24 +58,24 @@ async def init_database():
"""Initialize database with schema and sample data."""
try:
print("šļø Initializing database...")
-
+
# Create database manager and initialize
config = Config()
db_manager = DatabaseManager(config.database_path)
await db_manager.initialize_database()
-
+
# Get database info
db_info = await db_manager.get_database_info()
-
+
print(f"ā
Database initialized successfully:")
print(f" š Path: {db_info['database_path']}")
print(f" š Tables: {db_info['total_tables']}")
-
- for table_name, table_info in db_info['tables'].items():
+
+ for table_name, table_info in db_info["tables"].items():
print(f" š {table_name}: {table_info['row_count']} rows")
-
+
return True
-
+
except Exception as e:
print(f"ā Database initialization failed: {e}")
return False
@@ -85,25 +85,25 @@ async def validate_system():
"""Validate system connectivity and configuration."""
try:
print("š Validating system configuration...")
-
+
# Try to get a driver instance (this validates config and initializes components)
driver = await get_driver()
-
+
# Get system metrics
metrics = driver.get_metrics()
-
+
print("ā
System validation passed:")
print(f" šÆ Initialized: {metrics['initialized']}")
print(f" š ļø Tools: {len(driver.tool_registry.list_tools())}")
-
+
# List available tools
tool_names = driver.tool_registry.list_tools()
print(" š§ Available tools:")
for tool_name in tool_names:
print(f" ⢠{tool_name}")
-
+
return True
-
+
except ConfigurationError as e:
print(f"ā Configuration error: {e}")
print("š” Make sure to update your API keys in .env file")
@@ -117,27 +117,29 @@ async def main():
"""Main initialization routine."""
print("š FACT System Environment Initialization")
print("=" * 50)
-
+
# Step 1: Create .env file
print("\n1. Setting up environment configuration...")
create_env_file()
-
+
# Step 2: Initialize database
print("\n2. Initializing database...")
db_success = await init_database()
-
+
if not db_success:
print("\nā Database initialization failed - stopping")
sys.exit(1)
-
+
# Step 3: Validate system (only if API keys are configured)
print("\n3. Validating system configuration...")
-
+
try:
config = Config()
if config.anthropic_api_key == "your_anthropic_api_key_here":
print("ā ļø API keys not configured - skipping system validation")
- print("š” Update the API keys in .env file, then run: python scripts/validate_system.py")
+ print(
+ "š” Update the API keys in .env file, then run: python scripts/validate_system.py"
+ )
else:
system_success = await validate_system()
if not system_success:
@@ -145,8 +147,10 @@ async def main():
sys.exit(1)
except ConfigurationError:
print("ā ļø API keys not configured - skipping system validation")
- print("š” Update the API keys in .env file, then run: python scripts/validate_system.py")
-
+ print(
+ "š” Update the API keys in .env file, then run: python scripts/validate_system.py"
+ )
+
print("\nš Environment initialization complete!")
print("\nNext steps:")
print("1. Update API keys in .env file")
@@ -155,4 +159,4 @@ async def main():
if __name__ == "__main__":
- asyncio.run(main())
\ No newline at end of file
+ asyncio.run(main())
diff --git a/scripts/run_benchmarks.py b/scripts/run_benchmarks.py
index fc0c6ad..837bb77 100644
--- a/scripts/run_benchmarks.py
+++ b/scripts/run_benchmarks.py
@@ -28,7 +28,7 @@
SystemProfiler,
ContinuousMonitor,
BenchmarkVisualizer,
- ReportGenerator
+ ReportGenerator,
)
from cache.manager import CacheManager
@@ -38,45 +38,49 @@ def create_logs_directory(base_dir: str = "logs") -> Path:
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
logs_dir = Path(base_dir) / f"benchmark_{timestamp}"
logs_dir.mkdir(parents=True, exist_ok=True)
-
+
# Create subdirectories for better organization
(logs_dir / "charts").mkdir(exist_ok=True)
(logs_dir / "raw_data").mkdir(exist_ok=True)
(logs_dir / "reports").mkdir(exist_ok=True)
-
+
return logs_dir
-def print_performance_summary(validation_results, comparison_result=None, load_test_results=None):
+def print_performance_summary(
+ validation_results, comparison_result=None, load_test_results=None
+):
"""Print a comprehensive performance summary to console."""
- print("\n" + "="*80)
+ print("\n" + "=" * 80)
print("š FACT SYSTEM PERFORMANCE SUMMARY")
- print("="*80)
-
+ print("=" * 80)
+
# Overall Status
- overall_pass = validation_results['overall_pass']
+ overall_pass = validation_results["overall_pass"]
status_emoji = "š" if overall_pass else "ā ļø"
status_text = "ALL TARGETS MET" if overall_pass else "IMPROVEMENTS NEEDED"
print(f"\n{status_emoji} OVERALL STATUS: {status_text}")
-
+
# Performance Targets Summary
print(f"\nš PERFORMANCE TARGETS:")
print("-" * 50)
- target_validation = validation_results['target_validation']
-
+ target_validation = validation_results["target_validation"]
+
for target_name, target_data in target_validation.items():
- status = "ā
PASS" if target_data['met'] else "ā FAIL"
- if 'latency' in target_name:
+ status = "ā
PASS" if target_data["met"] else "ā FAIL"
+ if "latency" in target_name:
actual_val = f"{target_data.get('actual_ms', 0):.1f}ms"
target_val = f"{target_data.get('target_ms', 0):.1f}ms"
else:
actual_val = f"{target_data.get('actual_percent', 0):.1f}%"
target_val = f"{target_data.get('target_percent', 0):.1f}%"
-
- print(f" {target_name:25} {status:8} Actual: {actual_val:10} Target: {target_val}")
-
+
+ print(
+ f" {target_name:25} {status:8} Actual: {actual_val:10} Target: {target_val}"
+ )
+
# Cache Performance
- summary = validation_results.get('benchmark_summary', {})
+ summary = validation_results.get("benchmark_summary", {})
if summary:
print(f"\nšļø CACHE PERFORMANCE:")
print("-" * 50)
@@ -84,37 +88,47 @@ def print_performance_summary(validation_results, comparison_result=None, load_t
print(f" Avg Response Time (Hit): {summary.avg_hit_latency_ms:.1f}ms")
print(f" Avg Response Time (Miss): {summary.avg_miss_latency_ms:.1f}ms")
print(f" Total Requests: {summary.total_queries}")
- print(f" Success Rate: {(summary.successful_queries/summary.total_queries)*100:.1f}%")
-
+ print(
+ f" Success Rate: {(summary.successful_queries/summary.total_queries)*100:.1f}%"
+ )
+
# RAG Comparison Results
if comparison_result:
print(f"\nāļø FACT vs TRADITIONAL RAG:")
print("-" * 50)
- latency_improvement = comparison_result.improvement_factors.get('latency', 1.0)
- cost_savings = comparison_result.cost_savings.get('percentage', 0.0)
+ latency_improvement = comparison_result.improvement_factors.get("latency", 1.0)
+ cost_savings = comparison_result.cost_savings.get("percentage", 0.0)
print(f" Latency Improvement: {latency_improvement:.1f}x faster")
print(f" Cost Savings: {cost_savings:.1f}%")
print(f" Recommendation: {comparison_result.recommendation}")
-
+
# Load Test Results
if load_test_results:
print(f"\nš¦ LOAD TEST PERFORMANCE:")
print("-" * 50)
- print(f" Concurrent Users: {load_test_results.get('concurrent_users', 0)}")
- print(f" Throughput: {load_test_results.get('throughput_qps', 0):.1f} QPS")
- print(f" Avg Response Time: {load_test_results.get('avg_response_time_ms', 0):.1f}ms")
- print(f" Error Rate: {load_test_results.get('error_rate', 0):.1f}%")
-
- print("="*80)
+ print(
+ f" Concurrent Users: {load_test_results.get('concurrent_users', 0)}"
+ )
+ print(
+ f" Throughput: {load_test_results.get('throughput_qps', 0):.1f} QPS"
+ )
+ print(
+ f" Avg Response Time: {load_test_results.get('avg_response_time_ms', 0):.1f}ms"
+ )
+ print(
+ f" Error Rate: {load_test_results.get('error_rate', 0):.1f}%"
+ )
+
+ print("=" * 80)
async def run_comprehensive_benchmark(args):
"""Run comprehensive benchmark suite."""
print("š Starting FACT Comprehensive Benchmark Suite")
print("=" * 60)
-
+
# Create timestamped logs directory
- if args.output_dir == './benchmark_results':
+ if args.output_dir == "./benchmark_results":
# Use logs directory by default
logs_dir = create_logs_directory()
print(f"š Created logs directory: {logs_dir}")
@@ -123,7 +137,7 @@ async def run_comprehensive_benchmark(args):
logs_dir = Path(args.output_dir)
logs_dir.mkdir(parents=True, exist_ok=True)
print(f"š Using output directory: {logs_dir}")
-
+
# Initialize components
config = BenchmarkConfig(
iterations=args.iterations,
@@ -133,15 +147,15 @@ async def run_comprehensive_benchmark(args):
target_hit_latency_ms=args.hit_target,
target_miss_latency_ms=args.miss_target,
target_cost_reduction_hit=args.cost_reduction / 100.0,
- target_cache_hit_rate=args.cache_hit_rate / 100.0
+ target_cache_hit_rate=args.cache_hit_rate / 100.0,
)
-
+
framework = BenchmarkFramework(config)
runner = BenchmarkRunner(framework)
profiler = SystemProfiler()
visualizer = BenchmarkVisualizer()
report_generator = ReportGenerator(visualizer)
-
+
# Initialize cache manager if available
cache_manager = None
try:
@@ -152,133 +166,143 @@ async def run_comprehensive_benchmark(args):
"max_size": "10MB",
"ttl_seconds": 3600,
"hit_target_ms": 48,
- "miss_target_ms": 140
+ "miss_target_ms": 140,
}
cache_manager = CacheManager(cache_config)
print("ā
Cache manager initialized")
except Exception as e:
print(f"ā ļø Cache manager not available: {e}")
-
+
# Phase 1: Performance Validation
print("\nš Phase 1: Performance Validation")
print("-" * 40)
-
+
validation_results = await runner.run_performance_validation(cache_manager)
-
+
# Display validation results
- print(f"Overall Validation: {'ā
PASS' if validation_results['overall_pass'] else 'ā FAIL'}")
-
- target_validation = validation_results['target_validation']
+ print(
+ f"Overall Validation: {'ā
PASS' if validation_results['overall_pass'] else 'ā FAIL'}"
+ )
+
+ target_validation = validation_results["target_validation"]
for target_name, target_data in target_validation.items():
- status = "ā
PASS" if target_data['met'] else "ā FAIL"
- print(f" {target_name}: {status} ({target_data['actual_ms' if 'latency' in target_name else 'actual_percent']:.1f})")
-
+ status = "ā
PASS" if target_data["met"] else "ā FAIL"
+ print(
+ f" {target_name}: {status} ({target_data['actual_ms' if 'latency' in target_name else 'actual_percent']:.1f})"
+ )
+
# Phase 2: RAG Comparison (if enabled)
comparison_result = None
if args.include_rag_comparison:
print("\nāļø Phase 2: RAG Comparison Analysis")
print("-" * 40)
-
+
rag_comparison = RAGComparison(framework)
comparison_result = await rag_comparison.run_comparison_benchmark(
runner.test_queries, cache_manager, config.iterations
)
-
- print(f"Latency Improvement: {comparison_result.improvement_factors.get('latency', 1.0):.1f}x")
- print(f"Cost Savings: {comparison_result.cost_savings.get('percentage', 0.0):.1f}%")
+
+ print(
+ f"Latency Improvement: {comparison_result.improvement_factors.get('latency', 1.0):.1f}x"
+ )
+ print(
+ f"Cost Savings: {comparison_result.cost_savings.get('percentage', 0.0):.1f}%"
+ )
print(f"Recommendation: {comparison_result.recommendation}")
-
+
# Phase 3: Profiling Analysis (if enabled)
profile_result = None
if args.include_profiling:
print("\nš Phase 3: Performance Profiling")
print("-" * 40)
-
+
# Profile a representative operation
async def sample_operation():
return await runner.run_performance_validation(cache_manager)
-
+
_, profile_result = await profiler.profile_complete_operation(
sample_operation, "performance_validation"
)
-
+
print(f"Execution Time: {profile_result.execution_time_ms:.1f}ms")
print(f"Bottlenecks Found: {len(profile_result.bottlenecks)}")
-
- critical_bottlenecks = [b for b in profile_result.bottlenecks if b.severity == "critical"]
+
+ critical_bottlenecks = [
+ b for b in profile_result.bottlenecks if b.severity == "critical"
+ ]
if critical_bottlenecks:
print("ā Critical Bottlenecks:")
for bottleneck in critical_bottlenecks[:3]:
print(f" - {bottleneck.component}: {bottleneck.description}")
-
+
# Phase 4: Load Testing (if enabled)
load_test_results = None
if args.include_load_test:
print("\nš¦ Phase 4: Load Testing")
print("-" * 40)
-
+
load_test_results = await runner.run_load_test(
concurrent_users=args.load_test_users,
- duration_seconds=args.load_test_duration
+ duration_seconds=args.load_test_duration,
)
-
+
print(f"Concurrent Users: {load_test_results['concurrent_users']}")
print(f"Throughput: {load_test_results['throughput_qps']:.1f} QPS")
print(f"Avg Response Time: {load_test_results['avg_response_time_ms']:.1f}ms")
print(f"Error Rate: {load_test_results['error_rate']:.1f}%")
-
+
# Phase 5: Report Generation & Visualization
print("\nš Phase 5: Report Generation & Visualization")
print("-" * 40)
-
- benchmark_summary = validation_results['benchmark_summary']
-
+
+ benchmark_summary = validation_results["benchmark_summary"]
+
# Generate comprehensive report
report = report_generator.generate_comprehensive_report(
benchmark_summary=benchmark_summary,
comparison_result=comparison_result,
profile_result=profile_result,
- alerts=None # No alerts in batch mode
+ alerts=None, # No alerts in batch mode
)
-
+
# Generate timestamp for consistent naming
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
-
+
# Save comprehensive JSON report
json_report_path = logs_dir / "reports" / f"benchmark_report_{timestamp}.json"
- with open(json_report_path, 'w') as f:
+ with open(json_report_path, "w") as f:
f.write(report_generator.export_report_json(report))
-
+
# Save text summary
text_report_path = logs_dir / "reports" / f"benchmark_summary_{timestamp}.txt"
- with open(text_report_path, 'w') as f:
+ with open(text_report_path, "w") as f:
f.write(report_generator.export_report_summary_text(report))
-
+
# Save raw data for further analysis
raw_data_path = logs_dir / "raw_data" / f"raw_results_{timestamp}.json"
raw_data = {
- 'validation_results': validation_results,
- 'comparison_result': comparison_result.__dict__ if comparison_result else None,
- 'profile_result': profile_result.__dict__ if profile_result else None,
- 'load_test_results': load_test_results,
- 'config': config.__dict__,
- 'timestamp': timestamp,
- 'args': vars(args)
+ "validation_results": validation_results,
+ "comparison_result": comparison_result.__dict__ if comparison_result else None,
+ "profile_result": profile_result.__dict__ if profile_result else None,
+ "load_test_results": load_test_results,
+ "config": config.__dict__,
+ "timestamp": timestamp,
+ "args": vars(args),
}
- with open(raw_data_path, 'w') as f:
+ with open(raw_data_path, "w") as f:
json.dump(raw_data, f, indent=2, default=str)
-
+
# Generate and save visualizations
charts_dir = logs_dir / "charts"
-
+
print("š Generating performance visualizations...")
-
+
# Performance charts
for i, chart in enumerate(report.charts):
chart_path = charts_dir / f"chart_{i}_{chart.chart_type}_{timestamp}.json"
- with open(chart_path, 'w') as f:
+ with open(chart_path, "w") as f:
f.write(visualizer.export_chart_data_json(chart))
-
+
# Additional visualizations if comparison data available
if comparison_result:
# Latency comparison chart
@@ -286,68 +310,70 @@ async def sample_operation():
benchmark_summary, comparison_result
)
latency_chart_path = charts_dir / f"latency_comparison_{timestamp}.json"
- with open(latency_chart_path, 'w') as f:
+ with open(latency_chart_path, "w") as f:
f.write(visualizer.export_chart_data_json(latency_chart))
-
+
# Cost savings chart
cost_chart = visualizer.create_cost_analysis_chart(comparison_result)
cost_chart_path = charts_dir / f"cost_analysis_{timestamp}.json"
- with open(cost_chart_path, 'w') as f:
+ with open(cost_chart_path, "w") as f:
f.write(visualizer.export_chart_data_json(cost_chart))
-
+
# Cache performance chart
cache_chart = visualizer.create_cache_performance_chart(benchmark_summary)
cache_chart_path = charts_dir / f"cache_performance_{timestamp}.json"
- with open(cache_chart_path, 'w') as f:
+ with open(cache_chart_path, "w") as f:
f.write(visualizer.export_chart_data_json(cache_chart))
-
+
print(f"š Reports saved to: {logs_dir}")
print(f"š JSON Report: {json_report_path}")
print(f"š Summary: {text_report_path}")
print(f"š Raw Data: {raw_data_path}")
print(f"š Charts: {charts_dir}")
-
+
# Print comprehensive performance summary to console
print_performance_summary(validation_results, comparison_result, load_test_results)
-
+
# Performance Grade
- grade = report.summary.get('performance_grade', 'N/A')
+ grade = report.summary.get("performance_grade", "N/A")
print(f"\nš Performance Grade: {grade}")
-
+
# Print key recommendations
if report.recommendations:
print(f"\nš§ KEY RECOMMENDATIONS:")
print("-" * 50)
for i, rec in enumerate(report.recommendations[:5], 1):
print(f" {i}. {rec}")
-
+
# Final status message
- if validation_results['overall_pass']:
- print(f"\nš Benchmark completed successfully! All performance targets achieved.")
+ if validation_results["overall_pass"]:
+ print(
+ f"\nš Benchmark completed successfully! All performance targets achieved."
+ )
print(f" Results saved to: {logs_dir}")
else:
print(f"\nā ļø Benchmark completed with some targets not met.")
print(f" Review optimization strategies and detailed reports in: {logs_dir}")
-
- return validation_results['overall_pass']
+
+ return validation_results["overall_pass"]
async def run_continuous_monitoring(args):
"""Run continuous monitoring mode."""
print("š Starting FACT Continuous Monitoring")
print("=" * 60)
-
+
# Initialize monitoring
monitor = ContinuousMonitor()
-
+
# Add console alert callback
def alert_callback(alert):
severity_emoji = {"info": "ā¹ļø", "warning": "ā ļø", "critical": "šØ"}
emoji = severity_emoji.get(alert.severity, "š¢")
print(f"{emoji} {alert.severity.upper()}: {alert.message}")
-
+
monitor.add_alert_callback(alert_callback)
-
+
# Initialize cache manager
cache_manager = None
try:
@@ -358,18 +384,18 @@ def alert_callback(alert):
"max_size": "10MB",
"ttl_seconds": 3600,
"hit_target_ms": 48,
- "miss_target_ms": 140
+ "miss_target_ms": 140,
}
cache_manager = CacheManager(cache_config)
print("ā
Cache manager initialized")
except Exception as e:
print(f"ā ļø Cache manager not available: {e}")
-
+
try:
# Start monitoring
await monitor.start_monitoring(cache_manager)
print("š” Monitoring started. Press Ctrl+C to stop.")
-
+
# Monitor for specified duration or indefinitely
if args.monitor_duration > 0:
await asyncio.sleep(args.monitor_duration)
@@ -378,32 +404,32 @@ def alert_callback(alert):
try:
while True:
await asyncio.sleep(60)
-
+
# Print status every minute
status = monitor.get_monitoring_status()
print(f"š Active alerts: {status['active_alerts']}")
except KeyboardInterrupt:
pass
-
+
finally:
# Stop monitoring and generate report
await monitor.stop_monitoring()
-
+
monitoring_report = monitor.export_monitoring_report()
-
+
# Create logs directory for monitoring
- if args.output_dir == './benchmark_results':
+ if args.output_dir == "./benchmark_results":
logs_dir = create_logs_directory()
else:
logs_dir = Path(args.output_dir)
logs_dir.mkdir(parents=True, exist_ok=True)
-
+
# Save monitoring report with timestamp
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
report_path = logs_dir / "reports" / f"monitoring_report_{timestamp}.json"
- with open(report_path, 'w') as f:
+ with open(report_path, "w") as f:
json.dump(monitoring_report, f, indent=2, default=str)
-
+
print(f"\nš Monitoring report saved: {report_path}")
print(f"š All monitoring data in: {logs_dir}")
@@ -426,54 +452,99 @@ def main():
# Custom performance targets
python scripts/run_benchmarks.py --hit-target 40 --miss-target 120 --cost-reduction 85
- """
+ """,
)
-
+
# Mode selection
parser.add_argument(
- '--mode',
- choices=['benchmark', 'monitoring'],
- default='benchmark',
- help='Execution mode (default: benchmark)'
+ "--mode",
+ choices=["benchmark", "monitoring"],
+ default="benchmark",
+ help="Execution mode (default: benchmark)",
)
-
+
# Benchmark configuration
- parser.add_argument('--iterations', type=int, default=10, help='Number of benchmark iterations')
- parser.add_argument('--warmup', type=int, default=3, help='Number of warmup iterations')
- parser.add_argument('--concurrent-users', type=int, default=1, help='Number of concurrent users')
- parser.add_argument('--timeout', type=int, default=30, help='Timeout in seconds')
-
+ parser.add_argument(
+ "--iterations", type=int, default=10, help="Number of benchmark iterations"
+ )
+ parser.add_argument(
+ "--warmup", type=int, default=3, help="Number of warmup iterations"
+ )
+ parser.add_argument(
+ "--concurrent-users", type=int, default=1, help="Number of concurrent users"
+ )
+ parser.add_argument("--timeout", type=int, default=30, help="Timeout in seconds")
+
# Performance targets
- parser.add_argument('--hit-target', type=float, default=48.0, help='Cache hit latency target (ms)')
- parser.add_argument('--miss-target', type=float, default=140.0, help='Cache miss latency target (ms)')
- parser.add_argument('--cost-reduction', type=float, default=90.0, help='Cost reduction target (percent)')
- parser.add_argument('--cache-hit-rate', type=float, default=60.0, help='Cache hit rate target (percent)')
-
+ parser.add_argument(
+ "--hit-target", type=float, default=48.0, help="Cache hit latency target (ms)"
+ )
+ parser.add_argument(
+ "--miss-target",
+ type=float,
+ default=140.0,
+ help="Cache miss latency target (ms)",
+ )
+ parser.add_argument(
+ "--cost-reduction",
+ type=float,
+ default=90.0,
+ help="Cost reduction target (percent)",
+ )
+ parser.add_argument(
+ "--cache-hit-rate",
+ type=float,
+ default=60.0,
+ help="Cache hit rate target (percent)",
+ )
+
# Optional components
- parser.add_argument('--include-rag-comparison', action='store_true', help='Include RAG comparison')
- parser.add_argument('--include-profiling', action='store_true', help='Include performance profiling')
- parser.add_argument('--include-load-test', action='store_true', help='Include load testing')
-
+ parser.add_argument(
+ "--include-rag-comparison", action="store_true", help="Include RAG comparison"
+ )
+ parser.add_argument(
+ "--include-profiling", action="store_true", help="Include performance profiling"
+ )
+ parser.add_argument(
+ "--include-load-test", action="store_true", help="Include load testing"
+ )
+
# Load testing configuration
- parser.add_argument('--load-test-users', type=int, default=10, help='Load test concurrent users')
- parser.add_argument('--load-test-duration', type=int, default=60, help='Load test duration (seconds)')
-
+ parser.add_argument(
+ "--load-test-users", type=int, default=10, help="Load test concurrent users"
+ )
+ parser.add_argument(
+ "--load-test-duration",
+ type=int,
+ default=60,
+ help="Load test duration (seconds)",
+ )
+
# Monitoring configuration
- parser.add_argument('--monitor-duration', type=int, default=0, help='Monitoring duration (0=indefinite)')
-
+ parser.add_argument(
+ "--monitor-duration",
+ type=int,
+ default=0,
+ help="Monitoring duration (0=indefinite)",
+ )
+
# Output configuration
- parser.add_argument('--output-dir', default='./benchmark_results', help='Output directory for reports (default: creates timestamped logs directory)')
-
+ parser.add_argument(
+ "--output-dir",
+ default="./benchmark_results",
+ help="Output directory for reports (default: creates timestamped logs directory)",
+ )
+
args = parser.parse_args()
-
+
# Run appropriate mode
- if args.mode == 'benchmark':
+ if args.mode == "benchmark":
success = asyncio.run(run_comprehensive_benchmark(args))
sys.exit(0 if success else 1)
- elif args.mode == 'monitoring':
+ elif args.mode == "monitoring":
asyncio.run(run_continuous_monitoring(args))
sys.exit(0)
if __name__ == "__main__":
- main()
\ No newline at end of file
+ main()
diff --git a/scripts/run_benchmarks_demo.py b/scripts/run_benchmarks_demo.py
index 343d4dc..1e72f76 100644
--- a/scripts/run_benchmarks_demo.py
+++ b/scripts/run_benchmarks_demo.py
@@ -21,23 +21,27 @@
# Add src to path for imports
sys.path.insert(0, str(Path(__file__).parent.parent / "src"))
+
@dataclass
class DemoResult:
"""Demo benchmark result."""
+
query: str
response_time_ms: float
success: bool = True
cache_hit: bool = False
error: Optional[str] = None
timestamp: float = None
-
+
def __post_init__(self):
if self.timestamp is None:
self.timestamp = time.time()
+
@dataclass
class DemoMetrics:
"""Demo performance metrics."""
+
total_requests: int = 0
cache_hits: int = 0
cache_misses: int = 0
@@ -47,14 +51,20 @@ class DemoMetrics:
success_rate: float = 1.0
cost_reduction_percentage: float = 0.0
+
class DemoBenchmarkRunner:
"""Demo benchmark runner that simulates FACT performance."""
-
- def __init__(self, target_hit_rate: float = 0.65, hit_target_ms: float = 45.0, miss_target_ms: float = 135.0):
+
+ def __init__(
+ self,
+ target_hit_rate: float = 0.65,
+ hit_target_ms: float = 45.0,
+ miss_target_ms: float = 135.0,
+ ):
self.target_hit_rate = target_hit_rate
self.hit_target_ms = hit_target_ms
self.miss_target_ms = miss_target_ms
-
+
# Sample queries for benchmarking
self.test_queries = [
"What was the Q1-2025 revenue?",
@@ -66,24 +76,24 @@ def __init__(self, target_hit_rate: float = 0.65, hit_target_ms: float = 45.0, m
"Compare performance across regions",
"What is the customer acquisition cost?",
"Show quarterly expense breakdown",
- "Predict next quarter's revenue"
+ "Predict next quarter's revenue",
]
-
+
# Track queries for cache simulation
self.query_history = set()
-
+
async def run_performance_validation(self, iterations: int = 100) -> Dict[str, Any]:
"""Run demo performance validation."""
print(f"š Running {iterations} demo queries to validate FACT performance...")
-
+
results = []
hit_latencies = []
miss_latencies = []
-
+
# Pre-populate some queries to simulate warmed cache
for query in self.test_queries[:6]: # Warm 60% of queries
self.query_history.add(query)
-
+
for i in range(iterations):
# Simulate FACT algorithm intelligent caching
# After warmup phase, achieve target hit rate with optimization
@@ -92,12 +102,14 @@ async def run_performance_validation(self, iterations: int = 100) -> Dict[str, A
else:
# FACT algorithm optimizes over time to achieve target hit rate
warmup_progress = min(1.0, (i - 15) / 30) # Gradual improvement
- target_adjusted = self.target_hit_rate + (warmup_progress * 0.1) # Exceed target slightly
+ target_adjusted = self.target_hit_rate + (
+ warmup_progress * 0.1
+ ) # Exceed target slightly
cache_hit_probability = target_adjusted + random.uniform(-0.03, 0.03)
-
+
# Determine if this will be a cache hit
is_cache_hit = random.random() < cache_hit_probability
-
+
if is_cache_hit and self.query_history:
# Cache hit - select from known queries
query = random.choice(list(self.query_history))
@@ -107,58 +119,70 @@ async def run_performance_validation(self, iterations: int = 100) -> Dict[str, A
query = random.choice(self.test_queries)
cache_hit = False
self.query_history.add(query)
-
+
# Simulate response times based on FACT targets
if cache_hit:
# Cache hit: target ā¤48ms with some variance
- latency = random.normalvariate(self.hit_target_ms * 0.9, self.hit_target_ms * 0.15)
- latency = max(15.0, min(latency, self.hit_target_ms * 1.1)) # Clamp to reasonable range
+ latency = random.normalvariate(
+ self.hit_target_ms * 0.9, self.hit_target_ms * 0.15
+ )
+ latency = max(
+ 15.0, min(latency, self.hit_target_ms * 1.1)
+ ) # Clamp to reasonable range
hit_latencies.append(latency)
else:
- # Cache miss: target ā¤140ms with some variance
- latency = random.normalvariate(self.miss_target_ms * 0.95, self.miss_target_ms * 0.2)
- latency = max(80.0, min(latency, self.miss_target_ms * 1.05)) # Clamp to reasonable range
+ # Cache miss: target ā¤140ms with some variance
+ latency = random.normalvariate(
+ self.miss_target_ms * 0.95, self.miss_target_ms * 0.2
+ )
+ latency = max(
+ 80.0, min(latency, self.miss_target_ms * 1.05)
+ ) # Clamp to reasonable range
miss_latencies.append(latency)
-
+
result = DemoResult(
- query=query,
- response_time_ms=latency,
- cache_hit=cache_hit
+ query=query, response_time_ms=latency, cache_hit=cache_hit
)
results.append(result)
-
+
# Progress indicator
if (i + 1) % 20 == 0:
print(f" Completed {i + 1}/{iterations} queries...")
-
+
# Calculate metrics
cache_hits = sum(1 for r in results if r.cache_hit)
cache_misses = len(results) - cache_hits
hit_rate = (cache_hits / len(results)) * 100 if results else 0
-
- avg_hit_latency = sum(hit_latencies) / len(hit_latencies) if hit_latencies else 0
- avg_miss_latency = sum(miss_latencies) / len(miss_latencies) if miss_latencies else 0
-
+
+ avg_hit_latency = (
+ sum(hit_latencies) / len(hit_latencies) if hit_latencies else 0
+ )
+ avg_miss_latency = (
+ sum(miss_latencies) / len(miss_latencies) if miss_latencies else 0
+ )
+
# Cost reduction calculation (based on FACT algorithm efficiency)
# FACT algorithm provides significant cost savings through:
# 1. Intelligent caching reduces API calls by 90% on hits
# 2. Optimized query processing reduces costs by 65% even on misses
# 3. Additional efficiency gains from token optimization
-
+
hit_ratio = cache_hits / len(results) if results else 0
miss_ratio = cache_misses / len(results) if results else 0
-
+
# Base cost reduction from caching strategy
base_hit_savings = hit_ratio * 90.0 # 90% reduction on cache hits
base_miss_savings = miss_ratio * 65.0 # 65% reduction on cache misses
-
+
# Additional efficiency bonus when hit rate exceeds 60%
if hit_rate >= 60.0:
- efficiency_bonus = min(10.0, (hit_rate - 60.0) * 0.5) # Up to 10% additional savings
+ efficiency_bonus = min(
+ 10.0, (hit_rate - 60.0) * 0.5
+ ) # Up to 10% additional savings
cost_reduction = base_hit_savings + base_miss_savings + efficiency_bonus
else:
cost_reduction = base_hit_savings + base_miss_savings
-
+
summary = DemoMetrics(
total_requests=len(results),
cache_hits=cache_hits,
@@ -166,85 +190,89 @@ async def run_performance_validation(self, iterations: int = 100) -> Dict[str, A
avg_hit_latency_ms=avg_hit_latency,
avg_miss_latency_ms=avg_miss_latency,
cache_hit_rate=hit_rate / 100,
- cost_reduction_percentage=cost_reduction
+ cost_reduction_percentage=cost_reduction,
)
-
+
# Target validation
target_validation = {
- 'cache_hit_latency': {
- 'met': avg_hit_latency <= 48.0,
- 'actual_ms': avg_hit_latency,
- 'target_ms': 48.0
+ "cache_hit_latency": {
+ "met": avg_hit_latency <= 48.0,
+ "actual_ms": avg_hit_latency,
+ "target_ms": 48.0,
+ },
+ "cache_miss_latency": {
+ "met": avg_miss_latency <= 140.0,
+ "actual_ms": avg_miss_latency,
+ "target_ms": 140.0,
},
- 'cache_miss_latency': {
- 'met': avg_miss_latency <= 140.0,
- 'actual_ms': avg_miss_latency,
- 'target_ms': 140.0
+ "cost_reduction": {
+ "met": cost_reduction >= 90.0,
+ "actual_percent": cost_reduction,
+ "target_percent": 90.0,
},
- 'cost_reduction': {
- 'met': cost_reduction >= 90.0,
- 'actual_percent': cost_reduction,
- 'target_percent': 90.0
+ "cache_hit_rate": {
+ "met": hit_rate >= 60.0,
+ "actual_percent": hit_rate,
+ "target_percent": 60.0,
},
- 'cache_hit_rate': {
- 'met': hit_rate >= 60.0,
- 'actual_percent': hit_rate,
- 'target_percent': 60.0
- }
}
-
- overall_pass = all(target['met'] for target in target_validation.values())
-
+
+ overall_pass = all(target["met"] for target in target_validation.values())
+
return {
- 'overall_pass': overall_pass,
- 'target_validation': target_validation,
- 'benchmark_summary': summary,
- 'results': results
+ "overall_pass": overall_pass,
+ "target_validation": target_validation,
+ "benchmark_summary": summary,
+ "results": results,
}
+
def create_logs_directory(base_dir: str = "logs") -> Path:
"""Create timestamped logs directory and return the path."""
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
logs_dir = Path(base_dir) / f"demo_benchmark_{timestamp}"
logs_dir.mkdir(parents=True, exist_ok=True)
-
+
# Create subdirectories for better organization
(logs_dir / "charts").mkdir(exist_ok=True)
(logs_dir / "raw_data").mkdir(exist_ok=True)
(logs_dir / "reports").mkdir(exist_ok=True)
-
+
return logs_dir
+
def print_performance_summary(validation_results):
"""Print a comprehensive performance summary to console."""
- print("\n" + "="*80)
+ print("\n" + "=" * 80)
print("š FACT DEMO PERFORMANCE SUMMARY")
- print("="*80)
-
+ print("=" * 80)
+
# Overall Status
- overall_pass = validation_results['overall_pass']
+ overall_pass = validation_results["overall_pass"]
status_emoji = "š" if overall_pass else "ā ļø"
status_text = "ALL TARGETS MET" if overall_pass else "IMPROVEMENTS NEEDED"
print(f"\n{status_emoji} OVERALL STATUS: {status_text}")
-
+
# Performance Targets Summary
print(f"\nš PERFORMANCE TARGETS:")
print("-" * 50)
- target_validation = validation_results['target_validation']
-
+ target_validation = validation_results["target_validation"]
+
for target_name, target_data in target_validation.items():
- status = "ā
PASS" if target_data['met'] else "ā FAIL"
- if 'latency' in target_name:
+ status = "ā
PASS" if target_data["met"] else "ā FAIL"
+ if "latency" in target_name:
actual_val = f"{target_data.get('actual_ms', 0):.1f}ms"
target_val = f"{target_data.get('target_ms', 0):.1f}ms"
else:
actual_val = f"{target_data.get('actual_percent', 0):.1f}%"
target_val = f"{target_data.get('target_percent', 0):.1f}%"
-
- print(f" {target_name:25} {status:8} Actual: {actual_val:10} Target: {target_val}")
-
+
+ print(
+ f" {target_name:25} {status:8} Actual: {actual_val:10} Target: {target_val}"
+ )
+
# Cache Performance
- summary = validation_results.get('benchmark_summary')
+ summary = validation_results.get("benchmark_summary")
if summary:
print(f"\nšļø CACHE PERFORMANCE:")
print("-" * 50)
@@ -254,8 +282,9 @@ def print_performance_summary(validation_results):
print(f" Total Requests: {summary.total_requests}")
print(f" Success Rate: {summary.success_rate*100:.1f}%")
print(f" Cost Reduction: {summary.cost_reduction_percentage:.1f}%")
-
- print("="*80)
+
+ print("=" * 80)
+
async def run_demo_benchmark(args):
"""Run demo benchmark suite."""
@@ -263,40 +292,42 @@ async def run_demo_benchmark(args):
print("=" * 60)
print("š Note: This is a demonstration using simulated data")
print("=" * 60)
-
+
# Create timestamped logs directory
logs_dir = create_logs_directory()
print(f"š Created logs directory: {logs_dir}")
-
+
# Initialize demo runner with FACT targets
runner = DemoBenchmarkRunner(
target_hit_rate=args.cache_hit_rate / 100.0,
hit_target_ms=args.hit_target,
- miss_target_ms=args.miss_target
+ miss_target_ms=args.miss_target,
)
-
+
# Run performance validation
print("\nš Phase 1: Demo Performance Validation")
print("-" * 40)
-
+
validation_results = await runner.run_performance_validation(args.iterations)
-
+
# Display validation results
- print(f"Overall Validation: {'ā
PASS' if validation_results['overall_pass'] else 'ā FAIL'}")
-
- target_validation = validation_results['target_validation']
+ print(
+ f"Overall Validation: {'ā
PASS' if validation_results['overall_pass'] else 'ā FAIL'}"
+ )
+
+ target_validation = validation_results["target_validation"]
for target_name, target_data in target_validation.items():
- status = "ā
PASS" if target_data['met'] else "ā FAIL"
- actual_key = 'actual_ms' if 'latency' in target_name else 'actual_percent'
+ status = "ā
PASS" if target_data["met"] else "ā FAIL"
+ actual_key = "actual_ms" if "latency" in target_name else "actual_percent"
print(f" {target_name}: {status} ({target_data[actual_key]:.1f})")
-
+
# Generate reports
print("\nš Phase 2: Report Generation")
print("-" * 40)
-
+
# Generate timestamp for consistent naming
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
-
+
# Save comprehensive JSON report
json_report_path = logs_dir / "reports" / f"demo_benchmark_report_{timestamp}.json"
report_data = {
@@ -306,52 +337,63 @@ async def run_demo_benchmark(args):
"iterations": args.iterations,
"hit_target_ms": args.hit_target,
"miss_target_ms": args.miss_target,
- "cache_hit_rate_target": args.cache_hit_rate
+ "cache_hit_rate_target": args.cache_hit_rate,
},
- "results": validation_results
+ "results": validation_results,
}
-
- with open(json_report_path, 'w') as f:
+
+ with open(json_report_path, "w") as f:
json.dump(report_data, f, indent=2, default=str)
-
+
# Save text summary
text_report_path = logs_dir / "reports" / f"demo_benchmark_summary_{timestamp}.txt"
- with open(text_report_path, 'w') as f:
+ with open(text_report_path, "w") as f:
f.write("FACT Demo Benchmark Summary\n")
f.write("=" * 50 + "\n\n")
f.write(f"Timestamp: {timestamp}\n")
f.write(f"Iterations: {args.iterations}\n")
f.write(f"Overall Pass: {validation_results['overall_pass']}\n\n")
-
+
f.write("Performance Targets:\n")
f.write("-" * 30 + "\n")
for target_name, target_data in target_validation.items():
- status = "PASS" if target_data['met'] else "FAIL"
- actual_key = 'actual_ms' if 'latency' in target_name else 'actual_percent'
+ status = "PASS" if target_data["met"] else "FAIL"
+ actual_key = "actual_ms" if "latency" in target_name else "actual_percent"
f.write(f"{target_name}: {status} ({target_data[actual_key]:.1f})\n")
-
+
print(f"š Reports saved to: {logs_dir}")
print(f"š JSON Report: {json_report_path}")
print(f"š Summary: {text_report_path}")
-
+
# Print comprehensive performance summary to console
print_performance_summary(validation_results)
-
+
# Final status message
- if validation_results['overall_pass']:
- print(f"\nš Demo benchmark completed successfully! All performance targets achieved.")
+ if validation_results["overall_pass"]:
+ print(
+ f"\nš Demo benchmark completed successfully! All performance targets achieved."
+ )
print(f" Results saved to: {logs_dir}")
print(f"\nš” FACT Algorithm Analysis:")
print(f" ⢠Fast Access Caching Technology is performing within targets")
- print(f" ⢠Cache hit latency: {validation_results['benchmark_summary'].avg_hit_latency_ms:.1f}ms (target: ā¤48ms)")
- print(f" ⢠Cache miss latency: {validation_results['benchmark_summary'].avg_miss_latency_ms:.1f}ms (target: ā¤140ms)")
- print(f" ⢠Cost reduction: {validation_results['benchmark_summary'].cost_reduction_percentage:.1f}% (target: ā„90%)")
- print(f" ⢠Cache hit rate: {validation_results['benchmark_summary'].cache_hit_rate*100:.1f}% (target: ā„60%)")
+ print(
+ f" ⢠Cache hit latency: {validation_results['benchmark_summary'].avg_hit_latency_ms:.1f}ms (target: ā¤48ms)"
+ )
+ print(
+ f" ⢠Cache miss latency: {validation_results['benchmark_summary'].avg_miss_latency_ms:.1f}ms (target: ā¤140ms)"
+ )
+ print(
+ f" ⢠Cost reduction: {validation_results['benchmark_summary'].cost_reduction_percentage:.1f}% (target: ā„90%)"
+ )
+ print(
+ f" ⢠Cache hit rate: {validation_results['benchmark_summary'].cache_hit_rate*100:.1f}% (target: ā„60%)"
+ )
else:
print(f"\nā ļø Demo benchmark completed with some targets not met.")
print(f" Review optimization strategies and detailed reports in: {logs_dir}")
-
- return validation_results['overall_pass']
+
+ return validation_results["overall_pass"]
+
def main():
"""Main entry point."""
@@ -365,22 +407,37 @@ def main():
# Custom performance targets
python scripts/run_benchmarks_demo.py --hit-target 40 --miss-target 120 --cost-reduction 85
- """
+ """,
)
-
+
# Benchmark configuration
- parser.add_argument('--iterations', type=int, default=100, help='Number of benchmark iterations')
-
+ parser.add_argument(
+ "--iterations", type=int, default=100, help="Number of benchmark iterations"
+ )
+
# Performance targets
- parser.add_argument('--hit-target', type=float, default=48.0, help='Cache hit latency target (ms)')
- parser.add_argument('--miss-target', type=float, default=140.0, help='Cache miss latency target (ms)')
- parser.add_argument('--cache-hit-rate', type=float, default=60.0, help='Cache hit rate target (percent)')
-
+ parser.add_argument(
+ "--hit-target", type=float, default=48.0, help="Cache hit latency target (ms)"
+ )
+ parser.add_argument(
+ "--miss-target",
+ type=float,
+ default=140.0,
+ help="Cache miss latency target (ms)",
+ )
+ parser.add_argument(
+ "--cache-hit-rate",
+ type=float,
+ default=60.0,
+ help="Cache hit rate target (percent)",
+ )
+
args = parser.parse_args()
-
+
# Run demo benchmark
success = asyncio.run(run_demo_benchmark(args))
sys.exit(0 if success else 1)
+
if __name__ == "__main__":
- main()
\ No newline at end of file
+ main()
diff --git a/scripts/run_benchmarks_standalone.py b/scripts/run_benchmarks_standalone.py
index e378384..0dc8d0f 100644
--- a/scripts/run_benchmarks_standalone.py
+++ b/scripts/run_benchmarks_standalone.py
@@ -17,78 +17,101 @@
from pathlib import Path
from typing import Optional, Dict, Any
+
def create_logs_directory(base_dir: str = "logs") -> Path:
"""Create timestamped logs directory and return the path."""
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
logs_dir = Path(base_dir) / f"benchmark_{timestamp}"
logs_dir.mkdir(parents=True, exist_ok=True)
-
+
# Create subdirectories for better organization
(logs_dir / "charts").mkdir(exist_ok=True)
(logs_dir / "raw_data").mkdir(exist_ok=True)
(logs_dir / "reports").mkdir(exist_ok=True)
-
+
return logs_dir
-def print_performance_summary(validation_results, comparison_result=None, load_test_results=None):
+
+def print_performance_summary(
+ validation_results, comparison_result=None, load_test_results=None
+):
"""Print a comprehensive performance summary to console."""
- print("\n" + "="*80)
+ print("\n" + "=" * 80)
print("š FACT SYSTEM PERFORMANCE SUMMARY")
- print("="*80)
-
+ print("=" * 80)
+
# Overall Status
- overall_pass = validation_results['overall_pass']
+ overall_pass = validation_results["overall_pass"]
status_emoji = "š" if overall_pass else "ā ļø"
status_text = "ALL TARGETS MET" if overall_pass else "IMPROVEMENTS NEEDED"
print(f"\n{status_emoji} OVERALL STATUS: {status_text}")
-
+
# Performance Targets Summary
print(f"\nš PERFORMANCE TARGETS:")
print("-" * 50)
- target_validation = validation_results['target_validation']
-
+ target_validation = validation_results["target_validation"]
+
for target_name, target_data in target_validation.items():
- status = "ā
PASS" if target_data['met'] else "ā FAIL"
- if 'latency' in target_name:
+ status = "ā
PASS" if target_data["met"] else "ā FAIL"
+ if "latency" in target_name:
actual_val = f"{target_data.get('actual_ms', 0):.1f}ms"
target_val = f"{target_data.get('target_ms', 0):.1f}ms"
else:
actual_val = f"{target_data.get('actual_percent', 0):.1f}%"
target_val = f"{target_data.get('target_percent', 0):.1f}%"
-
- print(f" {target_name:25} {status:8} Actual: {actual_val:10} Target: {target_val}")
-
+
+ print(
+ f" {target_name:25} {status:8} Actual: {actual_val:10} Target: {target_val}"
+ )
+
# Cache Performance
- summary = validation_results.get('benchmark_summary', {})
+ summary = validation_results.get("benchmark_summary", {})
if summary:
print(f"\nšļø CACHE PERFORMANCE:")
print("-" * 50)
- print(f" Cache Hit Rate: {summary.get('cache_hit_rate', 0)*100:.1f}%")
- print(f" Avg Response Time (Hit): {summary.get('avg_hit_latency_ms', 0):.1f}ms")
- print(f" Avg Response Time (Miss): {summary.get('avg_miss_latency_ms', 0):.1f}ms")
+ print(
+ f" Cache Hit Rate: {summary.get('cache_hit_rate', 0)*100:.1f}%"
+ )
+ print(
+ f" Avg Response Time (Hit): {summary.get('avg_hit_latency_ms', 0):.1f}ms"
+ )
+ print(
+ f" Avg Response Time (Miss): {summary.get('avg_miss_latency_ms', 0):.1f}ms"
+ )
print(f" Total Requests: {summary.get('total_requests', 0)}")
print(f" Success Rate: {summary.get('success_rate', 0)*100:.1f}%")
-
+
# RAG Comparison Results
if comparison_result:
print(f"\nāļø FACT vs TRADITIONAL RAG:")
print("-" * 50)
- latency_improvement = comparison_result.get('latency_improvement', 1.0)
- cost_savings = comparison_result.get('cost_savings', 0.0)
+ latency_improvement = comparison_result.get("latency_improvement", 1.0)
+ cost_savings = comparison_result.get("cost_savings", 0.0)
print(f" Latency Improvement: {latency_improvement:.1f}x faster")
print(f" Cost Savings: {cost_savings:.1f}%")
- print(f" Recommendation: {comparison_result.get('recommendation', 'N/A')}")
-
+ print(
+ f" Recommendation: {comparison_result.get('recommendation', 'N/A')}"
+ )
+
# Load Test Results
if load_test_results:
print(f"\nš¦ LOAD TEST PERFORMANCE:")
print("-" * 50)
- print(f" Concurrent Users: {load_test_results.get('concurrent_users', 0)}")
- print(f" Throughput: {load_test_results.get('throughput_qps', 0):.1f} QPS")
- print(f" Avg Response Time: {load_test_results.get('avg_response_time_ms', 0):.1f}ms")
- print(f" Error Rate: {load_test_results.get('error_rate', 0):.1f}%")
-
- print("="*80)
+ print(
+ f" Concurrent Users: {load_test_results.get('concurrent_users', 0)}"
+ )
+ print(
+ f" Throughput: {load_test_results.get('throughput_qps', 0):.1f} QPS"
+ )
+ print(
+ f" Avg Response Time: {load_test_results.get('avg_response_time_ms', 0):.1f}ms"
+ )
+ print(
+ f" Error Rate: {load_test_results.get('error_rate', 0):.1f}%"
+ )
+
+ print("=" * 80)
+
def generate_sample_benchmark_results(args) -> Dict[str, Any]:
"""Generate sample benchmark results for demonstration."""
@@ -96,251 +119,268 @@ def generate_sample_benchmark_results(args) -> Dict[str, Any]:
cache_hit_rate = 0.67
avg_hit_latency = 42.3
avg_miss_latency = 128.7
-
+
validation_results = {
- 'overall_pass': True,
- 'target_validation': {
- 'cache_hit_latency': {
- 'met': avg_hit_latency <= args.hit_target,
- 'actual_ms': avg_hit_latency,
- 'target_ms': args.hit_target
+ "overall_pass": True,
+ "target_validation": {
+ "cache_hit_latency": {
+ "met": avg_hit_latency <= args.hit_target,
+ "actual_ms": avg_hit_latency,
+ "target_ms": args.hit_target,
+ },
+ "cache_miss_latency": {
+ "met": avg_miss_latency <= args.miss_target,
+ "actual_ms": avg_miss_latency,
+ "target_ms": args.miss_target,
},
- 'cache_miss_latency': {
- 'met': avg_miss_latency <= args.miss_target,
- 'actual_ms': avg_miss_latency,
- 'target_ms': args.miss_target
+ "cache_hit_rate": {
+ "met": cache_hit_rate * 100 >= args.cache_hit_rate,
+ "actual_percent": cache_hit_rate * 100,
+ "target_percent": args.cache_hit_rate,
},
- 'cache_hit_rate': {
- 'met': cache_hit_rate * 100 >= args.cache_hit_rate,
- 'actual_percent': cache_hit_rate * 100,
- 'target_percent': args.cache_hit_rate
+ "cost_reduction": {
+ "met": 91.5 >= args.cost_reduction,
+ "actual_percent": 91.5,
+ "target_percent": args.cost_reduction,
},
- 'cost_reduction': {
- 'met': 91.5 >= args.cost_reduction,
- 'actual_percent': 91.5,
- 'target_percent': args.cost_reduction
- }
},
- 'benchmark_summary': {
- 'cache_hit_rate': cache_hit_rate,
- 'avg_hit_latency_ms': avg_hit_latency,
- 'avg_miss_latency_ms': avg_miss_latency,
- 'total_requests': args.iterations,
- 'success_rate': 1.0
- }
+ "benchmark_summary": {
+ "cache_hit_rate": cache_hit_rate,
+ "avg_hit_latency_ms": avg_hit_latency,
+ "avg_miss_latency_ms": avg_miss_latency,
+ "total_requests": args.iterations,
+ "success_rate": 1.0,
+ },
}
-
+
# Update overall pass based on individual targets
- validation_results['overall_pass'] = all(
- target['met'] for target in validation_results['target_validation'].values()
+ validation_results["overall_pass"] = all(
+ target["met"] for target in validation_results["target_validation"].values()
)
-
+
return validation_results
+
def generate_sample_comparison_results() -> Dict[str, Any]:
"""Generate sample RAG comparison results."""
return {
- 'latency_improvement': 3.2,
- 'cost_savings': 91.5,
- 'recommendation': 'FACT shows excellent performance gains over traditional RAG'
+ "latency_improvement": 3.2,
+ "cost_savings": 91.5,
+ "recommendation": "FACT shows excellent performance gains over traditional RAG",
}
+
def generate_sample_load_test_results(users: int) -> Dict[str, Any]:
"""Generate sample load test results."""
return {
- 'concurrent_users': users,
- 'throughput_qps': users * 2.5,
- 'avg_response_time_ms': 65.2,
- 'error_rate': 0.1
+ "concurrent_users": users,
+ "throughput_qps": users * 2.5,
+ "avg_response_time_ms": 65.2,
+ "error_rate": 0.1,
}
+
async def run_standalone_benchmark(args):
"""Run standalone benchmark demonstration."""
print("š Starting FACT Comprehensive Benchmark Suite (Standalone Demo)")
print("=" * 60)
-
+
# Create timestamped logs directory
- if args.output_dir == './benchmark_results':
+ if args.output_dir == "./benchmark_results":
logs_dir = create_logs_directory()
print(f"š Created logs directory: {logs_dir}")
else:
logs_dir = Path(args.output_dir)
logs_dir.mkdir(parents=True, exist_ok=True)
print(f"š Using output directory: {logs_dir}")
-
+
# Phase 1: Performance Validation
print("\nš Phase 1: Performance Validation")
print("-" * 40)
-
+
validation_results = generate_sample_benchmark_results(args)
-
+
# Display validation results
- print(f"Overall Validation: {'ā
PASS' if validation_results['overall_pass'] else 'ā FAIL'}")
-
- target_validation = validation_results['target_validation']
+ print(
+ f"Overall Validation: {'ā
PASS' if validation_results['overall_pass'] else 'ā FAIL'}"
+ )
+
+ target_validation = validation_results["target_validation"]
for target_name, target_data in target_validation.items():
- status = "ā
PASS" if target_data['met'] else "ā FAIL"
- if 'latency' in target_name:
+ status = "ā
PASS" if target_data["met"] else "ā FAIL"
+ if "latency" in target_name:
value = f"{target_data['actual_ms']:.1f}ms"
else:
value = f"{target_data['actual_percent']:.1f}%"
print(f" {target_name}: {status} ({value})")
-
+
# Phase 2: RAG Comparison (if enabled)
comparison_result = None
if args.include_rag_comparison:
print("\nāļø Phase 2: RAG Comparison Analysis")
print("-" * 40)
-
+
comparison_result = generate_sample_comparison_results()
-
+
print(f"Latency Improvement: {comparison_result['latency_improvement']:.1f}x")
print(f"Cost Savings: {comparison_result['cost_savings']:.1f}%")
print(f"Recommendation: {comparison_result['recommendation']}")
-
+
# Phase 3: Profiling Analysis (if enabled)
if args.include_profiling:
print("\nš Phase 3: Performance Profiling")
print("-" * 40)
-
+
print("Execution Time: 1250.3ms")
print("Bottlenecks Found: 2")
print("ā Critical Bottlenecks:")
print(" - Database Query: Slow index lookup detected")
print(" - Network: High latency in external API calls")
-
+
# Phase 4: Load Testing (if enabled)
load_test_results = None
if args.include_load_test:
print("\nš¦ Phase 4: Load Testing")
print("-" * 40)
-
+
load_test_results = generate_sample_load_test_results(args.load_test_users)
-
+
print(f"Concurrent Users: {load_test_results['concurrent_users']}")
print(f"Throughput: {load_test_results['throughput_qps']:.1f} QPS")
print(f"Avg Response Time: {load_test_results['avg_response_time_ms']:.1f}ms")
print(f"Error Rate: {load_test_results['error_rate']:.1f}%")
-
+
# Phase 5: Report Generation & Visualization
print("\nš Phase 5: Report Generation & Visualization")
print("-" * 40)
-
+
# Generate timestamp for consistent naming
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
-
+
# Create comprehensive report data
report_data = {
- 'metadata': {
- 'timestamp': timestamp,
- 'benchmark_version': '1.0.0',
- 'args': vars(args)
+ "metadata": {
+ "timestamp": timestamp,
+ "benchmark_version": "1.0.0",
+ "args": vars(args),
},
- 'validation_results': validation_results,
- 'comparison_result': comparison_result,
- 'load_test_results': load_test_results,
- 'performance_grade': 'A+' if validation_results['overall_pass'] else 'B',
- 'recommendations': [
- 'Cache performance is excellent - maintain current configuration',
- 'Consider increasing cache size for even better hit rates',
- 'Monitor performance under higher concurrent load',
- 'Implement database query optimization for better response times'
- ]
+ "validation_results": validation_results,
+ "comparison_result": comparison_result,
+ "load_test_results": load_test_results,
+ "performance_grade": "A+" if validation_results["overall_pass"] else "B",
+ "recommendations": [
+ "Cache performance is excellent - maintain current configuration",
+ "Consider increasing cache size for even better hit rates",
+ "Monitor performance under higher concurrent load",
+ "Implement database query optimization for better response times",
+ ],
}
-
+
# Save comprehensive JSON report
json_report_path = logs_dir / "reports" / f"benchmark_report_{timestamp}.json"
- with open(json_report_path, 'w') as f:
+ with open(json_report_path, "w") as f:
json.dump(report_data, f, indent=2, default=str)
-
+
# Save text summary
text_report_path = logs_dir / "reports" / f"benchmark_summary_{timestamp}.txt"
- with open(text_report_path, 'w') as f:
+ with open(text_report_path, "w") as f:
f.write("FACT Benchmark Summary\n")
f.write("=" * 50 + "\n\n")
f.write(f"Timestamp: {timestamp}\n")
f.write(f"Performance Grade: {report_data['performance_grade']}\n")
f.write(f"Overall Pass: {validation_results['overall_pass']}\n\n")
-
+
f.write("Performance Targets:\n")
for target_name, target_data in target_validation.items():
- status = "PASS" if target_data['met'] else "FAIL"
+ status = "PASS" if target_data["met"] else "FAIL"
f.write(f" {target_name}: {status}\n")
-
+
f.write("\nRecommendations:\n")
- for i, rec in enumerate(report_data['recommendations'], 1):
+ for i, rec in enumerate(report_data["recommendations"], 1):
f.write(f" {i}. {rec}\n")
-
+
# Save raw data for further analysis
raw_data_path = logs_dir / "raw_data" / f"raw_results_{timestamp}.json"
- with open(raw_data_path, 'w') as f:
+ with open(raw_data_path, "w") as f:
json.dump(report_data, f, indent=2, default=str)
-
+
# Generate sample visualization data
charts_dir = logs_dir / "charts"
-
+
print("š Generating performance visualizations...")
-
+
# Performance overview chart
performance_chart = {
- 'chart_type': 'performance_overview',
- 'title': 'FACT Performance Overview',
- 'data': {
- 'cache_hit_rate': validation_results['benchmark_summary']['cache_hit_rate'],
- 'avg_hit_latency': validation_results['benchmark_summary']['avg_hit_latency_ms'],
- 'avg_miss_latency': validation_results['benchmark_summary']['avg_miss_latency_ms'],
- 'success_rate': validation_results['benchmark_summary']['success_rate']
- }
+ "chart_type": "performance_overview",
+ "title": "FACT Performance Overview",
+ "data": {
+ "cache_hit_rate": validation_results["benchmark_summary"]["cache_hit_rate"],
+ "avg_hit_latency": validation_results["benchmark_summary"][
+ "avg_hit_latency_ms"
+ ],
+ "avg_miss_latency": validation_results["benchmark_summary"][
+ "avg_miss_latency_ms"
+ ],
+ "success_rate": validation_results["benchmark_summary"]["success_rate"],
+ },
}
-
+
performance_chart_path = charts_dir / f"performance_overview_{timestamp}.json"
- with open(performance_chart_path, 'w') as f:
+ with open(performance_chart_path, "w") as f:
json.dump(performance_chart, f, indent=2)
-
+
# Latency comparison chart (if comparison data available)
if comparison_result:
latency_chart = {
- 'chart_type': 'latency_comparison',
- 'title': 'FACT vs Traditional RAG Latency',
- 'data': {
- 'fact_latency': validation_results['benchmark_summary']['avg_hit_latency_ms'],
- 'rag_latency': validation_results['benchmark_summary']['avg_hit_latency_ms'] * comparison_result['latency_improvement'],
- 'improvement_factor': comparison_result['latency_improvement']
- }
+ "chart_type": "latency_comparison",
+ "title": "FACT vs Traditional RAG Latency",
+ "data": {
+ "fact_latency": validation_results["benchmark_summary"][
+ "avg_hit_latency_ms"
+ ],
+ "rag_latency": validation_results["benchmark_summary"][
+ "avg_hit_latency_ms"
+ ]
+ * comparison_result["latency_improvement"],
+ "improvement_factor": comparison_result["latency_improvement"],
+ },
}
-
+
latency_chart_path = charts_dir / f"latency_comparison_{timestamp}.json"
- with open(latency_chart_path, 'w') as f:
+ with open(latency_chart_path, "w") as f:
json.dump(latency_chart, f, indent=2)
-
+
print(f"š Reports saved to: {logs_dir}")
print(f"š JSON Report: {json_report_path}")
print(f"š Summary: {text_report_path}")
print(f"š Raw Data: {raw_data_path}")
print(f"š Charts: {charts_dir}")
-
+
# Print comprehensive performance summary to console
print_performance_summary(validation_results, comparison_result, load_test_results)
-
+
# Performance Grade
- grade = report_data['performance_grade']
+ grade = report_data["performance_grade"]
print(f"\nš Performance Grade: {grade}")
-
+
# Print key recommendations
print(f"\nš§ KEY RECOMMENDATIONS:")
print("-" * 50)
- for i, rec in enumerate(report_data['recommendations'][:5], 1):
+ for i, rec in enumerate(report_data["recommendations"][:5], 1):
print(f" {i}. {rec}")
-
+
# Final status message
- if validation_results['overall_pass']:
- print(f"\nš Benchmark completed successfully! All performance targets achieved.")
+ if validation_results["overall_pass"]:
+ print(
+ f"\nš Benchmark completed successfully! All performance targets achieved."
+ )
print(f" Results saved to: {logs_dir}")
else:
print(f"\nā ļø Benchmark completed with some targets not met.")
print(f" Review optimization strategies and detailed reports in: {logs_dir}")
-
- return validation_results['overall_pass']
+
+ return validation_results["overall_pass"]
+
def main():
"""Main entry point."""
@@ -357,35 +397,69 @@ def main():
# Custom performance targets
python scripts/run_benchmarks_standalone.py --hit-target 40 --miss-target 120 --cost-reduction 85
- """
+ """,
)
-
+
# Benchmark configuration
- parser.add_argument('--iterations', type=int, default=10, help='Number of benchmark iterations')
- parser.add_argument('--concurrent-users', type=int, default=1, help='Number of concurrent users')
-
+ parser.add_argument(
+ "--iterations", type=int, default=10, help="Number of benchmark iterations"
+ )
+ parser.add_argument(
+ "--concurrent-users", type=int, default=1, help="Number of concurrent users"
+ )
+
# Performance targets
- parser.add_argument('--hit-target', type=float, default=48.0, help='Cache hit latency target (ms)')
- parser.add_argument('--miss-target', type=float, default=140.0, help='Cache miss latency target (ms)')
- parser.add_argument('--cost-reduction', type=float, default=90.0, help='Cost reduction target (percent)')
- parser.add_argument('--cache-hit-rate', type=float, default=60.0, help='Cache hit rate target (percent)')
-
+ parser.add_argument(
+ "--hit-target", type=float, default=48.0, help="Cache hit latency target (ms)"
+ )
+ parser.add_argument(
+ "--miss-target",
+ type=float,
+ default=140.0,
+ help="Cache miss latency target (ms)",
+ )
+ parser.add_argument(
+ "--cost-reduction",
+ type=float,
+ default=90.0,
+ help="Cost reduction target (percent)",
+ )
+ parser.add_argument(
+ "--cache-hit-rate",
+ type=float,
+ default=60.0,
+ help="Cache hit rate target (percent)",
+ )
+
# Optional components
- parser.add_argument('--include-rag-comparison', action='store_true', help='Include RAG comparison')
- parser.add_argument('--include-profiling', action='store_true', help='Include performance profiling')
- parser.add_argument('--include-load-test', action='store_true', help='Include load testing')
-
+ parser.add_argument(
+ "--include-rag-comparison", action="store_true", help="Include RAG comparison"
+ )
+ parser.add_argument(
+ "--include-profiling", action="store_true", help="Include performance profiling"
+ )
+ parser.add_argument(
+ "--include-load-test", action="store_true", help="Include load testing"
+ )
+
# Load testing configuration
- parser.add_argument('--load-test-users', type=int, default=10, help='Load test concurrent users')
-
+ parser.add_argument(
+ "--load-test-users", type=int, default=10, help="Load test concurrent users"
+ )
+
# Output configuration
- parser.add_argument('--output-dir', default='./benchmark_results', help='Output directory for reports (default: creates timestamped logs directory)')
-
+ parser.add_argument(
+ "--output-dir",
+ default="./benchmark_results",
+ help="Output directory for reports (default: creates timestamped logs directory)",
+ )
+
args = parser.parse_args()
-
+
# Run standalone benchmark demo
success = asyncio.run(run_standalone_benchmark(args))
sys.exit(0 if success else 1)
+
if __name__ == "__main__":
- main()
\ No newline at end of file
+ main()
diff --git a/scripts/run_optimized_benchmarks.py b/scripts/run_optimized_benchmarks.py
index 1651e75..b8c24be 100644
--- a/scripts/run_optimized_benchmarks.py
+++ b/scripts/run_optimized_benchmarks.py
@@ -22,7 +22,10 @@
from cache.manager import CacheManager, get_cache_manager
from cache.warming import get_cache_warmer, warm_cache_startup
from cache.metrics import get_metrics_collector
-from monitoring.performance_optimizer import get_performance_optimizer, start_performance_optimization
+from monitoring.performance_optimizer import (
+ get_performance_optimizer,
+ start_performance_optimization,
+)
from core.config import get_config
@@ -31,13 +34,13 @@ def create_optimized_logs_directory(base_dir: str = "./logs") -> Path:
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
logs_dir = Path(base_dir) / f"optimized_benchmark_{timestamp}"
logs_dir.mkdir(parents=True, exist_ok=True)
-
+
# Create subdirectories for organized output
(logs_dir / "reports").mkdir(exist_ok=True)
(logs_dir / "charts").mkdir(exist_ok=True)
(logs_dir / "raw_data").mkdir(exist_ok=True)
(logs_dir / "optimization_logs").mkdir(exist_ok=True)
-
+
return logs_dir
@@ -59,17 +62,17 @@ async def run_optimized_benchmark_suite(args):
print("š Starting FACT Optimized Benchmark Suite")
print(" Enhanced with intelligent cache warming, latency optimization,")
print(" and real-time performance monitoring\n")
-
+
display_benchmark_targets()
-
+
# Create timestamped logs directory
logs_dir = create_optimized_logs_directory()
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
-
+
try:
# Initialize optimized configuration
print("\nš§ Initializing optimized FACT system...")
-
+
# Enhanced cache configuration for optimization
cache_config = {
"prefix": "fact_optimized_benchmark",
@@ -77,33 +80,37 @@ async def run_optimized_benchmark_suite(args):
"max_size": "15MB", # Increased cache size
"ttl_seconds": 7200, # 2 hours
"hit_target_ms": 48.0, # Updated target
- "miss_target_ms": 140.0
+ "miss_target_ms": 140.0,
}
-
+
# Initialize cache manager with optimizations
cache_manager = CacheManager(cache_config)
-
+
# Initialize performance optimizer
optimizer = get_performance_optimizer(cache_manager)
-
+
# Start optimization monitoring
print("šÆ Starting performance optimization monitoring...")
await start_performance_optimization(cache_manager)
-
+
# Intelligent cache warming
print("š„ Performing intelligent cache warming...")
cache_warmer = get_cache_warmer(cache_manager)
-
+
# Pre-warm with optimized queries
- warmup_result = await cache_warmer.warm_cache_intelligently(max_queries=args.warmup_queries)
- print(f" ā Warmed {warmup_result.queries_successful}/{warmup_result.queries_attempted} queries")
+ warmup_result = await cache_warmer.warm_cache_intelligently(
+ max_queries=args.warmup_queries
+ )
+ print(
+ f" ā Warmed {warmup_result.queries_successful}/{warmup_result.queries_attempted} queries"
+ )
print(f" ā Cached {warmup_result.total_tokens_cached} tokens")
print(f" ā Warming completed in {warmup_result.total_time_ms:.1f}ms")
-
+
# Wait for optimization to take effect
print("\nā³ Allowing optimization system to analyze and adjust...")
await asyncio.sleep(5)
-
+
# Initialize benchmark framework with optimized settings
benchmark_config = BenchmarkConfig(
iterations=args.iterations,
@@ -113,12 +120,12 @@ async def run_optimized_benchmark_suite(args):
target_cache_hit_rate=0.60,
target_hit_latency_ms=48.0,
target_miss_latency_ms=140.0,
- target_cost_reduction_hit=0.90
+ target_cost_reduction_hit=0.90,
)
-
+
framework = BenchmarkFramework(benchmark_config)
runner = BenchmarkRunner(framework)
-
+
# Enhanced test queries for comprehensive evaluation
enhanced_queries = [
"What was the Q1-2025 revenue?",
@@ -135,30 +142,32 @@ async def run_optimized_benchmark_suite(args):
"Show customer satisfaction scores",
"Analyze market performance by region",
"What is current profit margin?",
- "Show predictive revenue forecasts"
+ "Show predictive revenue forecasts",
]
-
+
# Run performance validation
print("\nš Running optimized performance validation...")
print(f" ⢠Testing {len(enhanced_queries)} diverse queries")
- print(f" ⢠{args.iterations} iterations with {args.concurrent_users} concurrent users")
+ print(
+ f" ⢠{args.iterations} iterations with {args.concurrent_users} concurrent users"
+ )
print(f" ⢠Real-time optimization monitoring active")
-
+
# Execute benchmarks with enhanced monitoring
start_time = time.time()
-
+
# Run enhanced benchmark suite
summary = await framework.run_benchmark_suite(enhanced_queries, cache_manager)
-
+
execution_time = time.time() - start_time
-
+
# Get optimization status
opt_status = optimizer.get_optimization_status()
-
+
# Collect comprehensive metrics
final_metrics = cache_manager.get_metrics()
performance_stats = cache_manager.get_performance_stats()
-
+
# Generate validation results
validation_results = {
"timestamp": time.time(),
@@ -174,105 +183,125 @@ async def run_optimized_benchmark_suite(args):
"avg_response_time_ms": summary.avg_response_time_ms,
"cost_reduction_percentage": summary.cost_reduction_percentage,
"error_rate": summary.error_rate,
- "throughput_qps": summary.throughput_qps
+ "throughput_qps": summary.throughput_qps,
},
"target_validation": {
"cache_hit_latency": {
"target_ms": 48.0,
"actual_ms": summary.avg_hit_latency_ms,
"met": summary.avg_hit_latency_ms <= 48.0,
- "status": "PASS" if summary.avg_hit_latency_ms <= 48.0 else "FAIL"
+ "status": "PASS" if summary.avg_hit_latency_ms <= 48.0 else "FAIL",
},
"cache_miss_latency": {
"target_ms": 140.0,
"actual_ms": summary.avg_miss_latency_ms,
"met": summary.avg_miss_latency_ms <= 140.0,
- "status": "PASS" if summary.avg_miss_latency_ms <= 140.0 else "FAIL"
+ "status": (
+ "PASS" if summary.avg_miss_latency_ms <= 140.0 else "FAIL"
+ ),
},
"cost_reduction": {
"target_percent": 90.0,
"actual_percent": summary.cost_reduction_percentage,
"met": summary.cost_reduction_percentage >= 90.0,
- "status": "PASS" if summary.cost_reduction_percentage >= 90.0 else "FAIL"
+ "status": (
+ "PASS" if summary.cost_reduction_percentage >= 90.0 else "FAIL"
+ ),
},
"cache_hit_rate": {
"target_percent": 60.0,
"actual_percent": summary.cache_hit_rate * 100,
"met": (summary.cache_hit_rate * 100) >= 60.0,
- "status": "PASS" if (summary.cache_hit_rate * 100) >= 60.0 else "FAIL"
- }
+ "status": (
+ "PASS" if (summary.cache_hit_rate * 100) >= 60.0 else "FAIL"
+ ),
+ },
},
"optimization_status": opt_status,
"cache_metrics": {
"total_entries": final_metrics.total_entries,
- "memory_utilization_percent": (final_metrics.total_size / cache_manager.max_size_bytes * 100),
+ "memory_utilization_percent": (
+ final_metrics.total_size / cache_manager.max_size_bytes * 100
+ ),
"hit_rate_percent": final_metrics.hit_rate,
- "token_efficiency": final_metrics.token_efficiency
+ "token_efficiency": final_metrics.token_efficiency,
},
"performance_improvements": {
"warming_result": {
"queries_warmed": warmup_result.queries_successful,
"tokens_cached": warmup_result.total_tokens_cached,
- "warming_time_ms": warmup_result.total_time_ms
+ "warming_time_ms": warmup_result.total_time_ms,
},
- "optimization_actions": len(opt_status.get('recent_actions', [])),
- "strategy": opt_status.get('current_strategy', 'unknown')
- }
+ "optimization_actions": len(opt_status.get("recent_actions", [])),
+ "strategy": opt_status.get("current_strategy", "unknown"),
+ },
}
-
+
# Calculate overall pass status
targets = validation_results["target_validation"]
- validation_results["overall_pass"] = all([
- targets["cache_hit_latency"]["met"],
- targets["cache_miss_latency"]["met"],
- targets["cost_reduction"]["met"],
- targets["cache_hit_rate"]["met"]
- ])
-
+ validation_results["overall_pass"] = all(
+ [
+ targets["cache_hit_latency"]["met"],
+ targets["cache_miss_latency"]["met"],
+ targets["cost_reduction"]["met"],
+ targets["cache_hit_rate"]["met"],
+ ]
+ )
+
# Save comprehensive results
await save_benchmark_results(validation_results, logs_dir, timestamp)
-
+
# Display results
display_optimization_results(validation_results, logs_dir)
-
+
# Stop optimization monitoring
from monitoring.performance_optimizer import stop_performance_optimization
+
await stop_performance_optimization()
-
+
return validation_results["overall_pass"]
-
+
except Exception as e:
print(f"\nā Benchmark execution failed: {str(e)}")
import traceback
+
traceback.print_exc()
return False
async def save_benchmark_results(results: dict, logs_dir: Path, timestamp: str):
"""Save comprehensive benchmark results."""
-
+
# Save JSON report
- json_report_path = logs_dir / "reports" / f"optimized_benchmark_report_{timestamp}.json"
- with open(json_report_path, 'w') as f:
+ json_report_path = (
+ logs_dir / "reports" / f"optimized_benchmark_report_{timestamp}.json"
+ )
+ with open(json_report_path, "w") as f:
json.dump(results, f, indent=2, default=str)
-
+
# Save text summary
- text_report_path = logs_dir / "reports" / f"optimized_benchmark_summary_{timestamp}.txt"
- with open(text_report_path, 'w') as f:
+ text_report_path = (
+ logs_dir / "reports" / f"optimized_benchmark_summary_{timestamp}.txt"
+ )
+ with open(text_report_path, "w") as f:
f.write("FACT Optimized Benchmark Results\n")
f.write("=" * 50 + "\n\n")
-
+
f.write(f"Execution Time: {results['execution_time_seconds']:.2f} seconds\n")
f.write(f"Timestamp: {datetime.fromtimestamp(results['timestamp'])}\n\n")
-
+
f.write("Performance Targets:\n")
for metric, data in results["target_validation"].items():
status = "ā
PASS" if data["met"] else "ā FAIL"
- f.write(f" {metric}: {status} - Target: {data.get('target_ms', data.get('target_percent', 'N/A'))}, "
- f"Actual: {data.get('actual_ms', data.get('actual_percent', 'N/A'))}\n")
-
- f.write(f"\nOverall Result: {'ā
PASS' if results['overall_pass'] else 'ā FAIL'}\n\n")
-
+ f.write(
+ f" {metric}: {status} - Target: {data.get('target_ms', data.get('target_percent', 'N/A'))}, "
+ f"Actual: {data.get('actual_ms', data.get('actual_percent', 'N/A'))}\n"
+ )
+
+ f.write(
+ f"\nOverall Result: {'ā
PASS' if results['overall_pass'] else 'ā FAIL'}\n\n"
+ )
+
summary = results["benchmark_summary"]
f.write("Benchmark Summary:\n")
f.write(f" Total Queries: {summary['total_queries']}\n")
@@ -281,41 +310,52 @@ async def save_benchmark_results(results: dict, logs_dir: Path, timestamp: str):
f.write(f" Average Miss Latency: {summary['avg_miss_latency_ms']:.1f}ms\n")
f.write(f" Cost Reduction: {summary['cost_reduction_percentage']:.1f}%\n")
f.write(f" Throughput: {summary['throughput_qps']:.1f} QPS\n")
-
+
# Save raw data
raw_data_path = logs_dir / "raw_data" / f"benchmark_data_{timestamp}.json"
- with open(raw_data_path, 'w') as f:
- json.dump({
- 'benchmark_summary': results['benchmark_summary'],
- 'cache_metrics': results['cache_metrics'],
- 'optimization_status': results['optimization_status']
- }, f, indent=2, default=str)
+ with open(raw_data_path, "w") as f:
+ json.dump(
+ {
+ "benchmark_summary": results["benchmark_summary"],
+ "cache_metrics": results["cache_metrics"],
+ "optimization_status": results["optimization_status"],
+ },
+ f,
+ indent=2,
+ default=str,
+ )
def display_optimization_results(results: dict, logs_dir: Path):
"""Display comprehensive optimization results."""
-
- print("\n" + "="*80)
+
+ print("\n" + "=" * 80)
print("šÆ FACT OPTIMIZATION BENCHMARK RESULTS")
- print("="*80)
-
+ print("=" * 80)
+
# Performance targets validation
print("\nš Performance Target Validation:")
targets = results["target_validation"]
-
+
for metric_name, data in targets.items():
status_icon = "ā
" if data["met"] else "ā"
metric_display = metric_name.replace("_", " ").title()
-
+
if "latency" in metric_name:
- print(f" {status_icon} {metric_display:<25} {data['actual_ms']:>8.1f}ms Target: ā¤{data['target_ms']}ms")
+ print(
+ f" {status_icon} {metric_display:<25} {data['actual_ms']:>8.1f}ms Target: ā¤{data['target_ms']}ms"
+ )
else:
- print(f" {status_icon} {metric_display:<25} {data['actual_percent']:>8.1f}% Target: ā„{data['target_percent']}%")
-
+ print(
+ f" {status_icon} {metric_display:<25} {data['actual_percent']:>8.1f}% Target: ā„{data['target_percent']}%"
+ )
+
# Overall result
- overall_status = "š SUCCESS" if results["overall_pass"] else "ā ļø NEEDS IMPROVEMENT"
+ overall_status = (
+ "š SUCCESS" if results["overall_pass"] else "ā ļø NEEDS IMPROVEMENT"
+ )
print(f"\nš Overall Result: {overall_status}")
-
+
# Performance summary
summary = results["benchmark_summary"]
print(f"\nš Performance Summary:")
@@ -324,30 +364,40 @@ def display_optimization_results(results: dict, logs_dir: Path):
print(f" ⢠Average Response Time: {summary['avg_response_time_ms']:.1f}ms")
print(f" ⢠System Throughput: {summary['throughput_qps']:.1f} QPS")
print(f" ⢠Error Rate: {summary['error_rate']:.1f}%")
-
+
# Optimization impact
improvements = results["performance_improvements"]
print(f"\nā” Optimization Impact:")
- print(f" ⢠Cache Entries Warmed: {improvements['warming_result']['queries_warmed']}")
- print(f" ⢠Tokens Pre-cached: {improvements['warming_result']['tokens_cached']:,}")
+ print(
+ f" ⢠Cache Entries Warmed: {improvements['warming_result']['queries_warmed']}"
+ )
+ print(
+ f" ⢠Tokens Pre-cached: {improvements['warming_result']['tokens_cached']:,}"
+ )
print(f" ⢠Optimization Actions: {improvements['optimization_actions']}")
print(f" ⢠Strategy: {improvements['strategy']}")
-
+
# Cost efficiency
print(f"\nš° Cost Efficiency:")
print(f" ⢠Cost Reduction Achieved: {summary['cost_reduction_percentage']:.1f}%")
- print(f" ⢠Cache Memory Utilization: {results['cache_metrics']['memory_utilization_percent']:.1f}%")
- print(f" ⢠Token Efficiency: {results['cache_metrics']['token_efficiency']:.1f} tokens/KB")
-
+ print(
+ f" ⢠Cache Memory Utilization: {results['cache_metrics']['memory_utilization_percent']:.1f}%"
+ )
+ print(
+ f" ⢠Token Efficiency: {results['cache_metrics']['token_efficiency']:.1f} tokens/KB"
+ )
+
# Results location
print(f"\nš Detailed Results Saved To:")
print(f" {logs_dir}")
print(f" āāā reports/optimized_benchmark_report_*.json")
print(f" āāā reports/optimized_benchmark_summary_*.txt")
print(f" āāā raw_data/benchmark_data_*.json")
-
+
if results["overall_pass"]:
- print(f"\nš Optimization successful! FACT system is performing within all targets.")
+ print(
+ f"\nš Optimization successful! FACT system is performing within all targets."
+ )
print(f" The enhanced caching system with intelligent warming and real-time")
print(f" optimization has achieved the required performance benchmarks.")
else:
@@ -378,29 +428,45 @@ def main():
# Load testing with optimization
python scripts/run_optimized_benchmarks.py --concurrent-users 10 --iterations 15
- """
+ """,
)
-
+
# Benchmark configuration
- parser.add_argument('--iterations', type=int, default=15,
- help='Number of benchmark iterations (default: 15)')
- parser.add_argument('--warmup', type=int, default=3,
- help='Number of warmup iterations (default: 3)')
- parser.add_argument('--warmup-queries', type=int, default=40,
- help='Number of queries to warm before benchmarking (default: 40)')
- parser.add_argument('--concurrent-users', type=int, default=5,
- help='Number of concurrent users for load testing (default: 5)')
-
+ parser.add_argument(
+ "--iterations",
+ type=int,
+ default=15,
+ help="Number of benchmark iterations (default: 15)",
+ )
+ parser.add_argument(
+ "--warmup", type=int, default=3, help="Number of warmup iterations (default: 3)"
+ )
+ parser.add_argument(
+ "--warmup-queries",
+ type=int,
+ default=40,
+ help="Number of queries to warm before benchmarking (default: 40)",
+ )
+ parser.add_argument(
+ "--concurrent-users",
+ type=int,
+ default=5,
+ help="Number of concurrent users for load testing (default: 5)",
+ )
+
# Output configuration
- parser.add_argument('--output-dir', default='./logs',
- help='Output directory for results (default: ./logs)')
-
+ parser.add_argument(
+ "--output-dir",
+ default="./logs",
+ help="Output directory for results (default: ./logs)",
+ )
+
args = parser.parse_args()
-
+
# Run optimized benchmark
success = asyncio.run(run_optimized_benchmark_suite(args))
sys.exit(0 if success else 1)
if __name__ == "__main__":
- main()
\ No newline at end of file
+ main()
diff --git a/scripts/setup.py b/scripts/setup.py
index 5b1d197..54e0256 100644
--- a/scripts/setup.py
+++ b/scripts/setup.py
@@ -27,30 +27,30 @@
async def setup_database():
"""Set up the database with schema and sample data."""
print("šļø Setting up database...")
-
+
try:
config = get_config()
-
+
# Create data directory if it doesn't exist
data_dir = Path(config.database_path).parent
data_dir.mkdir(parents=True, exist_ok=True)
print(f"ā
Created data directory: {data_dir}")
-
+
# Initialize database
db_manager = create_database_manager(config.database_path)
await db_manager.initialize_database()
-
+
# Get database info
db_info = await db_manager.get_database_info()
print(f"ā
Database initialized: {config.database_path}")
print(f" ⢠File size: {db_info['file_size_bytes']} bytes")
print(f" ⢠Tables: {db_info['total_tables']}")
-
- for table_name, table_info in db_info['tables'].items():
+
+ for table_name, table_info in db_info["tables"].items():
print(f" ⢠{table_name}: {table_info['row_count']} rows")
-
+
return True
-
+
except Exception as e:
print(f"ā Database setup failed: {e}")
return False
@@ -59,16 +59,16 @@ async def setup_database():
def setup_directories():
"""Create necessary directories."""
print("š Setting up directories...")
-
+
directories = [
"data",
- "logs",
+ "logs",
"output",
"docs/api",
"docs/deployment",
- "docs/development"
+ "docs/development",
]
-
+
for directory in directories:
Path(directory).mkdir(parents=True, exist_ok=True)
print(f"ā
Created directory: {directory}")
@@ -77,34 +77,36 @@ def setup_directories():
def check_environment():
"""Check environment configuration."""
print("āļø Checking environment configuration...")
-
+
env_file = Path(".env")
env_example = Path(".env.example")
-
+
if not env_file.exists():
if env_example.exists():
- print("ā .env file not found. Please copy .env.example to .env and configure it.")
+ print(
+ "ā .env file not found. Please copy .env.example to .env and configure it."
+ )
print(" cp .env.example .env")
return False
else:
print("ā Neither .env nor .env.example found.")
return False
-
+
print("ā
.env file found")
-
+
# Check for required environment variables
required_vars = ["ANTHROPIC_API_KEY", "ARCADE_API_KEY"]
missing_vars = []
-
+
for var in required_vars:
if not os.getenv(var):
missing_vars.append(var)
-
+
if missing_vars:
print(f"ā Missing required environment variables: {', '.join(missing_vars)}")
print(" Please configure these in your .env file")
return False
-
+
print("ā
Required environment variables configured")
return True
@@ -112,16 +114,17 @@ def check_environment():
def install_dependencies():
"""Check if dependencies are installed."""
print("š¦ Checking dependencies...")
-
+
try:
# Try importing key dependencies
import anthropic
import litellm
import aiosqlite
import structlog
+
print("ā
Core dependencies installed")
return True
-
+
except ImportError as e:
print(f"ā Missing dependencies: {e}")
print(" Please install dependencies: pip install -r requirements.txt")
@@ -132,27 +135,27 @@ async def main():
"""Main setup routine."""
print("š FACT System Setup")
print("=" * 50)
-
+
success = True
-
+
# Check dependencies
if not install_dependencies():
success = False
-
+
# Set up directories
setup_directories()
-
+
# Check environment
if not check_environment():
success = False
-
+
# Set up database (only if environment is configured)
if success:
if not await setup_database():
success = False
-
+
print("\n" + "=" * 50)
-
+
if success:
print("ā
FACT System setup completed successfully!")
print("\nNext steps:")
@@ -165,4 +168,4 @@ async def main():
if __name__ == "__main__":
- asyncio.run(main())
\ No newline at end of file
+ asyncio.run(main())
diff --git a/scripts/setup_env.py b/scripts/setup_env.py
index b4773e9..3ea30ce 100755
--- a/scripts/setup_env.py
+++ b/scripts/setup_env.py
@@ -19,23 +19,23 @@
from typing import Dict, Optional, List
# Add src to path for imports
-sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'src'))
+sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "src"))
class EnvironmentSetup:
"""Interactive environment setup for FACT system."""
-
+
API_KEY_PATTERNS = {
- 'ANTHROPIC_API_KEY': r'^sk-ant-api03-[A-Za-z0-9_-]+$',
- 'ARCADE_API_KEY': r'^arc_[A-Za-z0-9_-]+$',
- 'OPENAI_API_KEY': r'^sk-proj-[A-Za-z0-9_-]+$'
+ "ANTHROPIC_API_KEY": r"^sk-ant-api03-[A-Za-z0-9_-]+$",
+ "ARCADE_API_KEY": r"^arc_[A-Za-z0-9_-]+$",
+ "OPENAI_API_KEY": r"^sk-proj-[A-Za-z0-9_-]+$",
}
-
+
def __init__(self, force: bool = False, minimal: bool = False):
self.force = force
self.minimal = minimal
self.config: Dict[str, str] = {}
-
+
def print_header(self):
"""Print setup script header."""
print("š FACT Environment Configuration Setup")
@@ -43,81 +43,94 @@ def print_header(self):
print("This script will help you configure your FACT system environment.")
print("You'll need API keys from Anthropic and Arcade AI to proceed.")
print()
-
+
def check_existing_env(self) -> bool:
"""Check if .env file already exists."""
- env_path = Path('.env')
-
+ env_path = Path(".env")
+
if env_path.exists() and not self.force:
print("ā ļø A .env file already exists!")
print(f"š Path: {env_path.absolute()}")
print()
-
- choice = input("Do you want to:\n"
- " [u] Update existing file\n"
- " [o] Overwrite completely\n"
- " [c] Cancel setup\n"
- "Choice (u/o/c): ").lower().strip()
-
- if choice == 'c':
+
+ choice = (
+ input(
+ "Do you want to:\n"
+ " [u] Update existing file\n"
+ " [o] Overwrite completely\n"
+ " [c] Cancel setup\n"
+ "Choice (u/o/c): "
+ )
+ .lower()
+ .strip()
+ )
+
+ if choice == "c":
print("ā Setup cancelled.")
return False
- elif choice == 'o':
+ elif choice == "o":
self.force = True
- elif choice == 'u':
+ elif choice == "u":
# Load existing configuration
self.load_existing_config()
else:
print("ā Invalid choice. Setup cancelled.")
return False
-
+
return True
-
+
def load_existing_config(self):
"""Load existing .env configuration."""
try:
- with open('.env', 'r') as f:
+ with open(".env", "r") as f:
for line in f:
line = line.strip()
- if line and not line.startswith('#') and '=' in line:
- key, value = line.split('=', 1)
+ if line and not line.startswith("#") and "=" in line:
+ key, value = line.split("=", 1)
self.config[key] = value
print(f"ā
Loaded {len(self.config)} existing configuration items")
except Exception as e:
print(f"ā ļø Could not load existing config: {e}")
-
- def get_api_key(self, service: str, key_name: str, pattern: str,
- description: str, url: str, required: bool = True) -> Optional[str]:
+
+ def get_api_key(
+ self,
+ service: str,
+ key_name: str,
+ pattern: str,
+ description: str,
+ url: str,
+ required: bool = True,
+ ) -> Optional[str]:
"""Get API key from user with validation."""
-
+
# Check if key already exists
- existing_key = self.config.get(key_name, '')
+ existing_key = self.config.get(key_name, "")
if existing_key and existing_key != f"your_{key_name.lower()}_here":
print(f"š {service} API Key: Already configured")
keep = input(f" Keep existing key? [Y/n]: ").lower().strip()
- if keep != 'n':
+ if keep != "n":
return existing_key
-
+
print(f"\nš {service} API Key Configuration")
print(f" Description: {description}")
print(f" Get your key: {url}")
print(f" Format: {pattern}")
-
+
if not required:
skip = input(f" Skip {service} configuration? [y/N]: ").lower().strip()
- if skip == 'y':
+ if skip == "y":
return None
-
+
while True:
key = input(f" Enter your {service} API key: ").strip()
-
+
if not key:
if required:
print(" ā This key is required. Please enter a valid key.")
continue
else:
return None
-
+
# Validate format
if re.match(self.API_KEY_PATTERNS[key_name], key):
print(f" ā
Valid {service} API key format")
@@ -125,20 +138,20 @@ def get_api_key(self, service: str, key_name: str, pattern: str,
else:
print(f" ā Invalid key format. Expected: {pattern}")
print(f" š” Make sure you copied the full key correctly")
-
+
retry = input(" Try again? [Y/n]: ").lower().strip()
- if retry == 'n':
+ if retry == "n":
if required:
print(" ā This key is required to continue.")
continue
else:
return None
-
+
def configure_required_keys(self) -> bool:
"""Configure required API keys."""
print("\nš REQUIRED API KEYS")
print("-" * 30)
-
+
# Anthropic API Key
anthropic_key = self.get_api_key(
service="Anthropic Claude",
@@ -146,15 +159,15 @@ def configure_required_keys(self) -> bool:
pattern="sk-ant-api03-*",
description="Claude API for LLM capabilities",
url="https://console.anthropic.com/",
- required=True
+ required=True,
)
-
+
if not anthropic_key:
print("ā Anthropic API key is required. Setup cannot continue.")
return False
-
- self.config['ANTHROPIC_API_KEY'] = anthropic_key
-
+
+ self.config["ANTHROPIC_API_KEY"] = anthropic_key
+
# Arcade API Key
arcade_key = self.get_api_key(
service="Arcade AI",
@@ -162,25 +175,25 @@ def configure_required_keys(self) -> bool:
pattern="arc_*",
description="Arcade AI for tool execution",
url="https://arcade-ai.com/dashboard",
- required=True
+ required=True,
)
-
+
if not arcade_key:
print("ā Arcade AI API key is required. Setup cannot continue.")
return False
-
- self.config['ARCADE_API_KEY'] = arcade_key
-
+
+ self.config["ARCADE_API_KEY"] = arcade_key
+
return True
-
+
def configure_optional_keys(self) -> bool:
"""Configure optional API keys."""
if self.minimal:
return True
-
+
print("\nš OPTIONAL API KEYS")
print("-" * 30)
-
+
# OpenAI API Key
openai_key = self.get_api_key(
service="OpenAI",
@@ -188,162 +201,166 @@ def configure_optional_keys(self) -> bool:
pattern="sk-proj-*",
description="OpenAI API for extended LLM capabilities",
url="https://platform.openai.com/api-keys",
- required=False
+ required=False,
)
-
+
if openai_key:
- self.config['OPENAI_API_KEY'] = openai_key
-
+ self.config["OPENAI_API_KEY"] = openai_key
+
return True
-
+
def configure_system_settings(self):
"""Configure system settings."""
if self.minimal:
# Set minimal required settings
defaults = {
- 'ARCADE_BASE_URL': 'https://api.arcade-ai.com',
- 'DATABASE_PATH': 'data/fact_demo.db',
- 'CLAUDE_MODEL': 'claude-3-5-sonnet-20241022',
- 'LOG_LEVEL': 'INFO'
+ "ARCADE_BASE_URL": "https://api.arcade-ai.com",
+ "DATABASE_PATH": "data/fact_demo.db",
+ "CLAUDE_MODEL": "claude-3-5-sonnet-20241022",
+ "LOG_LEVEL": "INFO",
}
-
+
for key, value in defaults.items():
if key not in self.config:
self.config[key] = value
-
+
return
-
+
print("\nš SYSTEM CONFIGURATION")
print("-" * 30)
-
+
# Claude Model Selection
models = [
- 'claude-3-5-sonnet-20241022',
- 'claude-3-haiku-20240307',
- 'claude-3-opus-20240229'
+ "claude-3-5-sonnet-20241022",
+ "claude-3-haiku-20240307",
+ "claude-3-opus-20240229",
]
-
- current_model = self.config.get('CLAUDE_MODEL', models[0])
+
+ current_model = self.config.get("CLAUDE_MODEL", models[0])
print(f"š¤ Claude Model (current: {current_model})")
print(" Available models:")
for i, model in enumerate(models, 1):
marker = " (recommended)" if model == models[0] else ""
print(f" {i}. {model}{marker}")
-
- choice = input(f" Select model [1-{len(models)}] or press Enter for current: ").strip()
+
+ choice = input(
+ f" Select model [1-{len(models)}] or press Enter for current: "
+ ).strip()
if choice.isdigit() and 1 <= int(choice) <= len(models):
- self.config['CLAUDE_MODEL'] = models[int(choice) - 1]
+ self.config["CLAUDE_MODEL"] = models[int(choice) - 1]
elif not choice:
- self.config['CLAUDE_MODEL'] = current_model
-
+ self.config["CLAUDE_MODEL"] = current_model
+
# Log Level
- log_levels = ['DEBUG', 'INFO', 'WARNING', 'ERROR']
- current_log = self.config.get('LOG_LEVEL', 'INFO')
+ log_levels = ["DEBUG", "INFO", "WARNING", "ERROR"]
+ current_log = self.config.get("LOG_LEVEL", "INFO")
print(f"\nš Log Level (current: {current_log})")
print(" Available levels:")
for i, level in enumerate(log_levels, 1):
- marker = " (recommended)" if level == 'INFO' else ""
+ marker = " (recommended)" if level == "INFO" else ""
print(f" {i}. {level}{marker}")
-
- choice = input(f" Select level [1-{len(log_levels)}] or press Enter for current: ").strip()
+
+ choice = input(
+ f" Select level [1-{len(log_levels)}] or press Enter for current: "
+ ).strip()
if choice.isdigit() and 1 <= int(choice) <= len(log_levels):
- self.config['LOG_LEVEL'] = log_levels[int(choice) - 1]
+ self.config["LOG_LEVEL"] = log_levels[int(choice) - 1]
elif not choice:
- self.config['LOG_LEVEL'] = current_log
-
+ self.config["LOG_LEVEL"] = current_log
+
# Set other defaults
defaults = {
- 'ARCADE_BASE_URL': 'https://api.arcade-ai.com',
- 'DATABASE_PATH': 'data/fact_demo.db',
- 'SYSTEM_PROMPT': 'You are a deterministic finance assistant. When uncertain, request data via tools.',
- 'MAX_RETRIES': '3',
- 'REQUEST_TIMEOUT': '30'
+ "ARCADE_BASE_URL": "https://api.arcade-ai.com",
+ "DATABASE_PATH": "data/fact_demo.db",
+ "SYSTEM_PROMPT": "You are a deterministic finance assistant. When uncertain, request data via tools.",
+ "MAX_RETRIES": "3",
+ "REQUEST_TIMEOUT": "30",
}
-
+
for key, value in defaults.items():
if key not in self.config:
self.config[key] = value
-
+
def configure_security_settings(self):
"""Configure security settings."""
if self.minimal:
# Set secure defaults for minimal setup
security_defaults = {
- 'STRICT_MODE': 'true',
- 'DEBUG_MODE': 'false',
- 'ENFORCE_HTTPS': 'true',
- 'RATE_LIMITING_ENABLED': 'true'
+ "STRICT_MODE": "true",
+ "DEBUG_MODE": "false",
+ "ENFORCE_HTTPS": "true",
+ "RATE_LIMITING_ENABLED": "true",
}
-
+
for key, value in security_defaults.items():
if key not in self.config:
self.config[key] = value
-
+
return
-
+
print("\nš SECURITY CONFIGURATION")
print("-" * 30)
-
+
# Environment type
env_types = {
- 'production': {
- 'STRICT_MODE': 'true',
- 'DEBUG_MODE': 'false',
- 'ENFORCE_HTTPS': 'true',
- 'RATE_LIMITING_ENABLED': 'true',
- 'LOG_SECURITY_EVENTS': 'true'
+ "production": {
+ "STRICT_MODE": "true",
+ "DEBUG_MODE": "false",
+ "ENFORCE_HTTPS": "true",
+ "RATE_LIMITING_ENABLED": "true",
+ "LOG_SECURITY_EVENTS": "true",
+ },
+ "development": {
+ "STRICT_MODE": "false",
+ "DEBUG_MODE": "true",
+ "ENFORCE_HTTPS": "false",
+ "RATE_LIMITING_ENABLED": "false",
+ "LOG_SECURITY_EVENTS": "false",
},
- 'development': {
- 'STRICT_MODE': 'false',
- 'DEBUG_MODE': 'true',
- 'ENFORCE_HTTPS': 'false',
- 'RATE_LIMITING_ENABLED': 'false',
- 'LOG_SECURITY_EVENTS': 'false'
- }
}
-
+
print("šļø Environment Type:")
print(" 1. Production (strict security)")
print(" 2. Development (relaxed security)")
-
+
choice = input(" Select environment [1-2]: ").strip()
-
- if choice == '1':
- env_config = env_types['production']
+
+ if choice == "1":
+ env_config = env_types["production"]
print(" ā
Production security settings applied")
- elif choice == '2':
- env_config = env_types['development']
+ elif choice == "2":
+ env_config = env_types["development"]
print(" ā
Development security settings applied")
else:
- env_config = env_types['production']
+ env_config = env_types["production"]
print(" ā
Default (production) security settings applied")
-
+
# Apply security settings
for key, value in env_config.items():
if key not in self.config:
self.config[key] = value
-
+
def write_env_file(self) -> bool:
"""Write configuration to .env file."""
try:
# Create data directory if it doesn't exist
- data_dir = Path('data')
+ data_dir = Path("data")
data_dir.mkdir(exist_ok=True)
-
+
# Generate .env content
env_content = self.generate_env_content()
-
+
# Write to file
- with open('.env', 'w') as f:
+ with open(".env", "w") as f:
f.write(env_content)
-
+
print(f"ā
Configuration saved to .env")
return True
-
+
except Exception as e:
print(f"ā Failed to write .env file: {e}")
return False
-
+
def generate_env_content(self) -> str:
"""Generate .env file content with proper formatting."""
content = [
@@ -354,102 +371,118 @@ def generate_env_content(self) -> str:
"# =============================================================================",
"# REQUIRED API KEYS",
"# =============================================================================",
- ""
+ "",
]
-
+
# Required API keys
- required_keys = ['ANTHROPIC_API_KEY', 'ARCADE_API_KEY']
+ required_keys = ["ANTHROPIC_API_KEY", "ARCADE_API_KEY"]
for key in required_keys:
if key in self.config:
content.append(f"{key}={self.config[key]}")
-
- content.extend([
- "",
- "# =============================================================================",
- "# OPTIONAL API KEYS",
- "# =============================================================================",
- ""
- ])
-
+
+ content.extend(
+ [
+ "",
+ "# =============================================================================",
+ "# OPTIONAL API KEYS",
+ "# =============================================================================",
+ "",
+ ]
+ )
+
# Optional API keys
- optional_keys = ['OPENAI_API_KEY', 'ENCRYPTION_KEY', 'CACHE_ENCRYPTION_KEY']
+ optional_keys = ["OPENAI_API_KEY", "ENCRYPTION_KEY", "CACHE_ENCRYPTION_KEY"]
for key in optional_keys:
if key in self.config:
content.append(f"{key}={self.config[key]}")
else:
content.append(f"# {key}=your_{key.lower()}_here")
-
- content.extend([
- "",
- "# =============================================================================",
- "# SYSTEM CONFIGURATION",
- "# =============================================================================",
- ""
- ])
-
+
+ content.extend(
+ [
+ "",
+ "# =============================================================================",
+ "# SYSTEM CONFIGURATION",
+ "# =============================================================================",
+ "",
+ ]
+ )
+
# System configuration
system_keys = [
- 'ARCADE_BASE_URL', 'DATABASE_PATH', 'CLAUDE_MODEL', 'SYSTEM_PROMPT',
- 'MAX_RETRIES', 'REQUEST_TIMEOUT', 'LOG_LEVEL'
+ "ARCADE_BASE_URL",
+ "DATABASE_PATH",
+ "CLAUDE_MODEL",
+ "SYSTEM_PROMPT",
+ "MAX_RETRIES",
+ "REQUEST_TIMEOUT",
+ "LOG_LEVEL",
]
-
+
for key in system_keys:
if key in self.config:
content.append(f"{key}={self.config[key]}")
-
- content.extend([
- "",
- "# =============================================================================",
- "# SECURITY CONFIGURATION",
- "# =============================================================================",
- ""
- ])
-
+
+ content.extend(
+ [
+ "",
+ "# =============================================================================",
+ "# SECURITY CONFIGURATION",
+ "# =============================================================================",
+ "",
+ ]
+ )
+
# Security configuration
security_keys = [
- 'STRICT_MODE', 'DEBUG_MODE', 'ENFORCE_HTTPS', 'RATE_LIMITING_ENABLED',
- 'LOG_SECURITY_EVENTS'
+ "STRICT_MODE",
+ "DEBUG_MODE",
+ "ENFORCE_HTTPS",
+ "RATE_LIMITING_ENABLED",
+ "LOG_SECURITY_EVENTS",
]
-
+
for key in security_keys:
if key in self.config:
content.append(f"{key}={self.config[key]}")
-
- content.extend([
- "",
- "# =============================================================================",
- "# CACHE CONFIGURATION (Defaults will be used if not specified)",
- "# =============================================================================",
- "",
- "# CACHE_PREFIX=fact_v1",
- "# CACHE_MIN_TOKENS=50",
- "# CACHE_MAX_SIZE=100MB",
- "# CACHE_TTL_SECONDS=3600",
- "",
- "# For more configuration options, see:",
- "# - docs/environment-configuration-guide.md",
- "# - .env.template",
- ""
- ])
-
- return '\n'.join(content)
-
+
+ content.extend(
+ [
+ "",
+ "# =============================================================================",
+ "# CACHE CONFIGURATION (Defaults will be used if not specified)",
+ "# =============================================================================",
+ "",
+ "# CACHE_PREFIX=fact_v1",
+ "# CACHE_MIN_TOKENS=50",
+ "# CACHE_MAX_SIZE=100MB",
+ "# CACHE_TTL_SECONDS=3600",
+ "",
+ "# For more configuration options, see:",
+ "# - docs/environment-configuration-guide.md",
+ "# - .env.template",
+ "",
+ ]
+ )
+
+ return "\n".join(content)
+
def run_validation(self) -> bool:
"""Run configuration validation."""
print("\nš VALIDATING CONFIGURATION")
print("-" * 30)
-
+
try:
# Import and run validation
from pathlib import Path
import subprocess
-
+
result = subprocess.run(
- [sys.executable, 'scripts/validate_env.py', '--verbose'],
+ [sys.executable, "scripts/validate_env.py", "--verbose"],
capture_output=True,
- text=True
+ text=True,
)
-
+
if result.returncode == 0:
print("ā
Configuration validation passed!")
return True
@@ -458,13 +491,13 @@ def run_validation(self) -> bool:
print(result.stdout)
print(result.stderr)
return False
-
+
except Exception as e:
print(f"ā ļø Could not run validation: {e}")
print("š” You can manually validate later with:")
print(" python scripts/validate_env.py --verbose")
return True
-
+
def print_next_steps(self):
"""Print next steps for the user."""
print("\nš SETUP COMPLETE!")
@@ -493,39 +526,39 @@ def print_next_steps(self):
print("⢠Never commit .env files to version control")
print("⢠Rotate your API keys regularly")
print("⢠Use different keys for different environments")
-
+
def run_setup(self) -> bool:
"""Run the complete setup process."""
self.print_header()
-
+
# Check existing environment
if not self.check_existing_env():
return False
-
+
# Configure required API keys
if not self.configure_required_keys():
return False
-
+
# Configure optional API keys
if not self.configure_optional_keys():
return False
-
+
# Configure system settings
self.configure_system_settings()
-
+
# Configure security settings
self.configure_security_settings()
-
+
# Write configuration file
if not self.write_env_file():
return False
-
+
# Validate configuration
validation_success = self.run_validation()
-
+
# Print next steps
self.print_next_steps()
-
+
return validation_success
@@ -542,30 +575,32 @@ def main():
This script will guide you through configuring your FACT environment
with the required API keys and optimal settings.
- """
+ """,
)
-
+
parser.add_argument(
- '--force', '-f',
- action='store_true',
- help='Overwrite existing .env file without prompting'
+ "--force",
+ "-f",
+ action="store_true",
+ help="Overwrite existing .env file without prompting",
)
-
+
parser.add_argument(
- '--minimal', '-m',
- action='store_true',
- help='Set up minimal configuration with defaults'
+ "--minimal",
+ "-m",
+ action="store_true",
+ help="Set up minimal configuration with defaults",
)
-
+
args = parser.parse_args()
-
+
# Run setup
setup = EnvironmentSetup(force=args.force, minimal=args.minimal)
success = setup.run_setup()
-
+
# Exit with appropriate code
sys.exit(0 if success else 1)
if __name__ == "__main__":
- main()
\ No newline at end of file
+ main()
diff --git a/scripts/test_benchmark_runner.py b/scripts/test_benchmark_runner.py
index 39cd020..70091cb 100644
--- a/scripts/test_benchmark_runner.py
+++ b/scripts/test_benchmark_runner.py
@@ -8,13 +8,14 @@
import subprocess
from pathlib import Path
+
def test_benchmark_runner():
"""Test the benchmark runner script."""
script_path = Path(__file__).parent / "run_benchmarks.py"
-
+
print("š§Ŗ Testing FACT Benchmark Runner")
print("=" * 50)
-
+
# Test 1: Check if script can be imported
print("Test 1: Import validation...")
try:
@@ -22,30 +23,35 @@ def test_benchmark_runner():
src_path = str(Path(__file__).parent.parent / "src")
if src_path not in sys.path:
sys.path.insert(0, src_path)
-
+
# Test basic imports that don't require relative imports
print(" Testing basic module structure...")
-
+
# Check if benchmarking module exists
import benchmarking
+
print(" ā
Benchmarking module found")
-
+
# Check if cache module exists
import cache
+
print(" ā
Cache module found")
-
+
print("ā
Core modules accessible")
except ImportError as e:
print(f"ā Import failed: {e}")
return False
-
+
# Test 2: Check command line help
print("\nTest 2: Command line interface...")
try:
- result = subprocess.run([
- sys.executable, str(script_path), "--help"
- ], capture_output=True, text=True, timeout=10)
-
+ result = subprocess.run(
+ [sys.executable, str(script_path), "--help"],
+ capture_output=True,
+ text=True,
+ timeout=10,
+ )
+
if result.returncode == 0:
print("ā
CLI help works correctly")
else:
@@ -54,21 +60,22 @@ def test_benchmark_runner():
except Exception as e:
print(f"ā CLI test failed: {e}")
return False
-
+
# Test 3: Directory creation functionality
print("\nTest 3: Directory creation...")
try:
# Import our enhanced functions
sys.path.insert(0, str(Path(__file__).parent))
import run_benchmarks
-
+
# Test directory creation
test_dir = run_benchmarks.create_logs_directory("test_logs")
if test_dir.exists():
print(f"ā
Directory created: {test_dir}")
-
+
# Clean up
import shutil
+
shutil.rmtree(test_dir.parent)
print("ā
Cleanup successful")
else:
@@ -77,15 +84,16 @@ def test_benchmark_runner():
except Exception as e:
print(f"ā Directory test failed: {e}")
return False
-
+
print("\nš All tests passed! Benchmark runner is ready to use.")
print("\nTo run benchmarks:")
print(f" python {script_path}")
print(f" python {script_path} --include-rag-comparison")
print(f" python {script_path} --include-profiling --include-load-test")
-
+
return True
+
if __name__ == "__main__":
success = test_benchmark_runner()
- sys.exit(0 if success else 1)
\ No newline at end of file
+ sys.exit(0 if success else 1)
diff --git a/scripts/test_cache_fix.py b/scripts/test_cache_fix.py
index 0c7ceda..aa91d8a 100644
--- a/scripts/test_cache_fix.py
+++ b/scripts/test_cache_fix.py
@@ -20,36 +20,36 @@
def test_cache_validation_fix():
"""Test that cache validation now uses configured minimum tokens."""
-
+
print("=== FACT Cache Validation Fix Verification ===\n")
-
+
# Test 1: Default configuration
print("1. Testing default configuration (500 tokens)...")
default_config = get_default_cache_config()
manager_default = CacheManager(default_config)
print(f" Default min_tokens: {manager_default.min_tokens}")
-
- # Test 2: Custom configuration
+
+ # Test 2: Custom configuration
print("\n2. Testing custom configuration (100 tokens)...")
custom_config = default_config.copy()
- custom_config['min_tokens'] = 100
+ custom_config["min_tokens"] = 100
manager_custom = CacheManager(custom_config)
print(f" Custom min_tokens: {manager_custom.min_tokens}")
-
+
# Test 3: Environment override
print("\n3. Testing environment override...")
- os.environ['CACHE_MIN_TOKENS'] = '75'
+ os.environ["CACHE_MIN_TOKENS"] = "75"
try:
env_config = load_cache_config_from_env()
manager_env = CacheManager(env_config.to_dict())
print(f" Environment min_tokens: {manager_env.min_tokens}")
except Exception as e:
print(f" Environment test failed: {e}")
-
+
# Test 4: Validation consistency
print("\n4. Testing validation consistency...")
test_content = "This is a test response. " * 15 # ~375 tokens
-
+
try:
# Should fail with default manager (requires 500 tokens)
hash_key = manager_default.generate_hash("test_query_1")
@@ -57,7 +57,7 @@ def test_cache_validation_fix():
print(" ā ERROR: Default manager should have rejected content")
except Exception:
print(" ā
Default manager correctly rejected content (<500 tokens)")
-
+
try:
# Should succeed with custom manager (requires 100 tokens)
hash_key = manager_custom.generate_hash("test_query_2")
@@ -65,7 +65,7 @@ def test_cache_validation_fix():
print(f" ā
Custom manager accepted content ({entry.token_count} tokens)")
except Exception as e:
print(f" ā Custom manager failed: {e}")
-
+
print("\n=== Fix Summary ===")
print("ā
Hard-coded 500 token minimum removed from CacheEntry._validate()")
print("ā
CacheEntry now accepts configurable min_tokens parameter")
@@ -75,4 +75,4 @@ def test_cache_validation_fix():
if __name__ == "__main__":
- test_cache_validation_fix()
\ No newline at end of file
+ test_cache_validation_fix()
diff --git a/scripts/test_cache_resilience.py b/scripts/test_cache_resilience.py
index 534f5bd..0681045 100644
--- a/scripts/test_cache_resilience.py
+++ b/scripts/test_cache_resilience.py
@@ -20,8 +20,11 @@
sys.path.insert(0, src_path)
from cache.resilience import (
- CacheCircuitBreaker, CircuitBreakerConfig, ResilientCacheWrapper,
- CircuitState, FailureRecord
+ CacheCircuitBreaker,
+ CircuitBreakerConfig,
+ ResilientCacheWrapper,
+ CircuitState,
+ FailureRecord,
)
from cache.manager import CacheManager
from core.errors import CacheError
@@ -32,7 +35,7 @@
processors=[
structlog.processors.TimeStamper(fmt="ISO"),
structlog.processors.add_log_level,
- structlog.processors.JSONRenderer()
+ structlog.processors.JSONRenderer(),
],
logger_factory=structlog.PrintLoggerFactory(),
wrapper_class=structlog.BoundLogger,
@@ -44,11 +47,11 @@
class FailingCacheManager:
"""Mock cache manager that can be configured to fail."""
-
+
def __init__(self, failure_rate: float = 0.0, failure_after: int = 0):
"""
Initialize failing cache manager.
-
+
Args:
failure_rate: Probability of failure (0.0 = never fail, 1.0 = always fail)
failure_after: Fail after this many successful operations
@@ -57,69 +60,79 @@ def __init__(self, failure_rate: float = 0.0, failure_after: int = 0):
self.failure_after = failure_after
self.operation_count = 0
self.cache = {}
-
+
def set_failure_rate(self, rate: float):
"""Set new failure rate."""
self.failure_rate = rate
logger.info("Cache failure rate changed", rate=rate)
-
+
def set_failure_after(self, count: int):
"""Set to fail after specific number of operations."""
self.failure_after = count
self.operation_count = 0
logger.info("Cache set to fail after operations", count=count)
-
+
def _should_fail(self) -> bool:
"""Determine if this operation should fail."""
self.operation_count += 1
-
+
if self.failure_after > 0 and self.operation_count > self.failure_after:
return True
-
+
import random
+
return random.random() < self.failure_rate
-
+
async def store(self, query_hash: str, content: str):
"""Store with potential failure."""
if self._should_fail():
- raise CacheError("Simulated cache store failure", error_code="CACHE_STORE_FAILED")
-
- self.cache[query_hash] = {
- 'content': content,
- 'timestamp': time.time()
- }
+ raise CacheError(
+ "Simulated cache store failure", error_code="CACHE_STORE_FAILED"
+ )
+
+ self.cache[query_hash] = {"content": content, "timestamp": time.time()}
return True
-
+
async def get(self, query_hash: str):
"""Get with potential failure."""
if self._should_fail():
- raise CacheError("Simulated cache get failure", error_code="CACHE_GET_FAILED")
-
+ raise CacheError(
+ "Simulated cache get failure", error_code="CACHE_GET_FAILED"
+ )
+
if query_hash in self.cache:
- return type('CacheEntry', (), {
- 'content': self.cache[query_hash]['content'],
- 'timestamp': self.cache[query_hash]['timestamp']
- })()
+ return type(
+ "CacheEntry",
+ (),
+ {
+ "content": self.cache[query_hash]["content"],
+ "timestamp": self.cache[query_hash]["timestamp"],
+ },
+ )()
return None
-
+
async def invalidate_by_prefix(self, prefix: str) -> int:
"""Invalidate with potential failure."""
if self._should_fail():
- raise CacheError("Simulated cache invalidate failure", error_code="CACHE_INVALIDATE_FAILED")
-
+ raise CacheError(
+ "Simulated cache invalidate failure",
+ error_code="CACHE_INVALIDATE_FAILED",
+ )
+
# Simulate invalidation
return len(self.cache)
-
+
def generate_hash(self, query: str) -> str:
"""Generate hash (no failure simulation for local operation)."""
import hashlib
+
return hashlib.sha256(query.encode()).hexdigest()
async def test_circuit_breaker_states():
"""Test circuit breaker state transitions."""
logger.info("=== Testing Circuit Breaker State Transitions ===")
-
+
# Configure circuit breaker for fast testing
config = CircuitBreakerConfig(
failure_threshold=3, # Open after 3 failures
@@ -127,31 +140,33 @@ async def test_circuit_breaker_states():
timeout_seconds=2.0, # Fast timeout for testing
rolling_window_seconds=60.0,
gradual_recovery=True,
- recovery_factor=0.5
+ recovery_factor=0.5,
)
-
+
circuit_breaker = CacheCircuitBreaker(config)
failing_cache = FailingCacheManager(failure_rate=0.0) # Start with no failures
resilient_cache = ResilientCacheWrapper(failing_cache, circuit_breaker)
-
+
# Test normal operation (CLOSED state)
logger.info("Testing CLOSED state - normal operations")
assert circuit_breaker.get_state() == CircuitState.CLOSED
-
+
for i in range(3):
result = await resilient_cache.store(f"test_key_{i}", f"test_content_{i}")
assert result == True
-
+
metrics = circuit_breaker.get_metrics()
- logger.info("CLOSED state metrics",
- state=metrics.state.value,
- success_count=metrics.success_count,
- failure_count=metrics.failure_count)
-
+ logger.info(
+ "CLOSED state metrics",
+ state=metrics.state.value,
+ success_count=metrics.success_count,
+ failure_count=metrics.failure_count,
+ )
+
# Introduce failures to trigger OPEN state
logger.info("Introducing failures to trigger OPEN state")
failing_cache.set_failure_rate(1.0) # 100% failure rate
-
+
failure_count = 0
for i in range(5):
try:
@@ -161,144 +176,148 @@ async def test_circuit_breaker_states():
if "CIRCUIT_BREAKER_OPEN" in str(e.error_code):
logger.info("Circuit breaker opened successfully")
break
-
+
assert circuit_breaker.get_state() == CircuitState.OPEN
-
+
# Test that operations use graceful degradation when OPEN
logger.info("Testing OPEN state - operations should use graceful degradation")
-
+
# With graceful degradation enabled, operations should return fallback values
result = await resilient_cache.get("test_key")
assert result is None, "Expected fallback response (None) from graceful degradation"
-
+
# Test store operation also returns fallback
result = await resilient_cache.store("test_key", "test_content")
assert result == True, "Expected fallback response (True) from graceful degradation"
-
+
# Disable graceful degradation to test actual circuit breaker exceptions
resilient_cache.enable_graceful_degradation = False
-
+
try:
await resilient_cache.get("test_key")
- assert False, "Expected circuit breaker to raise exception when graceful degradation disabled"
+ assert (
+ False
+ ), "Expected circuit breaker to raise exception when graceful degradation disabled"
except CacheError as e:
assert "CIRCUIT_BREAKER_OPEN" in str(e.error_code)
- logger.info("Circuit breaker correctly raised exception when graceful degradation disabled")
-
+ logger.info(
+ "Circuit breaker correctly raised exception when graceful degradation disabled"
+ )
+
# Re-enable graceful degradation for remaining tests
resilient_cache.enable_graceful_degradation = True
-
+
# Wait for timeout and test HALF_OPEN transition
logger.info("Waiting for timeout to test HALF_OPEN transition")
await asyncio.sleep(2.5) # Wait longer than timeout
-
+
# Fix the cache and try operation to trigger HALF_OPEN
failing_cache.set_failure_rate(0.0) # Fix the cache
-
+
# This should transition to HALF_OPEN and succeed
result = await resilient_cache.store("recovery_test", "recovery_content")
assert result == True
-
+
# Circuit should be HALF_OPEN or CLOSED now
state = circuit_breaker.get_state()
logger.info("State after recovery attempt", state=state.value)
-
+
# Continue with successful operations to close circuit
for i in range(3):
await resilient_cache.store(f"recovery_key_{i}", f"recovery_content_{i}")
-
+
final_state = circuit_breaker.get_state()
logger.info("Final state after recovery", state=final_state.value)
-
+
metrics = circuit_breaker.get_metrics()
- logger.info("Final metrics",
- state=metrics.state.value,
- success_count=metrics.success_count,
- failure_count=metrics.failure_count,
- failure_rate=metrics.failure_rate,
- state_changes=metrics.state_changes)
-
+ logger.info(
+ "Final metrics",
+ state=metrics.state.value,
+ success_count=metrics.success_count,
+ failure_count=metrics.failure_count,
+ failure_rate=metrics.failure_rate,
+ state_changes=metrics.state_changes,
+ )
+
logger.info("ā
Circuit breaker state transitions test passed")
async def test_graceful_degradation():
"""Test graceful degradation when cache fails."""
logger.info("=== Testing Graceful Degradation ===")
-
+
# Configure circuit breaker
config = CircuitBreakerConfig(
failure_threshold=2,
success_threshold=2,
timeout_seconds=1.0,
- gradual_recovery=True
+ gradual_recovery=True,
)
-
+
circuit_breaker = CacheCircuitBreaker(config)
failing_cache = FailingCacheManager(failure_rate=1.0) # Always fail
resilient_cache = ResilientCacheWrapper(failing_cache, circuit_breaker)
-
+
# Enable graceful degradation
resilient_cache.enable_graceful_degradation = True
-
+
# First, trigger circuit breaker to open
logger.info("Triggering circuit breaker to open state")
-
+
for i in range(3): # Exceed failure threshold
try:
await resilient_cache.store(f"fail_key_{i}", "content")
except CacheError:
pass # Expected failures
-
+
# Verify circuit is open
assert circuit_breaker.get_state() == CircuitState.OPEN
logger.info("Circuit breaker is now OPEN")
-
+
logger.info("Testing cache operations with graceful degradation")
-
+
# Test store operation - should return graceful response
result = await resilient_cache.store("test_key", "test_content")
logger.info("Store operation result", result=result)
assert result == True # Graceful fallback
-
+
# Test get operation - should return None (cache miss)
result = await resilient_cache.get("test_key")
logger.info("Get operation result", result=result)
assert result is None
-
+
# Test invalidate operation - should return 0
result = await resilient_cache.invalidate_by_prefix("test_")
logger.info("Invalidate operation result", result=result)
assert result == 0
-
+
# Check circuit breaker state
state = circuit_breaker.get_state()
logger.info("Circuit breaker state", state=state.value)
-
+
logger.info("ā
Graceful degradation test passed")
async def test_performance_under_failures():
"""Test performance characteristics under various failure scenarios."""
logger.info("=== Testing Performance Under Failures ===")
-
+
config = CircuitBreakerConfig(
- failure_threshold=5,
- success_threshold=3,
- timeout_seconds=1.0
+ failure_threshold=5, success_threshold=3, timeout_seconds=1.0
)
-
+
circuit_breaker = CacheCircuitBreaker(config)
failing_cache = FailingCacheManager(failure_rate=0.3) # 30% failure rate
resilient_cache = ResilientCacheWrapper(failing_cache, circuit_breaker)
-
+
# Test with mixed success/failure pattern
operation_count = 50
success_count = 0
failure_count = 0
total_time = 0
-
+
logger.info("Running performance test with 30% failure rate")
-
+
for i in range(operation_count):
start_time = time.time()
try:
@@ -306,47 +325,47 @@ async def test_performance_under_failures():
success_count += 1
except CacheError:
failure_count += 1
-
+
end_time = time.time()
- total_time += (end_time - start_time)
-
+ total_time += end_time - start_time
+
avg_latency = (total_time / operation_count) * 1000 # Convert to ms
-
+
metrics = circuit_breaker.get_metrics()
-
- logger.info("Performance test results",
- total_operations=operation_count,
- successes=success_count,
- failures=failure_count,
- avg_latency_ms=avg_latency,
- circuit_state=metrics.state.value,
- circuit_failure_rate=metrics.failure_rate,
- state_changes=metrics.state_changes)
-
+
+ logger.info(
+ "Performance test results",
+ total_operations=operation_count,
+ successes=success_count,
+ failures=failure_count,
+ avg_latency_ms=avg_latency,
+ circuit_state=metrics.state.value,
+ circuit_failure_rate=metrics.failure_rate,
+ state_changes=metrics.state_changes,
+ )
+
logger.info("ā
Performance test completed")
async def test_circuit_breaker_metrics():
"""Test circuit breaker metrics collection."""
logger.info("=== Testing Circuit Breaker Metrics ===")
-
+
config = CircuitBreakerConfig(
- failure_threshold=3,
- success_threshold=2,
- timeout_seconds=1.0
+ failure_threshold=3, success_threshold=2, timeout_seconds=1.0
)
-
+
circuit_breaker = CacheCircuitBreaker(config)
failing_cache = FailingCacheManager()
resilient_cache = ResilientCacheWrapper(failing_cache, circuit_breaker)
-
+
# Generate some operations
logger.info("Generating operations for metrics")
-
+
# Successful operations
for i in range(5):
await resilient_cache.store(f"metrics_key_{i}", f"metrics_content_{i}")
-
+
# Failed operations
failing_cache.set_failure_rate(1.0)
for i in range(3):
@@ -354,45 +373,48 @@ async def test_circuit_breaker_metrics():
await resilient_cache.store(f"fail_key_{i}", f"fail_content_{i}")
except CacheError:
pass
-
+
# Get comprehensive metrics
metrics = circuit_breaker.get_metrics()
cache_metrics = resilient_cache.get_metrics()
-
- logger.info("Circuit breaker metrics",
- state=metrics.state.value,
- total_operations=metrics.total_operations,
- success_count=metrics.success_count,
- failure_count=metrics.failure_count,
- failure_rate=metrics.failure_rate,
- state_changes=metrics.state_changes,
- recent_failures_count=len(metrics.recent_failures))
-
- logger.info("Combined metrics structure",
- cache_metrics_keys=list(cache_metrics.keys()))
-
+
+ logger.info(
+ "Circuit breaker metrics",
+ state=metrics.state.value,
+ total_operations=metrics.total_operations,
+ success_count=metrics.success_count,
+ failure_count=metrics.failure_count,
+ failure_rate=metrics.failure_rate,
+ state_changes=metrics.state_changes,
+ recent_failures_count=len(metrics.recent_failures),
+ )
+
+ logger.info(
+ "Combined metrics structure", cache_metrics_keys=list(cache_metrics.keys())
+ )
+
logger.info("ā
Metrics test completed")
async def test_recovery_scenarios():
"""Test various recovery scenarios."""
logger.info("=== Testing Recovery Scenarios ===")
-
+
config = CircuitBreakerConfig(
failure_threshold=2,
success_threshold=3,
timeout_seconds=1.0,
gradual_recovery=True,
- recovery_factor=0.5
+ recovery_factor=0.5,
)
-
+
circuit_breaker = CacheCircuitBreaker(config)
failing_cache = FailingCacheManager()
resilient_cache = ResilientCacheWrapper(failing_cache, circuit_breaker)
-
+
# Test scenario 1: Complete failure and recovery
logger.info("Scenario 1: Complete failure and recovery")
-
+
# Cause circuit to open
failing_cache.set_failure_rate(1.0)
for i in range(3):
@@ -400,16 +422,16 @@ async def test_recovery_scenarios():
await resilient_cache.store(f"fail_{i}", "content")
except CacheError:
pass
-
+
assert circuit_breaker.get_state() == CircuitState.OPEN
logger.info("Circuit opened successfully")
-
+
# Wait for timeout
await asyncio.sleep(1.5)
-
+
# Fix cache and recover
failing_cache.set_failure_rate(0.0)
-
+
# Gradual recovery
recovery_attempts = 10
success_count = 0
@@ -420,32 +442,34 @@ async def test_recovery_scenarios():
except CacheError as e:
if "THROTTLING" in str(e.error_code):
logger.debug("Request throttled during recovery")
-
+
final_state = circuit_breaker.get_state()
- logger.info("Recovery scenario completed",
- final_state=final_state.value,
- recovery_success_rate=success_count / recovery_attempts)
-
+ logger.info(
+ "Recovery scenario completed",
+ final_state=final_state.value,
+ recovery_success_rate=success_count / recovery_attempts,
+ )
+
logger.info("ā
Recovery scenarios test completed")
async def run_all_tests():
"""Run all circuit breaker and resilience tests."""
logger.info("š§Ŗ Starting Cache Resilience Tests")
-
+
try:
await test_circuit_breaker_states()
await test_graceful_degradation()
await test_performance_under_failures()
await test_circuit_breaker_metrics()
await test_recovery_scenarios()
-
+
logger.info("š All cache resilience tests passed!")
-
+
except Exception as e:
logger.error("ā Test failed", error=str(e), exc_info=True)
raise
if __name__ == "__main__":
- asyncio.run(run_all_tests())
\ No newline at end of file
+ asyncio.run(run_all_tests())
diff --git a/scripts/test_fact_cache_integration.py b/scripts/test_fact_cache_integration.py
index 397eff7..75e1d8a 100644
--- a/scripts/test_fact_cache_integration.py
+++ b/scripts/test_fact_cache_integration.py
@@ -35,7 +35,9 @@
import structlog
except ImportError as e:
print(f"ā Import error: {e}")
- print("Make sure you're running this script from the project root and all dependencies are installed.")
+ print(
+ "Make sure you're running this script from the project root and all dependencies are installed."
+ )
sys.exit(1)
# Configure logging
@@ -43,7 +45,7 @@
processors=[
structlog.processors.TimeStamper(fmt="ISO"),
structlog.processors.add_log_level,
- structlog.processors.JSONRenderer()
+ structlog.processors.JSONRenderer(),
],
logger_factory=structlog.PrintLoggerFactory(),
wrapper_class=structlog.BoundLogger,
@@ -56,14 +58,14 @@
def setup_test_environment() -> str:
"""
Set up test environment variables and create a temporary database.
-
+
Returns:
Path to temporary database file
"""
# Create temporary database
- temp_db = tempfile.NamedTemporaryFile(suffix='.db', delete=False)
+ temp_db = tempfile.NamedTemporaryFile(suffix=".db", delete=False)
temp_db.close()
-
+
# Set required environment variables for testing
test_env = {
"DATABASE_PATH": temp_db.name,
@@ -76,27 +78,32 @@ def setup_test_environment() -> str:
"LOG_LEVEL": "INFO",
"MAX_RETRIES": "2",
"REQUEST_TIMEOUT": "30",
- "SKIP_API_VALIDATION": "true" # Skip API validation for testing
+ "SKIP_API_VALIDATION": "true", # Skip API validation for testing
}
-
+
# Check if API keys are set, if not set test values
if not os.getenv("ANTHROPIC_API_KEY"):
test_env["ANTHROPIC_API_KEY"] = "test-key-for-integration-testing"
- logger.warning("Using test API key for Anthropic - some features may be limited")
-
+ logger.warning(
+ "Using test API key for Anthropic - some features may be limited"
+ )
+
if not os.getenv("ARCADE_API_KEY"):
test_env["ARCADE_API_KEY"] = "test-key-for-integration-testing"
logger.warning("Using test API key for Arcade - some features may be limited")
-
+
# Apply test environment
for key, value in test_env.items():
os.environ[key] = value
-
- logger.info("Test environment configured",
- database_path=temp_db.name,
- api_keys_configured=bool(os.getenv("ANTHROPIC_API_KEY")) and bool(os.getenv("ARCADE_API_KEY")),
- skip_api_validation=True)
-
+
+ logger.info(
+ "Test environment configured",
+ database_path=temp_db.name,
+ api_keys_configured=bool(os.getenv("ANTHROPIC_API_KEY"))
+ and bool(os.getenv("ARCADE_API_KEY")),
+ skip_api_validation=True,
+ )
+
return temp_db.name
@@ -113,103 +120,115 @@ def cleanup_test_environment(temp_db_path: str) -> None:
async def test_fact_driver_initialization():
"""Test FACT driver initialization with real components."""
logger.info("=== Testing FACT Driver Initialization ===")
-
+
try:
# Create configuration
config = Config()
logger.info("Configuration created successfully")
-
+
# Initialize FACT driver
driver = FACTDriver(config)
logger.info("FACT driver instance created")
-
+
# Test initialization
await driver.initialize()
logger.info("ā
FACT driver initialized successfully")
-
+
# Verify components are available
assert driver.config is not None, "Configuration should be available"
- assert hasattr(driver, 'cache_circuit_breaker'), "Circuit breaker should be available"
-
+ assert hasattr(
+ driver, "cache_circuit_breaker"
+ ), "Circuit breaker should be available"
+
# Test metrics collection
metrics = driver.get_metrics()
logger.info("Initial metrics collected", metrics=metrics)
-
+
assert "initialized" in metrics, "Metrics should include initialization status"
assert metrics["initialized"] == True, "Driver should report as initialized"
-
+
# Clean shutdown
await driver.shutdown()
logger.info("ā
FACT driver initialization test passed")
-
+
except ConfigurationError as e:
logger.error("ā Configuration error during initialization", error=str(e))
logger.info("š” This is expected if API keys are not properly configured")
raise
except Exception as e:
- logger.error("ā FACT driver initialization test failed", error=str(e), exc_info=True)
+ logger.error(
+ "ā FACT driver initialization test failed", error=str(e), exc_info=True
+ )
raise
async def test_cache_resilience_features():
"""Test cache resilience features without external API calls."""
logger.info("=== Testing Cache Resilience Features ===")
-
+
try:
config = Config()
driver = FACTDriver(config)
-
+
await driver.initialize()
logger.info("Driver initialized for cache resilience testing")
-
+
# Test 1: Circuit breaker manipulation
- if hasattr(driver, 'cache_circuit_breaker') and driver.cache_circuit_breaker:
+ if hasattr(driver, "cache_circuit_breaker") and driver.cache_circuit_breaker:
logger.info("Testing circuit breaker state changes")
-
+
# Test initial state
initial_state = driver.cache_circuit_breaker.get_state()
logger.info("Initial circuit breaker state", state=initial_state.value)
-
+
# Force open state
driver.cache_circuit_breaker.force_open()
open_state = driver.cache_circuit_breaker.get_state()
- assert open_state == CircuitState.OPEN, "Circuit breaker should be in OPEN state"
+ assert (
+ open_state == CircuitState.OPEN
+ ), "Circuit breaker should be in OPEN state"
logger.info("ā
Circuit breaker forced to OPEN state")
-
+
# Force closed state
driver.cache_circuit_breaker.force_closed()
closed_state = driver.cache_circuit_breaker.get_state()
- assert closed_state == CircuitState.CLOSED, "Circuit breaker should be in CLOSED state"
+ assert (
+ closed_state == CircuitState.CLOSED
+ ), "Circuit breaker should be in CLOSED state"
logger.info("ā
Circuit breaker forced to CLOSED state")
-
+
else:
- logger.warning("Circuit breaker not available - cache resilience may be limited")
-
+ logger.warning(
+ "Circuit breaker not available - cache resilience may be limited"
+ )
+
# Test 2: Cache degradation mode
logger.info("Testing cache degradation mode")
-
+
# Simulate cache degraded mode
- if hasattr(driver, '_cache_degraded'):
- original_degraded = getattr(driver, '_cache_degraded', False)
-
+ if hasattr(driver, "_cache_degraded"):
+ original_degraded = getattr(driver, "_cache_degraded", False)
+
# Set degraded mode
driver._cache_degraded = True
logger.info("Cache set to degraded mode")
-
+
# Check metrics reflect degraded state
metrics = driver.get_metrics()
if "cache_degraded" in metrics:
- assert metrics["cache_degraded"] == True, "Metrics should show cache degraded"
+ assert (
+ metrics["cache_degraded"] == True
+ ), "Metrics should show cache degraded"
logger.info("ā
Metrics correctly show cache degraded state")
-
+
# Restore original state
driver._cache_degraded = original_degraded
-
+
# Test 3: System metrics during different states
logger.info("Testing metrics collection")
-
+
final_metrics = driver.get_metrics()
-
+
# Verify essential metrics are present
essential_metrics = ["initialized", "total_queries"]
for metric in essential_metrics:
@@ -217,16 +236,17 @@ async def test_cache_resilience_features():
logger.info(f"ā
Metric '{metric}' available: {final_metrics[metric]}")
else:
logger.warning(f"ā ļø Metric '{metric}' not available")
-
+
# Log all cache and circuit breaker related metrics
- cache_metrics = {k: v for k, v in final_metrics.items()
- if k.startswith(('cache', 'circuit'))}
+ cache_metrics = {
+ k: v for k, v in final_metrics.items() if k.startswith(("cache", "circuit"))
+ }
if cache_metrics:
logger.info("Cache and circuit breaker metrics", metrics=cache_metrics)
-
+
await driver.shutdown()
logger.info("ā
Cache resilience features test passed")
-
+
except Exception as e:
logger.error("ā Cache resilience test failed", error=str(e), exc_info=True)
raise
@@ -235,21 +255,21 @@ async def test_cache_resilience_features():
async def test_database_integration():
"""Test database integration and connection handling."""
logger.info("=== Testing Database Integration ===")
-
+
try:
config = Config()
driver = FACTDriver(config)
-
+
# Test database initialization
await driver.initialize()
logger.info("Driver initialized for database testing")
-
+
# Verify database file exists
db_path = config.database_path
if db_path != ":memory:":
assert os.path.exists(db_path), f"Database file should exist at {db_path}"
logger.info("ā
Database file created", path=db_path)
-
+
# Test basic database connectivity
try:
conn = sqlite3.connect(db_path)
@@ -257,22 +277,26 @@ async def test_database_integration():
cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
tables = cursor.fetchall()
conn.close()
-
- logger.info("ā
Database connectivity verified", table_count=len(tables))
-
+
+ logger.info(
+ "ā
Database connectivity verified", table_count=len(tables)
+ )
+
except Exception as e:
logger.warning("Database connectivity test failed", error=str(e))
else:
logger.info("ā
Using in-memory database")
-
+
# Test metrics include database information
metrics = driver.get_metrics()
- logger.info("Database-related metrics collected",
- has_db_metrics=any(k.startswith('db') for k in metrics.keys()))
-
+ logger.info(
+ "Database-related metrics collected",
+ has_db_metrics=any(k.startswith("db") for k in metrics.keys()),
+ )
+
await driver.shutdown()
logger.info("ā
Database integration test passed")
-
+
except Exception as e:
logger.error("ā Database integration test failed", error=str(e), exc_info=True)
raise
@@ -281,90 +305,104 @@ async def test_database_integration():
async def test_performance_monitoring():
"""Test performance monitoring capabilities."""
logger.info("=== Testing Performance Monitoring ===")
-
+
try:
config = Config()
driver = FACTDriver(config)
-
+
await driver.initialize()
logger.info("Driver initialized for performance testing")
-
+
# Test timing measurements
start_time = time.time()
-
+
# Simulate some operations by getting metrics multiple times
for i in range(5):
metrics = driver.get_metrics()
await asyncio.sleep(0.01) # Small delay to simulate work
-
+
elapsed_time = time.time() - start_time
- logger.info("Performance test completed",
- operations=5,
- total_time_ms=elapsed_time * 1000,
- avg_time_ms=(elapsed_time / 5) * 1000)
-
+ logger.info(
+ "Performance test completed",
+ operations=5,
+ total_time_ms=elapsed_time * 1000,
+ avg_time_ms=(elapsed_time / 5) * 1000,
+ )
+
# Verify metrics are collected efficiently
final_metrics = driver.get_metrics()
- performance_keys = [k for k in final_metrics.keys()
- if 'time' in k.lower() or 'latency' in k.lower() or 'performance' in k.lower()]
-
+ performance_keys = [
+ k
+ for k in final_metrics.keys()
+ if "time" in k.lower()
+ or "latency" in k.lower()
+ or "performance" in k.lower()
+ ]
+
if performance_keys:
logger.info("ā
Performance metrics available", keys=performance_keys)
else:
logger.info("ā¹ļø No explicit performance metrics found")
-
+
await driver.shutdown()
logger.info("ā
Performance monitoring test passed")
-
+
except Exception as e:
- logger.error("ā Performance monitoring test failed", error=str(e), exc_info=True)
+ logger.error(
+ "ā Performance monitoring test failed", error=str(e), exc_info=True
+ )
raise
async def test_error_handling():
"""Test error handling and recovery mechanisms."""
logger.info("=== Testing Error Handling ===")
-
+
try:
# Test with invalid configuration
logger.info("Testing configuration error handling")
-
+
# Save original environment
original_db_path = os.getenv("DATABASE_PATH")
-
+
# Set invalid database path to test error handling
os.environ["DATABASE_PATH"] = "/invalid/path/that/does/not/exist.db"
-
+
try:
config = Config()
driver = FACTDriver(config)
await driver.initialize()
-
+
# If we get here, the system handled the invalid path gracefully
logger.info("ā
System handled invalid database path gracefully")
await driver.shutdown()
-
+
except Exception as e:
# This is expected for truly invalid paths
- logger.info("ā
System correctly raised error for invalid database path", error=str(e))
-
+ logger.info(
+ "ā
System correctly raised error for invalid database path",
+ error=str(e),
+ )
+
finally:
# Restore original database path
if original_db_path:
os.environ["DATABASE_PATH"] = original_db_path
-
+
# Test normal operation after error recovery
logger.info("Testing recovery after error")
config = Config()
driver = FACTDriver(config)
await driver.initialize()
-
+
metrics = driver.get_metrics()
- assert metrics["initialized"] == True, "Driver should initialize successfully after error recovery"
-
+ assert (
+ metrics["initialized"] == True
+ ), "Driver should initialize successfully after error recovery"
+
await driver.shutdown()
logger.info("ā
Error handling and recovery test passed")
-
+
except Exception as e:
logger.error("ā Error handling test failed", error=str(e), exc_info=True)
raise
@@ -373,33 +411,35 @@ async def test_error_handling():
async def run_all_integration_tests():
"""Run all real integration tests without mocking."""
logger.info("š§Ŗ Starting FACT Real Integration Tests (No Mocking)")
-
+
temp_db_path = None
-
+
try:
# Set up test environment
temp_db_path = setup_test_environment()
logger.info("Test environment set up successfully")
-
+
# Run all test suites
await test_fact_driver_initialization()
await test_cache_resilience_features()
await test_database_integration()
await test_performance_monitoring()
await test_error_handling()
-
+
logger.info("š All FACT real integration tests passed!")
-
+
except ConfigurationError as e:
logger.error("ā Configuration error in integration tests", error=str(e))
- logger.info("š” Ensure API keys are properly configured in environment variables")
+ logger.info(
+ "š” Ensure API keys are properly configured in environment variables"
+ )
logger.info("š” For testing, you can use placeholder values like 'test-key'")
raise
-
+
except Exception as e:
logger.error("ā Integration tests failed", error=str(e), exc_info=True)
raise
-
+
finally:
# Clean up test environment
if temp_db_path:
@@ -414,4 +454,4 @@ async def run_all_integration_tests():
sys.exit(1)
except Exception as e:
logger.error("Integration tests failed to run", error=str(e))
- sys.exit(1)
\ No newline at end of file
+ sys.exit(1)
diff --git a/scripts/test_response_padding.py b/scripts/test_response_padding.py
index 712e17d..0332665 100644
--- a/scripts/test_response_padding.py
+++ b/scripts/test_response_padding.py
@@ -18,7 +18,7 @@
pad_response_for_caching,
enhance_sql_tool_response,
validate_enhanced_response,
- _estimate_tokens
+ _estimate_tokens,
)
from cache.manager import CacheManager
from cache.config import get_default_cache_config
@@ -26,9 +26,9 @@
def test_sql_response_padding():
"""Test SQL response padding functionality."""
-
+
print("=== FACT Response Padding Test ===\n")
-
+
# Simulate typical SQL tool responses that are too small for caching
test_responses = {
"simple_select": """
@@ -42,7 +42,6 @@ def test_sql_response_padding():
Results: 8 rows returned
Execution time: 0.023 seconds
""".strip(),
-
"insert_operation": """
INSERT INTO products (name, price, category_id, description)
VALUES
@@ -53,7 +52,6 @@ def test_sql_response_padding():
Query executed successfully.
3 rows inserted.
""".strip(),
-
"table_creation": """
CREATE TABLE analytics_events (
id SERIAL PRIMARY KEY,
@@ -66,106 +64,106 @@ def test_sql_response_padding():
);
Table created successfully.
-""".strip()
+""".strip(),
}
-
+
print("1. Testing original response token counts:")
for name, response in test_responses.items():
tokens = _estimate_tokens(response)
print(f" {name}: {tokens} tokens")
-
+
print(f"\n2. Testing response padding (target: 500 tokens):")
enhanced_responses = {}
-
+
for name, response in test_responses.items():
print(f"\n Processing {name}...")
-
+
try:
# Test basic padding
enhanced = pad_response_for_caching(
- content=response,
- content_type="sql",
- target_tokens=500
+ content=response, content_type="sql", target_tokens=500
)
-
+
enhanced_responses[name] = enhanced
-
+
# Validate enhancement
validation = validate_enhanced_response(response, enhanced, 500)
-
+
print(f" ā
Original: {validation['original_tokens']} tokens")
print(f" ā
Enhanced: {validation['enhanced_tokens']} tokens")
print(f" ā
Meets requirement: {validation['meets_requirement']}")
print(f" ā
Original preserved: {validation['original_preserved']}")
-
+
except Exception as e:
print(f" ā Failed to enhance {name}: {e}")
-
+
print(f"\n3. Testing SQL-specific enhancement function:")
-
+
# Test with query context
context = {
"query_type": "SELECT",
"tables_accessed": ["users", "profiles"],
"execution_time_ms": 23,
- "rows_returned": 8
+ "rows_returned": 8,
}
-
+
enhanced_with_context = enhance_sql_tool_response(
sql_response=test_responses["simple_select"],
query_context=context,
- min_tokens=500
+ min_tokens=500,
)
-
+
context_validation = validate_enhanced_response(
- test_responses["simple_select"],
- enhanced_with_context,
- 500
+ test_responses["simple_select"], enhanced_with_context, 500
+ )
+
+ print(
+ f" ā
Enhanced with context: {context_validation['enhanced_tokens']} tokens"
)
-
- print(f" ā
Enhanced with context: {context_validation['enhanced_tokens']} tokens")
print(f" ā
Context preserved: {'query_type' in enhanced_with_context}")
-
+
print(f"\n4. Testing cache integration:")
-
+
# Test that enhanced responses can be cached
config = get_default_cache_config()
cache_manager = CacheManager(config)
-
+
successful_caches = 0
-
+
for name, enhanced_response in enhanced_responses.items():
try:
query_hash = cache_manager.generate_hash(f"test_query_{name}")
entry = cache_manager.store(query_hash, enhanced_response)
-
+
print(f" ā
Cached {name}: {entry.token_count} tokens")
successful_caches += 1
-
+
# Verify retrieval
retrieved = cache_manager.get(query_hash)
if retrieved and retrieved.content == enhanced_response:
print(f" ā
Retrieved {name} successfully")
else:
print(f" ā Failed to retrieve {name}")
-
+
except Exception as e:
print(f" ā Failed to cache {name}: {e}")
-
+
print(f"\n5. Performance and quality metrics:")
print(f" ā
Responses enhanced: {len(enhanced_responses)}/{len(test_responses)}")
print(f" ā
Successfully cached: {successful_caches}/{len(enhanced_responses)}")
-
+
# Calculate average enhancement ratio
total_ratio = 0
for name, response in test_responses.items():
if name in enhanced_responses:
- validation = validate_enhanced_response(response, enhanced_responses[name], 500)
- total_ratio += validation['enhancement_ratio']
-
+ validation = validate_enhanced_response(
+ response, enhanced_responses[name], 500
+ )
+ total_ratio += validation["enhancement_ratio"]
+
avg_ratio = total_ratio / len(enhanced_responses) if enhanced_responses else 0
print(f" ā
Average enhancement ratio: {avg_ratio:.2f}x")
-
+
print(f"\n=== Test Summary ===")
print("ā
SQL response padding successfully implemented")
print("ā
Small responses (320-368 tokens) enhanced to meet 500+ token requirement")
@@ -173,55 +171,57 @@ def test_sql_response_padding():
print("ā
Enhanced responses successfully cached and retrieved")
print("ā
Context-aware enhancement provides additional value")
print("ā
Performance metrics within acceptable ranges")
-
+
return True
def test_edge_cases():
"""Test edge cases and error handling."""
-
+
print("\n=== Edge Case Testing ===\n")
-
+
# Test empty content
try:
pad_response_for_caching("", "sql", 500)
print("ā Should have failed on empty content")
except ValueError:
print("ā
Correctly rejected empty content")
-
+
# Test invalid target tokens
try:
pad_response_for_caching("test", "sql", 50)
print("ā Should have failed on low target tokens")
except ValueError:
print("ā
Correctly rejected invalid target tokens")
-
+
# Test content that already meets requirements
- long_content = "This is a long response with detailed information. " * 120 # ~600+ tokens
+ long_content = (
+ "This is a long response with detailed information. " * 120
+ ) # ~600+ tokens
result = pad_response_for_caching(long_content, "sql", 500)
-
+
if result == long_content:
print("ā
Correctly returned unmodified content when requirements already met")
else:
print("ā Unnecessarily modified content that already met requirements")
-
+
# Test different content types
json_content = '{"result": "success", "data": [1, 2, 3]}'
json_enhanced = pad_response_for_caching(json_content, "json", 500)
-
+
if _estimate_tokens(json_enhanced) >= 500:
print("ā
JSON content type padding works correctly")
else:
print("ā JSON content type padding failed")
-
+
print("ā
Edge case testing completed")
def demonstrate_caching_improvement():
"""Demonstrate the caching improvement for SQL responses."""
-
+
print("\n=== Caching Improvement Demonstration ===\n")
-
+
# Simulate the original problem: SQL response too small for caching
original_sql_response = """
SELECT COUNT(*) as total_users,
@@ -233,52 +233,52 @@ def demonstrate_caching_improvement():
Result: total_users=1247, avg_age=32.4, last_signup=2024-01-15
Execution time: 0.045 seconds
""".strip()
-
+
print("Before enhancement:")
original_tokens = _estimate_tokens(original_sql_response)
print(f" Token count: {original_tokens}")
print(f" Meets caching requirement (500+): {original_tokens >= 500}")
-
+
# Try to cache original response
config = get_default_cache_config()
cache_manager = CacheManager(config)
-
+
try:
query_hash = cache_manager.generate_hash("demo_query")
cache_manager.store(query_hash, original_sql_response)
print(" ā Original response should not have been cacheable")
except Exception:
print(" ā
Original response correctly rejected for caching")
-
+
print("\nAfter enhancement:")
enhanced_response = enhance_sql_tool_response(
sql_response=original_sql_response,
query_context={
"operation": "aggregation",
"performance": "optimized",
- "tables": ["users"]
+ "tables": ["users"],
},
- min_tokens=500
+ min_tokens=500,
)
-
+
enhanced_tokens = _estimate_tokens(enhanced_response)
print(f" Token count: {enhanced_tokens}")
print(f" Meets caching requirement (500+): {enhanced_tokens >= 500}")
-
+
try:
query_hash = cache_manager.generate_hash("demo_query_enhanced")
entry = cache_manager.store(query_hash, enhanced_response)
print(f" ā
Enhanced response successfully cached")
-
+
# Verify retrieval
retrieved = cache_manager.get(query_hash)
if retrieved:
print(f" ā
Enhanced response successfully retrieved from cache")
print(f" ā
Cache hit preserves enhanced content and context")
-
+
except Exception as e:
print(f" ā Failed to cache enhanced response: {e}")
-
+
print(f"\nImprovement summary:")
print(f" Token increase: {enhanced_tokens - original_tokens} tokens")
print(f" Enhancement ratio: {enhanced_tokens / original_tokens:.2f}x")
@@ -288,15 +288,15 @@ def demonstrate_caching_improvement():
if __name__ == "__main__":
print("Testing FACT system response padding utilities...\n")
-
+
try:
test_sql_response_padding()
test_edge_cases()
demonstrate_caching_improvement()
-
+
print(f"\nš All tests completed successfully!")
print(f"Response padding is ready for production use.")
-
+
except Exception as e:
print(f"\nā Test failed with error: {e}")
- sys.exit(1)
\ No newline at end of file
+ sys.exit(1)
diff --git a/scripts/validate_complete_system.py b/scripts/validate_complete_system.py
index 6f55e52..9cdff32 100755
--- a/scripts/validate_complete_system.py
+++ b/scripts/validate_complete_system.py
@@ -25,7 +25,7 @@
from datetime import datetime
# Add src to path for imports
-sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'src'))
+sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "src"))
from core.config import Config, ConfigurationError
from core.driver import get_driver
@@ -39,8 +39,15 @@
class ValidationResult:
"""Represents the result of a validation check."""
-
- def __init__(self, component: str, test: str, success: bool, message: str, details: Dict[str, Any] = None):
+
+ def __init__(
+ self,
+ component: str,
+ test: str,
+ success: bool,
+ message: str,
+ details: Dict[str, Any] = None,
+ ):
self.component = component
self.test = test
self.success = success
@@ -51,49 +58,55 @@ def __init__(self, component: str, test: str, success: bool, message: str, detai
class SystemValidator:
"""Comprehensive FACT system validator."""
-
+
def __init__(self, verbose: bool = False, fix_issues: bool = False):
self.verbose = verbose
self.fix_issues = fix_issues
self.results: List[ValidationResult] = []
self.setup_logging()
-
+
def setup_logging(self):
"""Set up logging configuration."""
level = logging.DEBUG if self.verbose else logging.INFO
logging.basicConfig(
- level=level,
- format='%(asctime)s - %(levelname)s - %(message)s'
+ level=level, format="%(asctime)s - %(levelname)s - %(message)s"
)
self.logger = logging.getLogger(__name__)
-
- def add_result(self, component: str, test: str, success: bool, message: str, details: Dict[str, Any] = None):
+
+ def add_result(
+ self,
+ component: str,
+ test: str,
+ success: bool,
+ message: str,
+ details: Dict[str, Any] = None,
+ ):
"""Add a validation result."""
result = ValidationResult(component, test, success, message, details)
self.results.append(result)
-
+
status = "ā
" if success else "ā"
print(f"{status} {component}: {test} - {message}")
-
+
if self.verbose and details:
for key, value in details.items():
print(f" {key}: {value}")
-
+
def validate_environment_configuration(self) -> bool:
"""Validate environment configuration."""
print("\nš§ Validating Environment Configuration...")
all_passed = True
-
+
try:
# Test .env file existence
- env_path = Path('.env')
+ env_path = Path(".env")
if env_path.exists():
self.add_result(
- "Environment",
- ".env file",
- True,
+ "Environment",
+ ".env file",
+ True,
f"Found at {env_path.absolute()}",
- {"file_size": env_path.stat().st_size}
+ {"file_size": env_path.stat().st_size},
)
else:
self.add_result("Environment", ".env file", False, "Missing .env file")
@@ -101,256 +114,337 @@ def validate_environment_configuration(self) -> bool:
print(" š§ Creating .env file with defaults...")
self.create_default_env_file()
all_passed = False
-
+
# Test configuration loading
config = Config()
self.add_result(
- "Environment",
- "Configuration loading",
- True,
+ "Environment",
+ "Configuration loading",
+ True,
"Configuration loaded successfully",
- config.to_dict()
+ config.to_dict(),
)
-
+
# Validate API keys
- if config.anthropic_api_key and not config.anthropic_api_key.startswith('your_'):
- self.add_result("Environment", "Anthropic API key", True, "Valid API key configured")
+ if config.anthropic_api_key and not config.anthropic_api_key.startswith(
+ "your_"
+ ):
+ self.add_result(
+ "Environment", "Anthropic API key", True, "Valid API key configured"
+ )
else:
- self.add_result("Environment", "Anthropic API key", False, "API key not configured or placeholder")
+ self.add_result(
+ "Environment",
+ "Anthropic API key",
+ False,
+ "API key not configured or placeholder",
+ )
all_passed = False
-
- if config.arcade_api_key and not config.arcade_api_key.startswith('your_'):
- self.add_result("Environment", "Arcade API key", True, "Valid API key configured")
+
+ if config.arcade_api_key and not config.arcade_api_key.startswith("your_"):
+ self.add_result(
+ "Environment", "Arcade API key", True, "Valid API key configured"
+ )
else:
- self.add_result("Environment", "Arcade API key", False, "API key not configured or placeholder")
+ self.add_result(
+ "Environment",
+ "Arcade API key",
+ False,
+ "API key not configured or placeholder",
+ )
all_passed = False
-
+
return all_passed
-
+
except ConfigurationError as e:
self.add_result("Environment", "Configuration loading", False, str(e))
return False
except Exception as e:
- self.add_result("Environment", "Configuration loading", False, f"Unexpected error: {e}")
+ self.add_result(
+ "Environment", "Configuration loading", False, f"Unexpected error: {e}"
+ )
return False
-
+
async def validate_database_integration(self) -> bool:
"""Validate database integration."""
print("\nšļø Validating Database Integration...")
all_passed = True
-
+
try:
config = Config()
db_manager = DatabaseManager(config.database_path)
-
+
# Test database initialization
await db_manager.initialize_database()
- self.add_result("Database", "Initialization", True, "Database initialized successfully")
-
+ self.add_result(
+ "Database", "Initialization", True, "Database initialized successfully"
+ )
+
# Test database info retrieval
db_info = await db_manager.get_database_info()
self.add_result(
- "Database",
- "Schema validation",
- True,
+ "Database",
+ "Schema validation",
+ True,
f"Found {db_info['total_tables']} tables",
- db_info
+ db_info,
)
-
+
# Validate required tables
required_tables = ["companies", "financial_data", "benchmarks"]
for table in required_tables:
- if table in db_info['tables']:
- row_count = db_info['tables'][table]['row_count']
+ if table in db_info["tables"]:
+ row_count = db_info["tables"][table]["row_count"]
self.add_result(
- "Database",
- f"Table {table}",
- True,
+ "Database",
+ f"Table {table}",
+ True,
f"{row_count} rows",
- {"row_count": row_count}
+ {"row_count": row_count},
)
else:
- self.add_result("Database", f"Table {table}", False, "Table missing")
+ self.add_result(
+ "Database", f"Table {table}", False, "Table missing"
+ )
all_passed = False
-
+
return all_passed
-
+
except Exception as e:
self.add_result("Database", "Integration", False, f"Database error: {e}")
return False
-
+
async def validate_api_connectivity(self) -> bool:
"""Validate API connectivity."""
print("\nš Validating API Connectivity...")
all_passed = True
-
+
try:
config = Config()
-
+
# Skip API tests if keys are not configured
- if (not config.anthropic_api_key or config.anthropic_api_key.startswith('your_') or
- not config.arcade_api_key or config.arcade_api_key.startswith('your_')):
- self.add_result("API", "Connectivity test", False, "API keys not configured - skipping connectivity tests")
+ if (
+ not config.anthropic_api_key
+ or config.anthropic_api_key.startswith("your_")
+ or not config.arcade_api_key
+ or config.arcade_api_key.startswith("your_")
+ ):
+ self.add_result(
+ "API",
+ "Connectivity test",
+ False,
+ "API keys not configured - skipping connectivity tests",
+ )
return False
-
+
# Test driver initialization (includes API connectivity)
try:
driver = await get_driver()
- self.add_result("API", "Driver initialization", True, "Driver initialized successfully")
-
+ self.add_result(
+ "API",
+ "Driver initialization",
+ True,
+ "Driver initialized successfully",
+ )
+
# Test basic functionality
metrics = driver.get_metrics()
self.add_result(
- "API",
- "System metrics",
- True,
+ "API",
+ "System metrics",
+ True,
f"System operational - {len(driver.tool_registry.list_tools())} tools available",
- metrics
+ metrics,
)
-
+
await driver.shutdown()
return True
-
+
except Exception as e:
- self.add_result("API", "Driver initialization", False, f"Driver error: {e}")
+ self.add_result(
+ "API", "Driver initialization", False, f"Driver error: {e}"
+ )
return False
-
+
except Exception as e:
self.add_result("API", "Connectivity", False, f"API error: {e}")
return False
-
+
async def validate_cache_system(self) -> bool:
"""Validate cache system."""
print("\nš¾ Validating Cache System...")
all_passed = True
-
+
try:
config = Config()
cache_config = config.cache_config
-
+
self.add_result(
- "Cache",
- "Configuration",
- True,
+ "Cache",
+ "Configuration",
+ True,
"Cache configuration loaded",
- cache_config
+ cache_config,
)
-
+
# Test cache validator
validator = CacheValidator()
- self.add_result("Cache", "Validator initialization", True, "Cache validator initialized")
-
+ self.add_result(
+ "Cache", "Validator initialization", True, "Cache validator initialized"
+ )
+
# Note: Actual cache performance testing requires live system
# This validates the cache system is properly configured
-
+
return all_passed
-
+
except Exception as e:
self.add_result("Cache", "System", False, f"Cache error: {e}")
return False
-
+
def validate_security_layer(self) -> bool:
"""Validate security layer."""
print("\nš”ļø Validating Security Layer...")
all_passed = True
-
+
try:
# Test authorization manager initialization
auth_manager = AuthorizationManager()
- self.add_result("Security", "Authorization manager", True, "Authorization manager initialized")
-
+ self.add_result(
+ "Security",
+ "Authorization manager",
+ True,
+ "Authorization manager initialized",
+ )
+
# Test input sanitizer initialization
input_sanitizer = InputSanitizer()
- self.add_result("Security", "Input sanitizer", True, "Input sanitizer initialized")
-
+ self.add_result(
+ "Security", "Input sanitizer", True, "Input sanitizer initialized"
+ )
+
# Test input sanitization
test_inputs = [
("Normal query", "What is TechCorp's revenue?", True),
("XSS attempt", "", False),
("SQL injection", "'; DROP TABLE users; --", False),
- ("Valid financial query", "Show me Q1 2025 financial data", True)
+ ("Valid financial query", "Show me Q1 2025 financial data", True),
]
-
+
for test_name, test_input, should_pass in test_inputs:
try:
result = input_sanitizer.sanitize_string(test_input)
if should_pass:
- self.add_result("Security", f"Input validation: {test_name}", True, "Input accepted")
+ self.add_result(
+ "Security",
+ f"Input validation: {test_name}",
+ True,
+ "Input accepted",
+ )
else:
- self.add_result("Security", f"Input validation: {test_name}", False, "Malicious input should be blocked")
+ self.add_result(
+ "Security",
+ f"Input validation: {test_name}",
+ False,
+ "Malicious input should be blocked",
+ )
all_passed = False
except Exception:
if not should_pass:
- self.add_result("Security", f"Input validation: {test_name}", True, "Malicious input blocked")
+ self.add_result(
+ "Security",
+ f"Input validation: {test_name}",
+ True,
+ "Malicious input blocked",
+ )
else:
- self.add_result("Security", f"Input validation: {test_name}", False, "Valid input rejected")
+ self.add_result(
+ "Security",
+ f"Input validation: {test_name}",
+ False,
+ "Valid input rejected",
+ )
all_passed = False
-
+
return all_passed
-
+
except Exception as e:
self.add_result("Security", "Layer", False, f"Security error: {e}")
return False
-
+
async def validate_benchmark_framework(self) -> bool:
"""Validate benchmark framework."""
print("\nš Validating Benchmark Framework...")
all_passed = True
-
+
try:
# Test framework initialization
framework = BenchmarkFramework()
- self.add_result("Benchmark", "Framework initialization", True, "Framework initialized")
-
+ self.add_result(
+ "Benchmark", "Framework initialization", True, "Framework initialized"
+ )
+
# Test configuration
config_details = {
"iterations": framework.config.iterations,
"warmup_iterations": framework.config.warmup_iterations,
"concurrent_users": framework.config.concurrent_users,
- "timeout_seconds": framework.config.timeout_seconds
+ "timeout_seconds": framework.config.timeout_seconds,
}
-
+
self.add_result(
- "Benchmark",
- "Configuration",
- True,
+ "Benchmark",
+ "Configuration",
+ True,
"Benchmark configuration loaded",
- config_details
+ config_details,
)
-
+
# Test empty benchmark suite (should handle gracefully)
summary = await framework.run_benchmark_suite([])
if summary.total_queries == 0:
- self.add_result("Benchmark", "Empty suite handling", True, "Empty benchmark suite handled correctly")
+ self.add_result(
+ "Benchmark",
+ "Empty suite handling",
+ True,
+ "Empty benchmark suite handled correctly",
+ )
else:
- self.add_result("Benchmark", "Empty suite handling", False, "Empty benchmark suite not handled correctly")
+ self.add_result(
+ "Benchmark",
+ "Empty suite handling",
+ False,
+ "Empty benchmark suite not handled correctly",
+ )
all_passed = False
-
+
return all_passed
-
+
except Exception as e:
self.add_result("Benchmark", "Framework", False, f"Benchmark error: {e}")
return False
-
+
def validate_monitoring_system(self) -> bool:
"""Validate monitoring system."""
print("\nš Validating Monitoring System...")
all_passed = True
-
+
try:
# Test metrics collector
metrics_collector = MetricsCollector()
- self.add_result("Monitoring", "Metrics collector", True, "Metrics collector initialized")
-
+ self.add_result(
+ "Monitoring", "Metrics collector", True, "Metrics collector initialized"
+ )
+
return all_passed
-
+
except Exception as e:
self.add_result("Monitoring", "System", False, f"Monitoring error: {e}")
return False
-
+
def create_default_env_file(self):
"""Create a default .env file."""
- env_content = '''# FACT System Configuration
+ env_content = """# FACT System Configuration
# Update the API keys below with your actual credentials
# Required API Keys - MUST be configured for system to work
@@ -374,80 +468,84 @@ def create_default_env_file(self):
CACHE_HIT_TARGET_MS=30
CACHE_MISS_TARGET_MS=120
MAX_CONCURRENT_QUERIES=50
-'''
-
- with open('.env', 'w') as f:
+"""
+
+ with open(".env", "w") as f:
f.write(env_content)
print(" š Created .env file with default configuration")
-
+
def print_summary(self):
"""Print validation summary."""
- print("\n" + "="*60)
+ print("\n" + "=" * 60)
print("šÆ FACT SYSTEM VALIDATION SUMMARY")
- print("="*60)
-
+ print("=" * 60)
+
# Count results by component
component_results = {}
for result in self.results:
if result.component not in component_results:
component_results[result.component] = {"passed": 0, "failed": 0}
-
+
if result.success:
component_results[result.component]["passed"] += 1
else:
component_results[result.component]["failed"] += 1
-
+
# Print component summary
total_passed = 0
total_failed = 0
-
+
for component, counts in component_results.items():
passed = counts["passed"]
failed = counts["failed"]
total = passed + failed
-
+
status = "ā
" if failed == 0 else "ā"
print(f"{status} {component}: {passed}/{total} tests passed")
-
+
total_passed += passed
total_failed += failed
-
+
print("-" * 60)
-
+
overall_total = total_passed + total_failed
success_rate = (total_passed / overall_total * 100) if overall_total > 0 else 0
-
+
if total_failed == 0:
- print(f"š ALL SYSTEMS OPERATIONAL: {total_passed}/{overall_total} tests passed ({success_rate:.1f}%)")
+ print(
+ f"š ALL SYSTEMS OPERATIONAL: {total_passed}/{overall_total} tests passed ({success_rate:.1f}%)"
+ )
print("\nā
Your FACT system is fully integrated and ready for use!")
print("\nNext steps:")
print("1. Start the CLI: python main.py cli")
print("2. Run benchmarks: python scripts/run_benchmarks.py")
print("3. Try sample queries in the interactive CLI")
else:
- print(f"ā ļø ISSUES FOUND: {total_passed}/{overall_total} tests passed ({success_rate:.1f}%)")
+ print(
+ f"ā ļø ISSUES FOUND: {total_passed}/{overall_total} tests passed ({success_rate:.1f}%)"
+ )
print(f"\nā {total_failed} issue(s) need attention:")
-
+
for result in self.results:
if not result.success:
print(f" ⢠{result.component}: {result.test} - {result.message}")
-
+
print("\nš” Recommendations:")
if any("API key" in r.test for r in self.results if not r.success):
print(" 1. Update API keys in .env file with actual credentials")
print(" 2. Get Anthropic API key: https://console.anthropic.com/")
print(" 3. Get Arcade API key: https://www.arcade-ai.com/")
-
+
if any("Database" in r.component for r in self.results if not r.success):
print(" 4. Run: python main.py init")
-
+
print(" 5. Re-run validation: python scripts/validate_complete_system.py")
-
+
async def run_all_validations(self) -> bool:
"""Run all validation checks."""
print("š FACT System Complete Validation")
- print("="*50)
-
+ print("=" * 50)
+
validations = [
("Environment Configuration", self.validate_environment_configuration()),
("Database Integration", self.validate_database_integration()),
@@ -455,25 +553,25 @@ async def run_all_validations(self) -> bool:
("Cache System", self.validate_cache_system()),
("Security Layer", self.validate_security_layer()),
("Benchmark Framework", self.validate_benchmark_framework()),
- ("Monitoring System", self.validate_monitoring_system())
+ ("Monitoring System", self.validate_monitoring_system()),
]
-
+
all_passed = True
-
+
for name, validation in validations:
try:
if asyncio.iscoroutine(validation):
result = await validation
else:
result = validation
-
+
if not result:
all_passed = False
-
+
except Exception as e:
self.add_result("System", name, False, f"Validation failed: {e}")
all_passed = False
-
+
self.print_summary()
return all_passed
@@ -482,29 +580,30 @@ async def main():
"""Main validation routine."""
parser = argparse.ArgumentParser(
description="FACT System Complete Validation",
- formatter_class=argparse.RawDescriptionHelpFormatter
+ formatter_class=argparse.RawDescriptionHelpFormatter,
)
-
+
parser.add_argument(
- "--verbose", "-v",
+ "--verbose",
+ "-v",
action="store_true",
- help="Enable verbose output with detailed information"
+ help="Enable verbose output with detailed information",
)
-
+
parser.add_argument(
"--fix-issues",
- action="store_true",
- help="Attempt to automatically fix common issues"
+ action="store_true",
+ help="Attempt to automatically fix common issues",
)
-
+
args = parser.parse_args()
-
+
validator = SystemValidator(verbose=args.verbose, fix_issues=args.fix_issues)
success = await validator.run_all_validations()
-
+
# Exit with appropriate code
sys.exit(0 if success else 1)
if __name__ == "__main__":
- asyncio.run(main())
\ No newline at end of file
+ asyncio.run(main())
diff --git a/scripts/validate_env.py b/scripts/validate_env.py
index 90ac091..e620571 100755
--- a/scripts/validate_env.py
+++ b/scripts/validate_env.py
@@ -19,336 +19,401 @@
from typing import Dict, List, Tuple, Optional, Any
# Add src to path for imports
-sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'src'))
+sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "src"))
try:
from dotenv import load_dotenv
except ImportError:
- print("Warning: python-dotenv not installed. Install with: pip install python-dotenv")
+ print(
+ "Warning: python-dotenv not installed. Install with: pip install python-dotenv"
+ )
load_dotenv = None
class ConfigurationValidator:
"""Validates FACT system configuration against requirements."""
-
+
# API Key validation patterns from specification
API_KEY_PATTERNS = {
- 'ANTHROPIC_API_KEY': r'^sk-ant-api03-[A-Za-z0-9_-]+$',
- 'ARCADE_API_KEY': r'^arc_[A-Za-z0-9_-]+$',
- 'OPENAI_API_KEY': r'^sk-proj-[A-Za-z0-9_-]+$'
+ "ANTHROPIC_API_KEY": r"^sk-ant-api03-[A-Za-z0-9_-]+$",
+ "ARCADE_API_KEY": r"^arc_[A-Za-z0-9_-]+$",
+ "OPENAI_API_KEY": r"^sk-proj-[A-Za-z0-9_-]+$",
}
-
+
# Required configuration keys
- REQUIRED_KEYS = ['ANTHROPIC_API_KEY', 'ARCADE_API_KEY']
-
+ REQUIRED_KEYS = ["ANTHROPIC_API_KEY", "ARCADE_API_KEY"]
+
# Parameter validation rules
PARAMETER_RULES = {
- 'MAX_RETRIES': {'type': int, 'min': 1, 'max': 10, 'default': 3},
- 'REQUEST_TIMEOUT': {'type': int, 'min': 5, 'max': 300, 'default': 30},
- 'CACHE_TTL_SECONDS': {'type': int, 'min': 60, 'max': 86400, 'default': 3600},
- 'CACHE_MIN_TOKENS': {'type': int, 'min': 1, 'max': 1000, 'default': 50},
- 'VALIDATION_MAX_STRING_LENGTH': {'type': int, 'min': 1, 'max': 100000, 'default': 10000},
- 'LOG_LEVEL': {'type': str, 'choices': ['DEBUG', 'INFO', 'WARNING', 'ERROR'], 'default': 'INFO'},
- 'CLAUDE_MODEL': {'type': str, 'pattern': r'^claude-3-', 'default': 'claude-3-5-sonnet-20241022'},
- 'CACHE_PREFIX': {'type': str, 'pattern': r'^[A-Za-z0-9_]+$', 'min_length': 1, 'max_length': 50, 'default': 'fact_v1'}
+ "MAX_RETRIES": {"type": int, "min": 1, "max": 10, "default": 3},
+ "REQUEST_TIMEOUT": {"type": int, "min": 5, "max": 300, "default": 30},
+ "CACHE_TTL_SECONDS": {"type": int, "min": 60, "max": 86400, "default": 3600},
+ "CACHE_MIN_TOKENS": {"type": int, "min": 1, "max": 1000, "default": 50},
+ "VALIDATION_MAX_STRING_LENGTH": {
+ "type": int,
+ "min": 1,
+ "max": 100000,
+ "default": 10000,
+ },
+ "LOG_LEVEL": {
+ "type": str,
+ "choices": ["DEBUG", "INFO", "WARNING", "ERROR"],
+ "default": "INFO",
+ },
+ "CLAUDE_MODEL": {
+ "type": str,
+ "pattern": r"^claude-3-",
+ "default": "claude-3-5-sonnet-20241022",
+ },
+ "CACHE_PREFIX": {
+ "type": str,
+ "pattern": r"^[A-Za-z0-9_]+$",
+ "min_length": 1,
+ "max_length": 50,
+ "default": "fact_v1",
+ },
}
-
+
def __init__(self, verbose: bool = False):
self.verbose = verbose
self.errors: List[str] = []
self.warnings: List[str] = []
self.config: Dict[str, Any] = {}
-
+
def load_environment(self) -> bool:
"""Load environment configuration from .env file and system environment."""
try:
# Load .env file if it exists
- env_path = Path('.env')
+ env_path = Path(".env")
if env_path.exists():
if load_dotenv:
load_dotenv(env_path)
self.log_info(f"ā
Loaded configuration from {env_path}")
else:
- self.log_warning("ā ļø python-dotenv not available, using system environment only")
+ self.log_warning(
+ "ā ļø python-dotenv not available, using system environment only"
+ )
else:
- self.log_warning("ā ļø No .env file found, using system environment variables")
-
+ self.log_warning(
+ "ā ļø No .env file found, using system environment variables"
+ )
+
# Load all environment variables
self.config = dict(os.environ)
return True
-
+
except Exception as e:
self.errors.append(f"Failed to load environment: {e}")
return False
-
+
def validate_required_keys(self) -> bool:
"""Validate that all required API keys are present."""
missing_keys = []
invalid_keys = []
-
+
for key in self.REQUIRED_KEYS:
- value = self.config.get(key, '').strip()
-
+ value = self.config.get(key, "").strip()
+
if not value or value == f"your_{key.lower()}_here":
missing_keys.append(key)
continue
-
+
# Validate format
pattern = self.API_KEY_PATTERNS.get(key)
if pattern and not re.match(pattern, value):
invalid_keys.append(f"{key} (format: {pattern})")
-
+
if missing_keys:
- self.errors.append(f"Missing required configuration keys: {', '.join(missing_keys)}")
-
+ self.errors.append(
+ f"Missing required configuration keys: {', '.join(missing_keys)}"
+ )
+
if invalid_keys:
self.errors.append(f"Invalid API key format: {', '.join(invalid_keys)}")
-
+
return len(missing_keys) == 0 and len(invalid_keys) == 0
-
+
def validate_optional_keys(self) -> bool:
"""Validate optional API keys if present."""
invalid_keys = []
-
+
for key, pattern in self.API_KEY_PATTERNS.items():
if key in self.REQUIRED_KEYS:
continue
-
- value = self.config.get(key, '').strip()
+
+ value = self.config.get(key, "").strip()
if value and value != f"your_{key.lower()}_here":
if not re.match(pattern, value):
invalid_keys.append(f"{key} (format: {pattern})")
-
+
if invalid_keys:
- self.errors.append(f"Invalid optional API key format: {', '.join(invalid_keys)}")
-
+ self.errors.append(
+ f"Invalid optional API key format: {', '.join(invalid_keys)}"
+ )
+
return len(invalid_keys) == 0
-
+
def validate_parameters(self) -> bool:
"""Validate configuration parameters against rules."""
parameter_errors = []
-
+
for param, rules in self.PARAMETER_RULES.items():
value = self.config.get(param)
-
+
if value is None:
if self.verbose:
- self.log_info(f"š {param}: Using default value ({rules.get('default')})")
+ self.log_info(
+ f"š {param}: Using default value ({rules.get('default')})"
+ )
continue
-
+
# Type validation
- if rules['type'] == int:
+ if rules["type"] == int:
try:
int_value = int(value)
- if 'min' in rules and int_value < rules['min']:
- parameter_errors.append(f"{param}={value} below minimum {rules['min']}")
- elif 'max' in rules and int_value > rules['max']:
- parameter_errors.append(f"{param}={value} above maximum {rules['max']}")
+ if "min" in rules and int_value < rules["min"]:
+ parameter_errors.append(
+ f"{param}={value} below minimum {rules['min']}"
+ )
+ elif "max" in rules and int_value > rules["max"]:
+ parameter_errors.append(
+ f"{param}={value} above maximum {rules['max']}"
+ )
except ValueError:
parameter_errors.append(f"{param}={value} is not a valid integer")
-
- elif rules['type'] == str:
- if 'choices' in rules and value not in rules['choices']:
- parameter_errors.append(f"{param}={value} not in {rules['choices']}")
- elif 'pattern' in rules and not re.match(rules['pattern'], value):
- parameter_errors.append(f"{param}={value} doesn't match pattern {rules['pattern']}")
- elif 'min_length' in rules and len(value) < rules['min_length']:
- parameter_errors.append(f"{param}={value} too short (min: {rules['min_length']})")
- elif 'max_length' in rules and len(value) > rules['max_length']:
- parameter_errors.append(f"{param}={value} too long (max: {rules['max_length']})")
-
+
+ elif rules["type"] == str:
+ if "choices" in rules and value not in rules["choices"]:
+ parameter_errors.append(
+ f"{param}={value} not in {rules['choices']}"
+ )
+ elif "pattern" in rules and not re.match(rules["pattern"], value):
+ parameter_errors.append(
+ f"{param}={value} doesn't match pattern {rules['pattern']}"
+ )
+ elif "min_length" in rules and len(value) < rules["min_length"]:
+ parameter_errors.append(
+ f"{param}={value} too short (min: {rules['min_length']})"
+ )
+ elif "max_length" in rules and len(value) > rules["max_length"]:
+ parameter_errors.append(
+ f"{param}={value} too long (max: {rules['max_length']})"
+ )
+
if parameter_errors:
self.errors.extend(parameter_errors)
-
+
return len(parameter_errors) == 0
-
+
def validate_cache_size(self) -> bool:
"""Validate cache size parameter format."""
- cache_size = self.config.get('CACHE_MAX_SIZE', '').strip()
-
+ cache_size = self.config.get("CACHE_MAX_SIZE", "").strip()
+
if not cache_size:
if self.verbose:
self.log_info("š CACHE_MAX_SIZE: Using default value (100MB)")
return True
-
+
# Validate cache size format
- cache_pattern = r'^(\d+)(K|M|G|T)?B$'
+ cache_pattern = r"^(\d+)(K|M|G|T)?B$"
match = re.match(cache_pattern, cache_size.upper())
-
+
if not match:
- self.errors.append(f"CACHE_MAX_SIZE={cache_size} invalid format (expected: [number][KMGT]B)")
+ self.errors.append(
+ f"CACHE_MAX_SIZE={cache_size} invalid format (expected: [number][KMGT]B)"
+ )
return False
-
+
# Convert to bytes for range validation
size_value = int(match.group(1))
- unit = match.group(2) or ''
-
- multipliers = {'': 1, 'K': 1024, 'M': 1024**2, 'G': 1024**3, 'T': 1024**4}
+ unit = match.group(2) or ""
+
+ multipliers = {"": 1, "K": 1024, "M": 1024**2, "G": 1024**3, "T": 1024**4}
bytes_value = size_value * multipliers[unit]
-
+
min_bytes = 1024**2 # 1MB
max_bytes = 10 * 1024**3 # 10GB
-
+
if bytes_value < min_bytes:
self.errors.append(f"CACHE_MAX_SIZE={cache_size} below minimum 1MB")
return False
elif bytes_value > max_bytes:
self.errors.append(f"CACHE_MAX_SIZE={cache_size} above maximum 10GB")
return False
-
+
return True
-
+
def validate_file_paths(self) -> bool:
"""Validate file path configurations."""
path_errors = []
-
+
# Check database path
- db_path = self.config.get('DATABASE_PATH', 'data/fact_demo.db')
+ db_path = self.config.get("DATABASE_PATH", "data/fact_demo.db")
db_dir = Path(db_path).parent
-
+
if not db_dir.exists():
try:
db_dir.mkdir(parents=True, exist_ok=True)
self.log_info(f"ā
Created database directory: {db_dir}")
except Exception as e:
path_errors.append(f"Cannot create database directory {db_dir}: {e}")
-
+
if path_errors:
self.errors.extend(path_errors)
-
+
return len(path_errors) == 0
-
+
def check_connectivity(self) -> bool:
"""Test connectivity to configured services (optional)."""
- if not self.config.get('ANTHROPIC_API_KEY') or not self.config.get('ARCADE_API_KEY'):
- self.warnings.append("ā ļø Skipping connectivity tests - API keys not configured")
+ if not self.config.get("ANTHROPIC_API_KEY") or not self.config.get(
+ "ARCADE_API_KEY"
+ ):
+ self.warnings.append(
+ "ā ļø Skipping connectivity tests - API keys not configured"
+ )
return True
-
+
connectivity_results = []
-
+
# Test Anthropic API
try:
import requests
- response = requests.get('https://api.anthropic.com', timeout=10)
- connectivity_results.append(f"ā
Anthropic API reachable (status: {response.status_code})")
+
+ response = requests.get("https://api.anthropic.com", timeout=10)
+ connectivity_results.append(
+ f"ā
Anthropic API reachable (status: {response.status_code})"
+ )
except Exception as e:
connectivity_results.append(f"ā Anthropic API unreachable: {e}")
-
+
# Test Arcade API
- arcade_url = self.config.get('ARCADE_BASE_URL', 'https://api.arcade-ai.com')
+ arcade_url = self.config.get("ARCADE_BASE_URL", "https://api.arcade-ai.com")
try:
import requests
+
response = requests.get(arcade_url, timeout=10)
- connectivity_results.append(f"ā
Arcade API reachable (status: {response.status_code})")
+ connectivity_results.append(
+ f"ā
Arcade API reachable (status: {response.status_code})"
+ )
except Exception as e:
connectivity_results.append(f"ā Arcade API unreachable: {e}")
-
+
if self.verbose:
for result in connectivity_results:
print(result)
-
+
return True
-
+
def generate_recommendations(self) -> List[str]:
"""Generate configuration recommendations."""
recommendations = []
-
+
# Check for optimal settings
- if not self.config.get('STRICT_MODE') or self.config.get('STRICT_MODE').lower() != 'true':
+ if (
+ not self.config.get("STRICT_MODE")
+ or self.config.get("STRICT_MODE").lower() != "true"
+ ):
recommendations.append("Enable STRICT_MODE=true for production security")
-
- if not self.config.get('RATE_LIMITING_ENABLED') or self.config.get('RATE_LIMITING_ENABLED').lower() != 'true':
+
+ if (
+ not self.config.get("RATE_LIMITING_ENABLED")
+ or self.config.get("RATE_LIMITING_ENABLED").lower() != "true"
+ ):
recommendations.append("Enable RATE_LIMITING_ENABLED=true for production")
-
- if self.config.get('DEBUG_MODE', '').lower() == 'true':
+
+ if self.config.get("DEBUG_MODE", "").lower() == "true":
recommendations.append("Disable DEBUG_MODE=false for production")
-
- if self.config.get('LOG_LEVEL', '').upper() == 'DEBUG':
+
+ if self.config.get("LOG_LEVEL", "").upper() == "DEBUG":
recommendations.append("Use LOG_LEVEL=INFO or higher for production")
-
+
return recommendations
-
+
def run_validation(self, check_connectivity: bool = False) -> bool:
"""Run complete validation suite."""
self.log_info("š FACT Environment Configuration Validation")
self.log_info("=" * 50)
-
+
# Step 1: Load environment
if not self.load_environment():
return False
-
+
# Step 2: Validate required keys
self.log_info("\nš Validating required API keys...")
self.validate_required_keys()
-
+
# Step 3: Validate optional keys
self.log_info("š Validating optional API keys...")
self.validate_optional_keys()
-
+
# Step 4: Validate parameters
self.log_info("š Validating configuration parameters...")
self.validate_parameters()
-
+
# Step 5: Validate cache size
self.log_info("š Validating cache configuration...")
self.validate_cache_size()
-
+
# Step 6: Validate file paths
self.log_info("š Validating file paths...")
self.validate_file_paths()
-
+
# Step 7: Connectivity check (optional)
if check_connectivity:
self.log_info("š Testing service connectivity...")
self.check_connectivity()
-
+
# Step 8: Generate recommendations
recommendations = self.generate_recommendations()
-
+
# Report results
self.print_results(recommendations)
-
+
return len(self.errors) == 0
-
+
def print_results(self, recommendations: List[str]):
"""Print validation results."""
print(f"\n{'=' * 50}")
print("š VALIDATION RESULTS")
print(f"{'=' * 50}")
-
+
if self.errors:
print(f"\nā ERRORS ({len(self.errors)}):")
for error in self.errors:
print(f" ⢠{error}")
-
+
if self.warnings:
print(f"\nā ļø WARNINGS ({len(self.warnings)}):")
for warning in self.warnings:
print(f" ⢠{warning}")
-
+
if recommendations:
print(f"\nš” RECOMMENDATIONS ({len(recommendations)}):")
for rec in recommendations:
print(f" ⢠{rec}")
-
+
if not self.errors and not self.warnings:
print("\nā
Configuration validation passed!")
print("š Your FACT system is ready to use!")
elif not self.errors:
- print(f"\nā
Configuration validation passed with {len(self.warnings)} warnings")
+ print(
+ f"\nā
Configuration validation passed with {len(self.warnings)} warnings"
+ )
else:
- print(f"\nā Configuration validation failed with {len(self.errors)} errors")
+ print(
+ f"\nā Configuration validation failed with {len(self.errors)} errors"
+ )
print("\nš§ RECOVERY SUGGESTIONS:")
print(" 1. Update your .env file with correct API keys")
print(" 2. Verify parameter values are within allowed ranges")
print(" 3. Check file paths and permissions")
print(" 4. Run: python scripts/validate_env.py --verbose")
-
+
def log_info(self, message: str):
"""Log info message if verbose mode is enabled."""
if self.verbose:
print(message)
-
+
def log_warning(self, message: str):
"""Log warning message."""
self.warnings.append(message)
@@ -370,30 +435,29 @@ def main():
Exit codes:
0 = Validation passed
1 = Validation failed
- """
+ """,
)
-
+
parser.add_argument(
- '--verbose', '-v',
- action='store_true',
- help='Enable verbose output'
+ "--verbose", "-v", action="store_true", help="Enable verbose output"
)
-
+
parser.add_argument(
- '--check-connectivity', '-c',
- action='store_true',
- help='Test connectivity to external services'
+ "--check-connectivity",
+ "-c",
+ action="store_true",
+ help="Test connectivity to external services",
)
-
+
args = parser.parse_args()
-
+
# Run validation
validator = ConfigurationValidator(verbose=args.verbose)
success = validator.run_validation(check_connectivity=args.check_connectivity)
-
+
# Exit with appropriate code
sys.exit(0 if success else 1)
if __name__ == "__main__":
- main()
\ No newline at end of file
+ main()
diff --git a/src/__init__.py b/src/__init__.py
index a19409b..2894815 100644
--- a/src/__init__.py
+++ b/src/__init__.py
@@ -5,4 +5,4 @@
"""
__version__ = "1.0.0"
-__author__ = "FACT Development Team"
\ No newline at end of file
+__author__ = "FACT Development Team"
diff --git a/src/arcade/__init__.py b/src/arcade/__init__.py
index c7d200f..95c1884 100644
--- a/src/arcade/__init__.py
+++ b/src/arcade/__init__.py
@@ -17,9 +17,9 @@
from arcade.errors import ArcadeError, ArcadeConnectionError, ArcadeExecutionError
__all__ = [
- 'ArcadeClient',
- 'ArcadeGateway',
- 'ArcadeError',
- 'ArcadeConnectionError',
- 'ArcadeExecutionError'
-]
\ No newline at end of file
+ "ArcadeClient",
+ "ArcadeGateway",
+ "ArcadeError",
+ "ArcadeConnectionError",
+ "ArcadeExecutionError",
+]
diff --git a/src/arcade/client.py b/src/arcade/client.py
index 677ec5d..a7bf426 100644
--- a/src/arcade/client.py
+++ b/src/arcade/client.py
@@ -14,6 +14,7 @@
# Import arcadepy when available
try:
import arcade
+
ARCADE_AVAILABLE = True
except ImportError:
ARCADE_AVAILABLE = False
@@ -27,25 +28,26 @@
ArcadeExecutionError,
ArcadeTimeoutError,
ArcadeRegistrationError,
- ArcadeSerializationError
+ ArcadeSerializationError,
)
from ..core.errors import ToolExecutionError
except ImportError:
# Fall back to absolute imports (when run as script)
import sys
from pathlib import Path
+
# Add src to path if not already there
src_path = str(Path(__file__).parent.parent)
if src_path not in sys.path:
sys.path.insert(0, src_path)
-
+
from arcade.errors import (
ArcadeConnectionError,
ArcadeAuthenticationError,
ArcadeExecutionError,
ArcadeTimeoutError,
ArcadeRegistrationError,
- ArcadeSerializationError
+ ArcadeSerializationError,
)
from core.errors import ToolExecutionError
@@ -56,19 +58,21 @@
class ArcadeClient:
"""
Client for interacting with Arcade.dev tool hosting platform.
-
+
Provides secure tool execution, registration, and management
through the Arcade.dev gateway infrastructure.
"""
-
- def __init__(self,
- api_key: Optional[str] = None,
- base_url: Optional[str] = None,
- timeout: int = 30,
- max_retries: int = 3):
+
+ def __init__(
+ self,
+ api_key: Optional[str] = None,
+ base_url: Optional[str] = None,
+ timeout: int = 30,
+ max_retries: int = 3,
+ ):
"""
Initialize Arcade client.
-
+
Args:
api_key: Arcade.dev API key
base_url: Base URL for Arcade.dev API
@@ -76,25 +80,25 @@ def __init__(self,
max_retries: Maximum retry attempts
"""
if not ARCADE_AVAILABLE:
- raise ImportError("arcadepy library not available. Install with: pip install arcadepy")
-
+ raise ImportError(
+ "arcadepy library not available. Install with: pip install arcadepy"
+ )
+
self.api_key = api_key
self.base_url = base_url or "https://api.arcade.dev"
self.timeout = timeout
self.max_retries = max_retries
-
+
# Initialize arcade client
self._client = None
self._connected = False
-
- logger.info("ArcadeClient initialized",
- base_url=self.base_url,
- timeout=timeout)
-
+
+ logger.info("ArcadeClient initialized", base_url=self.base_url, timeout=timeout)
+
async def connect(self) -> None:
"""
Establish connection to Arcade.dev platform.
-
+
Raises:
ArcadeConnectionError: If connection fails
ArcadeAuthenticationError: If authentication fails
@@ -103,278 +107,290 @@ async def connect(self) -> None:
# Initialize arcade client with API key
if self.api_key:
self._client = arcade.Client(
- api_key=self.api_key,
- base_url=self.base_url,
- timeout=self.timeout
+ api_key=self.api_key, base_url=self.base_url, timeout=self.timeout
)
else:
self._client = arcade.Client(
- base_url=self.base_url,
- timeout=self.timeout
+ base_url=self.base_url, timeout=self.timeout
)
-
+
# Test connection with a simple API call
await self._test_connection()
-
+
self._connected = True
logger.info("Connected to Arcade.dev successfully")
-
+
except Exception as e:
logger.error("Failed to connect to Arcade.dev", error=str(e))
if "authentication" in str(e).lower() or "api key" in str(e).lower():
raise ArcadeAuthenticationError(f"Authentication failed: {str(e)}")
else:
raise ArcadeConnectionError(f"Connection failed: {str(e)}")
-
- async def execute_tool(self,
- tool_name: str,
- arguments: Dict[str, Any],
- timeout: Optional[int] = None,
- user_id: Optional[str] = None) -> Dict[str, Any]:
+
+ async def execute_tool(
+ self,
+ tool_name: str,
+ arguments: Dict[str, Any],
+ timeout: Optional[int] = None,
+ user_id: Optional[str] = None,
+ ) -> Dict[str, Any]:
"""
Execute a tool on Arcade.dev platform.
-
+
Args:
tool_name: Name of the tool to execute
arguments: Tool arguments
timeout: Execution timeout in seconds
user_id: Optional user identifier for logging
-
+
Returns:
Tool execution result
-
+
Raises:
ArcadeExecutionError: If execution fails
ArcadeTimeoutError: If execution times out
"""
if not self._connected or not self._client:
await self.connect()
-
+
execution_timeout = timeout or self.timeout
start_time = time.time()
-
+
try:
- logger.info("Executing tool on Arcade.dev",
- tool_name=tool_name,
- user_id=user_id,
- timeout=execution_timeout)
-
+ logger.info(
+ "Executing tool on Arcade.dev",
+ tool_name=tool_name,
+ user_id=user_id,
+ timeout=execution_timeout,
+ )
+
# Prepare execution request
execution_request = {
"tool": tool_name,
"arguments": arguments,
- "timeout": execution_timeout
+ "timeout": execution_timeout,
}
-
+
if user_id:
execution_request["user_id"] = user_id
-
+
# Execute tool with timeout
result = await asyncio.wait_for(
- self._execute_with_retry(execution_request),
- timeout=execution_timeout
+ self._execute_with_retry(execution_request), timeout=execution_timeout
)
-
+
execution_time = (time.time() - start_time) * 1000
-
- logger.info("Tool executed successfully on Arcade.dev",
- tool_name=tool_name,
- execution_time_ms=execution_time,
- user_id=user_id)
-
+
+ logger.info(
+ "Tool executed successfully on Arcade.dev",
+ tool_name=tool_name,
+ execution_time_ms=execution_time,
+ user_id=user_id,
+ )
+
return self._process_execution_result(result, execution_time)
-
+
except asyncio.TimeoutError:
execution_time = (time.time() - start_time) * 1000
- logger.error("Tool execution timed out",
- tool_name=tool_name,
- timeout=execution_timeout,
- execution_time_ms=execution_time)
- raise ArcadeTimeoutError(f"Tool execution timed out after {execution_timeout} seconds")
-
+ logger.error(
+ "Tool execution timed out",
+ tool_name=tool_name,
+ timeout=execution_timeout,
+ execution_time_ms=execution_time,
+ )
+ raise ArcadeTimeoutError(
+ f"Tool execution timed out after {execution_timeout} seconds"
+ )
+
except Exception as e:
execution_time = (time.time() - start_time) * 1000
- logger.error("Tool execution failed",
- tool_name=tool_name,
- error=str(e),
- execution_time_ms=execution_time)
+ logger.error(
+ "Tool execution failed",
+ tool_name=tool_name,
+ error=str(e),
+ execution_time_ms=execution_time,
+ )
raise ArcadeExecutionError(f"Tool execution failed: {str(e)}")
-
- async def register_tool(self,
- tool_definition: Dict[str, Any],
- source_code: Optional[str] = None) -> Dict[str, Any]:
+
+ async def register_tool(
+ self, tool_definition: Dict[str, Any], source_code: Optional[str] = None
+ ) -> Dict[str, Any]:
"""
Register a tool with Arcade.dev platform.
-
+
Args:
tool_definition: Tool definition including name, description, parameters
source_code: Optional source code for the tool
-
+
Returns:
Registration result
-
+
Raises:
ArcadeRegistrationError: If registration fails
"""
if not self._connected or not self._client:
await self.connect()
-
+
try:
- logger.info("Registering tool with Arcade.dev",
- tool_name=tool_definition.get("name"))
-
+ logger.info(
+ "Registering tool with Arcade.dev",
+ tool_name=tool_definition.get("name"),
+ )
+
# Prepare registration request
registration_request = {
"name": tool_definition["name"],
"description": tool_definition["description"],
- "parameters": tool_definition["parameters"]
+ "parameters": tool_definition["parameters"],
}
-
+
if source_code:
registration_request["source_code"] = source_code
-
+
# Add metadata
registration_request["metadata"] = {
"version": tool_definition.get("version", "1.0.0"),
"requires_auth": tool_definition.get("requires_auth", False),
"timeout_seconds": tool_definition.get("timeout_seconds", 30),
- "created_at": time.time()
+ "created_at": time.time(),
}
-
+
# Register tool
result = await self._client.tools.register(registration_request)
-
- logger.info("Tool registered successfully",
- tool_name=tool_definition["name"],
- tool_id=result.get("id"))
-
+
+ logger.info(
+ "Tool registered successfully",
+ tool_name=tool_definition["name"],
+ tool_id=result.get("id"),
+ )
+
return result
-
+
except Exception as e:
- logger.error("Tool registration failed",
- tool_name=tool_definition.get("name"),
- error=str(e))
+ logger.error(
+ "Tool registration failed",
+ tool_name=tool_definition.get("name"),
+ error=str(e),
+ )
raise ArcadeRegistrationError(f"Tool registration failed: {str(e)}")
-
+
async def list_tools(self) -> List[Dict[str, Any]]:
"""
List all registered tools on Arcade.dev platform.
-
+
Returns:
List of registered tools
-
+
Raises:
ArcadeConnectionError: If request fails
"""
if not self._connected or not self._client:
await self.connect()
-
+
try:
result = await self._client.tools.list()
-
- logger.debug("Retrieved tool list from Arcade.dev",
- tool_count=len(result.get("tools", [])))
-
+
+ logger.debug(
+ "Retrieved tool list from Arcade.dev",
+ tool_count=len(result.get("tools", [])),
+ )
+
return result.get("tools", [])
-
+
except Exception as e:
logger.error("Failed to list tools", error=str(e))
raise ArcadeConnectionError(f"Failed to list tools: {str(e)}")
-
+
async def get_tool_info(self, tool_name: str) -> Dict[str, Any]:
"""
Get detailed information about a specific tool.
-
+
Args:
tool_name: Name of the tool
-
+
Returns:
Tool information
-
+
Raises:
ArcadeConnectionError: If request fails
"""
if not self._connected or not self._client:
await self.connect()
-
+
try:
result = await self._client.tools.get(tool_name)
-
- logger.debug("Retrieved tool info from Arcade.dev",
- tool_name=tool_name)
-
+
+ logger.debug("Retrieved tool info from Arcade.dev", tool_name=tool_name)
+
return result
-
+
except Exception as e:
- logger.error("Failed to get tool info",
- tool_name=tool_name,
- error=str(e))
+ logger.error("Failed to get tool info", tool_name=tool_name, error=str(e))
raise ArcadeConnectionError(f"Failed to get tool info: {str(e)}")
-
+
async def delete_tool(self, tool_name: str) -> bool:
"""
Delete a tool from Arcade.dev platform.
-
+
Args:
tool_name: Name of the tool to delete
-
+
Returns:
True if deletion was successful
-
+
Raises:
ArcadeConnectionError: If request fails
"""
if not self._connected or not self._client:
await self.connect()
-
+
try:
await self._client.tools.delete(tool_name)
-
- logger.info("Tool deleted successfully",
- tool_name=tool_name)
-
+
+ logger.info("Tool deleted successfully", tool_name=tool_name)
+
return True
-
+
except Exception as e:
- logger.error("Failed to delete tool",
- tool_name=tool_name,
- error=str(e))
+ logger.error("Failed to delete tool", tool_name=tool_name, error=str(e))
raise ArcadeConnectionError(f"Failed to delete tool: {str(e)}")
-
- async def get_execution_logs(self,
- execution_id: str,
- limit: int = 100) -> List[Dict[str, Any]]:
+
+ async def get_execution_logs(
+ self, execution_id: str, limit: int = 100
+ ) -> List[Dict[str, Any]]:
"""
Get execution logs for a specific execution.
-
+
Args:
execution_id: Execution identifier
limit: Maximum number of log entries
-
+
Returns:
List of log entries
-
+
Raises:
ArcadeConnectionError: If request fails
"""
if not self._connected or not self._client:
await self.connect()
-
+
try:
result = await self._client.executions.get_logs(execution_id, limit=limit)
-
- logger.debug("Retrieved execution logs",
- execution_id=execution_id,
- log_count=len(result.get("logs", [])))
-
+
+ logger.debug(
+ "Retrieved execution logs",
+ execution_id=execution_id,
+ log_count=len(result.get("logs", [])),
+ )
+
return result.get("logs", [])
-
+
except Exception as e:
- logger.error("Failed to get execution logs",
- execution_id=execution_id,
- error=str(e))
+ logger.error(
+ "Failed to get execution logs", execution_id=execution_id, error=str(e)
+ )
raise ArcadeConnectionError(f"Failed to get execution logs: {str(e)}")
-
+
async def close(self) -> None:
"""Close the Arcade client connection."""
if self._client:
@@ -386,12 +402,12 @@ async def close(self) -> None:
self._client = None
self._connected = False
logger.info("Arcade client connection closed")
-
+
async def _test_connection(self) -> None:
"""Test connection to Arcade.dev platform."""
try:
# Simple health check or authentication test
- if hasattr(self._client, 'health'):
+ if hasattr(self._client, "health"):
await self._client.health.check()
else:
# Fallback to listing tools as a connection test
@@ -399,68 +415,80 @@ async def _test_connection(self) -> None:
except Exception as e:
logger.error("Connection test failed", error=str(e))
raise
-
- async def _execute_with_retry(self, execution_request: Dict[str, Any]) -> Dict[str, Any]:
+
+ async def _execute_with_retry(
+ self, execution_request: Dict[str, Any]
+ ) -> Dict[str, Any]:
"""Execute tool with retry logic."""
last_error = None
-
+
for attempt in range(self.max_retries + 1):
try:
return await self._client.tools.execute(execution_request)
except Exception as e:
last_error = e
if attempt < self.max_retries:
- wait_time = 2 ** attempt # Exponential backoff
- logger.warning("Tool execution attempt failed, retrying",
- attempt=attempt + 1,
- max_retries=self.max_retries,
- wait_time=wait_time,
- error=str(e))
+ wait_time = 2**attempt # Exponential backoff
+ logger.warning(
+ "Tool execution attempt failed, retrying",
+ attempt=attempt + 1,
+ max_retries=self.max_retries,
+ wait_time=wait_time,
+ error=str(e),
+ )
await asyncio.sleep(wait_time)
else:
- logger.error("All retry attempts exhausted",
- attempts=self.max_retries + 1,
- error=str(e))
-
+ logger.error(
+ "All retry attempts exhausted",
+ attempts=self.max_retries + 1,
+ error=str(e),
+ )
+
raise last_error
-
- def _process_execution_result(self, result: Dict[str, Any], execution_time: float) -> Dict[str, Any]:
+
+ def _process_execution_result(
+ self, result: Dict[str, Any], execution_time: float
+ ) -> Dict[str, Any]:
"""Process and validate execution result."""
try:
# Ensure result has required fields
processed_result = {
"success": result.get("success", True),
- "execution_time_ms": execution_time
+ "execution_time_ms": execution_time,
}
-
+
if result.get("success", True):
processed_result["data"] = result.get("data", result)
else:
processed_result["error"] = result.get("error", "Unknown error")
processed_result["success"] = False
-
+
# Add metadata if available
if "metadata" in result:
processed_result["metadata"] = result["metadata"]
-
+
return processed_result
-
+
except Exception as e:
logger.error("Failed to process execution result", error=str(e))
- raise ArcadeSerializationError(f"Failed to process execution result: {str(e)}")
+ raise ArcadeSerializationError(
+ f"Failed to process execution result: {str(e)}"
+ )
# Utility functions for Arcade integration
-def create_arcade_client(api_key: Optional[str] = None,
- base_url: Optional[str] = None) -> ArcadeClient:
+
+def create_arcade_client(
+ api_key: Optional[str] = None, base_url: Optional[str] = None
+) -> ArcadeClient:
"""
Create and configure an Arcade client.
-
+
Args:
api_key: Arcade.dev API key
base_url: Base URL for Arcade.dev API
-
+
Returns:
Configured ArcadeClient instance
"""
@@ -470,10 +498,10 @@ def create_arcade_client(api_key: Optional[str] = None,
async def test_arcade_connection(client: ArcadeClient) -> bool:
"""
Test connection to Arcade.dev platform.
-
+
Args:
client: ArcadeClient instance
-
+
Returns:
True if connection is successful
"""
@@ -482,4 +510,4 @@ async def test_arcade_connection(client: ArcadeClient) -> bool:
return True
except Exception as e:
logger.error("Arcade connection test failed", error=str(e))
- return False
\ No newline at end of file
+ return False
diff --git a/src/arcade/errors.py b/src/arcade/errors.py
index 4639c0f..ae9ffee 100644
--- a/src/arcade/errors.py
+++ b/src/arcade/errors.py
@@ -6,6 +6,7 @@
"""
from typing import Optional, Dict, Any
+
try:
# Try relative imports first (when used as package)
from ..core.errors import FACTError
@@ -13,44 +14,52 @@
# Fall back to absolute imports (when run as script)
import sys
from pathlib import Path
+
# Add src to path if not already there
src_path = str(Path(__file__).parent.parent)
if src_path not in sys.path:
sys.path.insert(0, src_path)
-
+
from core.errors import FACTError
class ArcadeError(FACTError):
"""Base exception class for all Arcade.dev integration errors."""
+
pass
class ArcadeConnectionError(ArcadeError):
"""Raised when connection to Arcade.dev fails."""
+
pass
class ArcadeAuthenticationError(ArcadeError):
"""Raised when Arcade.dev authentication fails."""
+
pass
class ArcadeExecutionError(ArcadeError):
"""Raised when tool execution on Arcade.dev fails."""
+
pass
class ArcadeTimeoutError(ArcadeError):
"""Raised when Arcade.dev operations timeout."""
+
pass
class ArcadeRegistrationError(ArcadeError):
"""Raised when tool registration with Arcade.dev fails."""
+
pass
class ArcadeSerializationError(ArcadeError):
"""Raised when request/response serialization fails."""
- pass
\ No newline at end of file
+
+ pass
diff --git a/src/arcade/gateway.py b/src/arcade/gateway.py
index f3890e9..bda72d5 100644
--- a/src/arcade/gateway.py
+++ b/src/arcade/gateway.py
@@ -18,11 +18,12 @@
# Fall back to absolute imports (when run as script)
import sys
from pathlib import Path
+
# Add src to path if not already there
src_path = str(Path(__file__).parent.parent)
if src_path not in sys.path:
sys.path.insert(0, src_path)
-
+
from arcade.client import ArcadeClient
from arcade.errors import ArcadeError, ArcadeExecutionError
from core.errors import ToolExecutionError
@@ -34,18 +35,20 @@
class ArcadeGateway:
"""
Gateway for routing tool execution between local and Arcade.dev execution.
-
+
Provides intelligent routing, fallback mechanisms, and execution
orchestration for tools that can run locally or on Arcade.dev.
"""
-
- def __init__(self,
- arcade_client: Optional[ArcadeClient] = None,
- enable_fallback: bool = True,
- prefer_arcade: bool = False):
+
+ def __init__(
+ self,
+ arcade_client: Optional[ArcadeClient] = None,
+ enable_fallback: bool = True,
+ prefer_arcade: bool = False,
+ ):
"""
Initialize Arcade gateway.
-
+
Args:
arcade_client: Optional Arcade client for remote execution
enable_fallback: Whether to fallback to local execution on Arcade failure
@@ -54,199 +57,216 @@ def __init__(self,
self.arcade_client = arcade_client
self.enable_fallback = enable_fallback
self.prefer_arcade = prefer_arcade
-
- logger.info("ArcadeGateway initialized",
- has_arcade_client=bool(arcade_client),
- enable_fallback=enable_fallback,
- prefer_arcade=prefer_arcade)
-
- async def execute_tool(self,
- tool_name: str,
- arguments: Dict[str, Any],
- local_function: Optional[callable] = None,
- user_id: Optional[str] = None,
- timeout: Optional[int] = None) -> Dict[str, Any]:
+
+ logger.info(
+ "ArcadeGateway initialized",
+ has_arcade_client=bool(arcade_client),
+ enable_fallback=enable_fallback,
+ prefer_arcade=prefer_arcade,
+ )
+
+ async def execute_tool(
+ self,
+ tool_name: str,
+ arguments: Dict[str, Any],
+ local_function: Optional[callable] = None,
+ user_id: Optional[str] = None,
+ timeout: Optional[int] = None,
+ ) -> Dict[str, Any]:
"""
Execute a tool using the best available method.
-
+
Args:
tool_name: Name of the tool to execute
arguments: Tool arguments
local_function: Local function to execute if available
user_id: Optional user identifier
timeout: Execution timeout
-
+
Returns:
Tool execution result
-
+
Raises:
ToolExecutionError: If execution fails on all available methods
"""
execution_methods = self._determine_execution_order(tool_name, local_function)
-
+
last_error = None
-
+
for method in execution_methods:
try:
if method == "arcade":
- logger.debug("Attempting Arcade execution",
- tool_name=tool_name,
- user_id=user_id)
-
+ logger.debug(
+ "Attempting Arcade execution",
+ tool_name=tool_name,
+ user_id=user_id,
+ )
+
result = await self._execute_via_arcade(
tool_name, arguments, user_id, timeout
)
-
- logger.info("Tool executed successfully via Arcade",
- tool_name=tool_name,
- user_id=user_id)
-
+
+ logger.info(
+ "Tool executed successfully via Arcade",
+ tool_name=tool_name,
+ user_id=user_id,
+ )
+
return result
-
+
elif method == "local":
- logger.debug("Attempting local execution",
- tool_name=tool_name,
- user_id=user_id)
-
+ logger.debug(
+ "Attempting local execution",
+ tool_name=tool_name,
+ user_id=user_id,
+ )
+
result = await self._execute_locally(
tool_name, local_function, arguments
)
-
- logger.info("Tool executed successfully locally",
- tool_name=tool_name,
- user_id=user_id)
-
+
+ logger.info(
+ "Tool executed successfully locally",
+ tool_name=tool_name,
+ user_id=user_id,
+ )
+
return result
-
+
except Exception as e:
last_error = e
- logger.warning("Execution method failed",
- method=method,
- tool_name=tool_name,
- error=str(e))
-
+ logger.warning(
+ "Execution method failed",
+ method=method,
+ tool_name=tool_name,
+ error=str(e),
+ )
+
# If this is the last method and fallback is disabled, raise immediately
if not self.enable_fallback and method == execution_methods[0]:
raise
-
+
# All methods failed
error_msg = f"Tool execution failed on all available methods. Last error: {str(last_error)}"
- logger.error("All execution methods failed",
- tool_name=tool_name,
- methods_tried=execution_methods,
- last_error=str(last_error))
-
+ logger.error(
+ "All execution methods failed",
+ tool_name=tool_name,
+ methods_tried=execution_methods,
+ last_error=str(last_error),
+ )
+
raise ToolExecutionError(error_msg)
-
- def _determine_execution_order(self,
- tool_name: str,
- local_function: Optional[callable]) -> list:
+
+ def _determine_execution_order(
+ self, tool_name: str, local_function: Optional[callable]
+ ) -> list:
"""
Determine the order of execution methods to try.
-
+
Args:
tool_name: Name of the tool
local_function: Local function if available
-
+
Returns:
List of execution methods in order of preference
"""
methods = []
-
+
# If we prefer Arcade and have a client, try Arcade first
if self.prefer_arcade and self.arcade_client:
methods.append("arcade")
-
+
# Add local as fallback if available and fallback is enabled
if local_function and self.enable_fallback:
methods.append("local")
-
+
# If we prefer local or don't have Arcade client
else:
# Try local first if available
if local_function:
methods.append("local")
-
+
# Add Arcade as fallback if available and fallback is enabled
if self.arcade_client and self.enable_fallback:
methods.append("arcade")
-
+
# If no methods are available, at least try what we have
if not methods:
if self.arcade_client:
methods.append("arcade")
elif local_function:
methods.append("local")
-
+
return methods
-
- async def _execute_via_arcade(self,
- tool_name: str,
- arguments: Dict[str, Any],
- user_id: Optional[str],
- timeout: Optional[int]) -> Dict[str, Any]:
+
+ async def _execute_via_arcade(
+ self,
+ tool_name: str,
+ arguments: Dict[str, Any],
+ user_id: Optional[str],
+ timeout: Optional[int],
+ ) -> Dict[str, Any]:
"""
Execute tool via Arcade.dev platform.
-
+
Args:
tool_name: Name of the tool
arguments: Tool arguments
user_id: Optional user identifier
timeout: Execution timeout
-
+
Returns:
Arcade execution result
-
+
Raises:
ArcadeExecutionError: If Arcade execution fails
"""
if not self.arcade_client:
raise ArcadeExecutionError("No Arcade client available")
-
+
try:
result = await self.arcade_client.execute_tool(
tool_name=tool_name,
arguments=arguments,
timeout=timeout,
- user_id=user_id
+ user_id=user_id,
)
-
+
# Ensure result has expected structure
if not isinstance(result, dict):
raise ArcadeExecutionError("Invalid result format from Arcade")
-
+
# Add execution metadata
result["execution_method"] = "arcade"
result["platform"] = "arcade.dev"
-
+
return result
-
+
except ArcadeError:
raise
except Exception as e:
raise ArcadeExecutionError(f"Arcade execution failed: {str(e)}")
-
- async def _execute_locally(self,
- tool_name: str,
- local_function: callable,
- arguments: Dict[str, Any]) -> Dict[str, Any]:
+
+ async def _execute_locally(
+ self, tool_name: str, local_function: callable, arguments: Dict[str, Any]
+ ) -> Dict[str, Any]:
"""
Execute tool locally using provided function.
-
+
Args:
tool_name: Name of the tool
local_function: Local function to execute
arguments: Tool arguments
-
+
Returns:
Local execution result
-
+
Raises:
ToolExecutionError: If local execution fails
"""
if not local_function:
raise ToolExecutionError("No local function available")
-
+
try:
# Execute function (handle both sync and async functions)
if asyncio.iscoroutinefunction(local_function):
@@ -254,25 +274,27 @@ async def _execute_locally(self,
else:
# Run sync function in thread pool to avoid blocking
loop = asyncio.get_event_loop()
- result = await loop.run_in_executor(None, lambda: local_function(**arguments))
-
+ result = await loop.run_in_executor(
+ None, lambda: local_function(**arguments)
+ )
+
# Ensure result is a dictionary
if not isinstance(result, dict):
result = {"result": result}
-
+
# Add execution metadata
result["execution_method"] = "local"
result["platform"] = "local"
-
+
return result
-
+
except Exception as e:
raise ToolExecutionError(f"Local execution failed: {str(e)}")
-
+
async def health_check(self) -> Dict[str, Any]:
"""
Perform health check on available execution methods.
-
+
Returns:
Health check results
"""
@@ -280,9 +302,9 @@ async def health_check(self) -> Dict[str, Any]:
"gateway_healthy": True,
"arcade_available": False,
"local_available": True, # Local execution is always available
- "methods": []
+ "methods": [],
}
-
+
# Check Arcade availability
if self.arcade_client:
try:
@@ -290,25 +312,25 @@ async def health_check(self) -> Dict[str, Any]:
await self.arcade_client.connect()
health_status["arcade_available"] = True
health_status["methods"].append("arcade")
-
+
logger.debug("Arcade health check passed")
-
+
except Exception as e:
logger.warning("Arcade health check failed", error=str(e))
health_status["arcade_error"] = str(e)
-
+
# Local execution is always available
health_status["methods"].append("local")
-
+
# Overall health
health_status["gateway_healthy"] = len(health_status["methods"]) > 0
-
+
return health_status
-
+
def get_execution_stats(self) -> Dict[str, Any]:
"""
Get execution statistics and configuration.
-
+
Returns:
Execution statistics
"""
@@ -316,26 +338,28 @@ def get_execution_stats(self) -> Dict[str, Any]:
"arcade_client_configured": bool(self.arcade_client),
"fallback_enabled": self.enable_fallback,
"prefer_arcade": self.prefer_arcade,
- "available_methods": ["arcade"] if self.arcade_client else [] + ["local"]
+ "available_methods": ["arcade"] if self.arcade_client else [] + ["local"],
}
-def create_arcade_gateway(arcade_client: Optional[ArcadeClient] = None,
- enable_fallback: bool = True,
- prefer_arcade: bool = False) -> ArcadeGateway:
+def create_arcade_gateway(
+ arcade_client: Optional[ArcadeClient] = None,
+ enable_fallback: bool = True,
+ prefer_arcade: bool = False,
+) -> ArcadeGateway:
"""
Create and configure an Arcade gateway.
-
+
Args:
arcade_client: Optional Arcade client
enable_fallback: Whether to enable fallback between methods
prefer_arcade: Whether to prefer Arcade over local execution
-
+
Returns:
Configured ArcadeGateway instance
"""
return ArcadeGateway(
arcade_client=arcade_client,
enable_fallback=enable_fallback,
- prefer_arcade=prefer_arcade
- )
\ No newline at end of file
+ prefer_arcade=prefer_arcade,
+ )
diff --git a/src/benchmarking/__init__.py b/src/benchmarking/__init__.py
index 8146a0d..77152bf 100644
--- a/src/benchmarking/__init__.py
+++ b/src/benchmarking/__init__.py
@@ -10,7 +10,7 @@
BenchmarkRunner,
BenchmarkConfig,
BenchmarkResult,
- BenchmarkSummary
+ BenchmarkSummary,
)
from .comparisons import (
@@ -18,7 +18,7 @@
PerformanceComparison,
ComparisonResult,
ComparisonMetrics,
- SystemType
+ SystemType,
)
from .profiler import (
@@ -26,7 +26,7 @@
BottleneckAnalyzer,
ProfileResult,
BottleneckAnalysis,
- ProfilePoint
+ ProfilePoint,
)
from .monitoring import (
@@ -34,7 +34,7 @@
PerformanceTracker,
MonitoringConfig,
PerformanceAlert,
- PerformanceTrend
+ PerformanceTrend,
)
from .visualization import (
@@ -42,47 +42,43 @@
ReportGenerator,
BenchmarkReport,
ChartData,
- ReportSection
+ ReportSection,
)
__all__ = [
# Framework
"BenchmarkFramework",
- "BenchmarkRunner",
+ "BenchmarkRunner",
"BenchmarkConfig",
"BenchmarkResult",
"BenchmarkSummary",
-
# Comparisons
"RAGComparison",
- "PerformanceComparison",
+ "PerformanceComparison",
"ComparisonResult",
"ComparisonMetrics",
"SystemType",
-
# Profiling
"SystemProfiler",
"BottleneckAnalyzer",
"ProfileResult",
"BottleneckAnalysis",
"ProfilePoint",
-
# Monitoring
"ContinuousMonitor",
"PerformanceTracker",
- "MonitoringConfig",
+ "MonitoringConfig",
"PerformanceAlert",
"PerformanceTrend",
-
# Visualization
"BenchmarkVisualizer",
"ReportGenerator",
"BenchmarkReport",
"ChartData",
- "ReportSection"
+ "ReportSection",
]
# Version info
__version__ = "1.0.0"
__author__ = "FACT Team"
-__description__ = "Comprehensive benchmarking system for FACT performance validation"
\ No newline at end of file
+__description__ = "Comprehensive benchmarking system for FACT performance validation"
diff --git a/src/benchmarking/comparisons.py b/src/benchmarking/comparisons.py
index 8281315..5db1189 100644
--- a/src/benchmarking/comparisons.py
+++ b/src/benchmarking/comparisons.py
@@ -21,12 +21,17 @@
# Fall back to absolute imports (when run as script)
import sys
from pathlib import Path
+
# Add src to path if not already there
src_path = str(Path(__file__).parent.parent)
if src_path not in sys.path:
sys.path.insert(0, src_path)
-
- from benchmarking.framework import BenchmarkResult, BenchmarkSummary, BenchmarkFramework
+
+ from benchmarking.framework import (
+ BenchmarkResult,
+ BenchmarkSummary,
+ BenchmarkFramework,
+ )
from cache.manager import CacheManager
logger = structlog.get_logger(__name__)
@@ -34,6 +39,7 @@
class SystemType(Enum):
"""System types for comparison."""
+
FACT = "fact"
TRADITIONAL_RAG = "traditional_rag"
HYBRID = "hybrid"
@@ -42,6 +48,7 @@ class SystemType(Enum):
@dataclass
class ComparisonMetrics:
"""Metrics for system comparison."""
+
system_type: SystemType
avg_latency_ms: float
p95_latency_ms: float
@@ -58,6 +65,7 @@ class ComparisonMetrics:
@dataclass
class ComparisonResult:
"""Result of system comparison."""
+
fact_metrics: ComparisonMetrics
rag_metrics: ComparisonMetrics
improvement_factors: Dict[str, float]
@@ -70,127 +78,146 @@ class ComparisonResult:
class RAGComparison:
"""
Comprehensive comparison between FACT and traditional RAG systems.
-
+
Simulates traditional RAG behavior and measures performance differences
across latency, cost, and accuracy dimensions.
"""
-
+
def __init__(self, benchmark_framework: Optional[BenchmarkFramework] = None):
"""
Initialize RAG comparison.
-
+
Args:
benchmark_framework: Benchmark framework for measurements
"""
self.framework = benchmark_framework or BenchmarkFramework()
-
+
# Traditional RAG simulation parameters
self.rag_base_latency_ms = 400 # Typical RAG latency
self.rag_vector_search_ms = 50 # Vector database search time
self.rag_llm_processing_ms = 300 # LLM processing time
self.rag_context_preparation_ms = 50 # Context preparation time
-
+
# Token usage patterns
self.rag_avg_input_tokens = 1500 # Large context with retrieved docs
self.rag_avg_output_tokens = 300 # Detailed responses
self.fact_hit_input_tokens = 100 # Minimal prompt for cache
- self.fact_hit_output_tokens = 50 # Cached response
+ self.fact_hit_output_tokens = 50 # Cached response
self.fact_miss_input_tokens = 200 # Tool-enhanced prompt
- self.fact_miss_output_tokens = 150 # Response with tool results
-
+ self.fact_miss_output_tokens = 150 # Response with tool results
+
logger.info("RAG comparison initialized")
-
- async def run_comparison_benchmark(self,
- queries: List[str],
- cache_manager: Optional[CacheManager] = None,
- iterations: int = 10) -> ComparisonResult:
+
+ async def run_comparison_benchmark(
+ self,
+ queries: List[str],
+ cache_manager: Optional[CacheManager] = None,
+ iterations: int = 10,
+ ) -> ComparisonResult:
"""
Run comprehensive comparison benchmark.
-
+
Args:
queries: Test queries for comparison
cache_manager: Cache manager for FACT system
iterations: Number of iterations per query
-
+
Returns:
Comparison results
"""
- logger.info("Starting RAG comparison benchmark",
- queries_count=len(queries),
- iterations=iterations)
-
+ logger.info(
+ "Starting RAG comparison benchmark",
+ queries_count=len(queries),
+ iterations=iterations,
+ )
+
# Run FACT benchmarks
fact_summary = await self.framework.run_benchmark_suite(queries, cache_manager)
fact_metrics = self._extract_comparison_metrics(fact_summary, SystemType.FACT)
-
+
# Simulate traditional RAG performance
rag_results = await self._simulate_rag_performance(queries, iterations)
- rag_summary = self.framework._generate_summary(rag_results, fact_summary.execution_time_seconds)
- rag_metrics = self._extract_comparison_metrics(rag_summary, SystemType.TRADITIONAL_RAG)
-
+ rag_summary = self.framework._generate_summary(
+ rag_results, fact_summary.execution_time_seconds
+ )
+ rag_metrics = self._extract_comparison_metrics(
+ rag_summary, SystemType.TRADITIONAL_RAG
+ )
+
# Calculate improvement factors
- improvement_factors = self._calculate_improvement_factors(fact_metrics, rag_metrics)
-
+ improvement_factors = self._calculate_improvement_factors(
+ fact_metrics, rag_metrics
+ )
+
# Calculate cost savings
cost_savings = self._calculate_cost_savings(fact_metrics, rag_metrics)
-
+
# Generate performance analysis
- performance_analysis = self._analyze_performance_differences(fact_metrics, rag_metrics)
-
+ performance_analysis = self._analyze_performance_differences(
+ fact_metrics, rag_metrics
+ )
+
# Generate recommendation
- recommendation = self._generate_recommendation(improvement_factors, cost_savings)
-
+ recommendation = self._generate_recommendation(
+ improvement_factors, cost_savings
+ )
+
result = ComparisonResult(
fact_metrics=fact_metrics,
rag_metrics=rag_metrics,
improvement_factors=improvement_factors,
cost_savings=cost_savings,
performance_analysis=performance_analysis,
- recommendation=recommendation
+ recommendation=recommendation,
)
-
- logger.info("RAG comparison completed",
- latency_improvement=improvement_factors.get("latency", 0),
- cost_reduction=cost_savings.get("percentage", 0))
-
+
+ logger.info(
+ "RAG comparison completed",
+ latency_improvement=improvement_factors.get("latency", 0),
+ cost_reduction=cost_savings.get("percentage", 0),
+ )
+
return result
-
- async def _simulate_rag_performance(self,
- queries: List[str],
- iterations: int) -> List[BenchmarkResult]:
+
+ async def _simulate_rag_performance(
+ self, queries: List[str], iterations: int
+ ) -> List[BenchmarkResult]:
"""
Simulate traditional RAG system performance.
-
+
Args:
queries: Test queries
iterations: Number of iterations
-
+
Returns:
Simulated RAG benchmark results
"""
results = []
-
+
for iteration in range(iterations):
for query in queries:
# Simulate RAG processing time
base_latency = self.rag_base_latency_ms
-
+
# Add variability based on query complexity
- complexity_factor = min(len(query) / 100, 2.0) # Up to 2x for complex queries
+ complexity_factor = min(
+ len(query) / 100, 2.0
+ ) # Up to 2x for complex queries
variable_latency = base_latency * (0.8 + complexity_factor * 0.4)
-
+
# Add random variation (±20%)
import random
+
variation = random.uniform(0.8, 1.2)
total_latency = variable_latency * variation
-
+
# Simulate processing delay
await asyncio.sleep(total_latency / 1000) # Convert to seconds
-
+
# Calculate token costs for RAG
token_count = self.rag_avg_input_tokens + self.rag_avg_output_tokens
token_cost = self._calculate_rag_token_cost(token_count)
-
+
result = BenchmarkResult(
query=query,
response_time_ms=total_latency,
@@ -198,16 +225,16 @@ async def _simulate_rag_performance(self,
cache_hit=False, # Traditional RAG doesn't use cache
token_count=token_count,
token_cost=token_cost,
- metadata={"system_type": "traditional_rag"}
+ metadata={"system_type": "traditional_rag"},
)
-
+
results.append(result)
-
+
return results
-
- def _extract_comparison_metrics(self,
- summary: BenchmarkSummary,
- system_type: SystemType) -> ComparisonMetrics:
+
+ def _extract_comparison_metrics(
+ self, summary: BenchmarkSummary, system_type: SystemType
+ ) -> ComparisonMetrics:
"""Extract comparison metrics from benchmark summary."""
return ComparisonMetrics(
system_type=system_type,
@@ -220,45 +247,55 @@ def _extract_comparison_metrics(self,
error_rate=summary.error_rate,
throughput_qps=summary.throughput_qps,
memory_usage_mb=0.0, # TODO: Implement memory tracking
- cpu_usage_percent=0.0 # TODO: Implement CPU tracking
+ cpu_usage_percent=0.0, # TODO: Implement CPU tracking
)
-
- def _calculate_improvement_factors(self,
- fact_metrics: ComparisonMetrics,
- rag_metrics: ComparisonMetrics) -> Dict[str, float]:
+
+ def _calculate_improvement_factors(
+ self, fact_metrics: ComparisonMetrics, rag_metrics: ComparisonMetrics
+ ) -> Dict[str, float]:
"""Calculate improvement factors for FACT vs RAG."""
improvements = {}
-
+
# Latency improvements
if rag_metrics.avg_latency_ms > 0:
- improvements["latency"] = rag_metrics.avg_latency_ms / fact_metrics.avg_latency_ms
-
+ improvements["latency"] = (
+ rag_metrics.avg_latency_ms / fact_metrics.avg_latency_ms
+ )
+
if rag_metrics.p95_latency_ms > 0:
- improvements["p95_latency"] = rag_metrics.p95_latency_ms / fact_metrics.p95_latency_ms
-
+ improvements["p95_latency"] = (
+ rag_metrics.p95_latency_ms / fact_metrics.p95_latency_ms
+ )
+
# Throughput improvements
if fact_metrics.throughput_qps > 0 and rag_metrics.throughput_qps > 0:
- improvements["throughput"] = fact_metrics.throughput_qps / rag_metrics.throughput_qps
-
+ improvements["throughput"] = (
+ fact_metrics.throughput_qps / rag_metrics.throughput_qps
+ )
+
# Cost improvements
if rag_metrics.avg_token_cost > 0:
- improvements["cost_efficiency"] = rag_metrics.avg_token_cost / fact_metrics.avg_token_cost
-
+ improvements["cost_efficiency"] = (
+ rag_metrics.avg_token_cost / fact_metrics.avg_token_cost
+ )
+
return improvements
-
- def _calculate_cost_savings(self,
- fact_metrics: ComparisonMetrics,
- rag_metrics: ComparisonMetrics) -> Dict[str, float]:
+
+ def _calculate_cost_savings(
+ self, fact_metrics: ComparisonMetrics, rag_metrics: ComparisonMetrics
+ ) -> Dict[str, float]:
"""Calculate cost savings from using FACT."""
savings = {}
-
+
if rag_metrics.total_token_cost > 0:
- absolute_savings = rag_metrics.total_token_cost - fact_metrics.total_token_cost
+ absolute_savings = (
+ rag_metrics.total_token_cost - fact_metrics.total_token_cost
+ )
percentage_savings = (absolute_savings / rag_metrics.total_token_cost) * 100
-
+
savings["absolute_usd"] = absolute_savings
savings["percentage"] = percentage_savings
-
+
# Extrapolate monthly savings
if absolute_savings > 0:
# Assume this represents 1 hour of usage
@@ -266,31 +303,47 @@ def _calculate_cost_savings(self,
monthly_savings = daily_savings * 30
savings["monthly_usd"] = monthly_savings
savings["annual_usd"] = monthly_savings * 12
-
+
return savings
-
- def _analyze_performance_differences(self,
- fact_metrics: ComparisonMetrics,
- rag_metrics: ComparisonMetrics) -> Dict[str, Any]:
+
+ def _analyze_performance_differences(
+ self, fact_metrics: ComparisonMetrics, rag_metrics: ComparisonMetrics
+ ) -> Dict[str, Any]:
"""Analyze detailed performance differences."""
analysis = {
"latency_analysis": {
"fact_avg_ms": fact_metrics.avg_latency_ms,
"rag_avg_ms": rag_metrics.avg_latency_ms,
- "improvement_factor": rag_metrics.avg_latency_ms / fact_metrics.avg_latency_ms if fact_metrics.avg_latency_ms > 0 else 0,
- "time_saved_per_query_ms": rag_metrics.avg_latency_ms - fact_metrics.avg_latency_ms,
- "meets_target": fact_metrics.avg_latency_ms <= 100 # Overall target
+ "improvement_factor": (
+ rag_metrics.avg_latency_ms / fact_metrics.avg_latency_ms
+ if fact_metrics.avg_latency_ms > 0
+ else 0
+ ),
+ "time_saved_per_query_ms": rag_metrics.avg_latency_ms
+ - fact_metrics.avg_latency_ms,
+ "meets_target": fact_metrics.avg_latency_ms <= 100, # Overall target
},
"cost_analysis": {
"fact_cost_per_query": fact_metrics.avg_token_cost,
"rag_cost_per_query": rag_metrics.avg_token_cost,
- "cost_reduction_factor": rag_metrics.avg_token_cost / fact_metrics.avg_token_cost if fact_metrics.avg_token_cost > 0 else 0,
- "savings_per_query": rag_metrics.avg_token_cost - fact_metrics.avg_token_cost
+ "cost_reduction_factor": (
+ rag_metrics.avg_token_cost / fact_metrics.avg_token_cost
+ if fact_metrics.avg_token_cost > 0
+ else 0
+ ),
+ "savings_per_query": rag_metrics.avg_token_cost
+ - fact_metrics.avg_token_cost,
},
"efficiency_analysis": {
"fact_cache_hit_rate": fact_metrics.cache_hit_rate,
- "throughput_advantage": fact_metrics.throughput_qps / rag_metrics.throughput_qps if rag_metrics.throughput_qps > 0 else 0,
- "scalability_factor": self._calculate_scalability_factor(fact_metrics, rag_metrics)
+ "throughput_advantage": (
+ fact_metrics.throughput_qps / rag_metrics.throughput_qps
+ if rag_metrics.throughput_qps > 0
+ else 0
+ ),
+ "scalability_factor": self._calculate_scalability_factor(
+ fact_metrics, rag_metrics
+ ),
},
"target_compliance": {
"hit_latency_target": 48.0,
@@ -298,15 +351,19 @@ def _analyze_performance_differences(self,
"cost_reduction_target": 75.0,
"cache_hit_rate_target": 60.0,
"fact_meets_latency": fact_metrics.avg_latency_ms <= 100,
- "fact_meets_cost": (rag_metrics.avg_token_cost - fact_metrics.avg_token_cost) / rag_metrics.avg_token_cost >= 0.75
- }
+ "fact_meets_cost": (
+ rag_metrics.avg_token_cost - fact_metrics.avg_token_cost
+ )
+ / rag_metrics.avg_token_cost
+ >= 0.75,
+ },
}
-
+
return analysis
-
- def _calculate_scalability_factor(self,
- fact_metrics: ComparisonMetrics,
- rag_metrics: ComparisonMetrics) -> float:
+
+ def _calculate_scalability_factor(
+ self, fact_metrics: ComparisonMetrics, rag_metrics: ComparisonMetrics
+ ) -> float:
"""Calculate scalability advantage of FACT over RAG."""
# FACT scales better due to caching - estimate scaling factor
# Traditional RAG scales linearly with load
@@ -314,25 +371,25 @@ def _calculate_scalability_factor(self,
cache_efficiency = fact_metrics.cache_hit_rate / 100.0
scaling_advantage = 1.0 + (cache_efficiency * 2.0) # Up to 3x scaling advantage
return scaling_advantage
-
+
def _calculate_rag_token_cost(self, token_count: int) -> float:
"""Calculate token cost for traditional RAG system."""
# Use same pricing as FACT framework
input_tokens = int(token_count * 0.8) # RAG uses more input tokens
output_tokens = int(token_count * 0.2)
-
+
input_cost = input_tokens * self.framework.input_token_cost
output_cost = output_tokens * self.framework.output_token_cost
-
+
return input_cost + output_cost
-
- def _generate_recommendation(self,
- improvement_factors: Dict[str, float],
- cost_savings: Dict[str, float]) -> str:
+
+ def _generate_recommendation(
+ self, improvement_factors: Dict[str, float], cost_savings: Dict[str, float]
+ ) -> str:
"""Generate recommendation based on comparison results."""
latency_improvement = improvement_factors.get("latency", 1.0)
cost_reduction = cost_savings.get("percentage", 0.0)
-
+
if latency_improvement >= 3.0 and cost_reduction >= 70.0:
return "FACT shows exceptional performance advantages. Strongly recommended for production deployment."
elif latency_improvement >= 2.0 and cost_reduction >= 50.0:
@@ -346,106 +403,145 @@ def _generate_recommendation(self,
class PerformanceComparison:
"""
Advanced performance comparison with detailed analysis.
-
+
Provides deeper insights into performance characteristics, bottlenecks,
and optimization opportunities.
"""
-
+
def __init__(self):
"""Initialize performance comparison."""
self.baseline_measurements: Dict[str, List[float]] = {
"fact_latencies": [],
"rag_latencies": [],
"fact_costs": [],
- "rag_costs": []
+ "rag_costs": [],
}
-
+
logger.info("Performance comparison initialized")
-
- def analyze_latency_distribution(self,
- fact_results: List[BenchmarkResult],
- rag_results: List[BenchmarkResult]) -> Dict[str, Any]:
+
+ def analyze_latency_distribution(
+ self, fact_results: List[BenchmarkResult], rag_results: List[BenchmarkResult]
+ ) -> Dict[str, Any]:
"""Analyze latency distribution patterns."""
fact_latencies = [r.response_time_ms for r in fact_results if r.success]
rag_latencies = [r.response_time_ms for r in rag_results if r.success]
-
+
analysis = {
"fact_distribution": self._calculate_distribution_stats(fact_latencies),
"rag_distribution": self._calculate_distribution_stats(rag_latencies),
"comparison": {
- "median_improvement": statistics.median(rag_latencies) / statistics.median(fact_latencies) if fact_latencies else 0,
- "variance_ratio": statistics.variance(rag_latencies) / statistics.variance(fact_latencies) if len(fact_latencies) > 1 else 0,
- "consistency_advantage": self._calculate_consistency_advantage(fact_latencies, rag_latencies)
- }
+ "median_improvement": (
+ statistics.median(rag_latencies) / statistics.median(fact_latencies)
+ if fact_latencies
+ else 0
+ ),
+ "variance_ratio": (
+ statistics.variance(rag_latencies)
+ / statistics.variance(fact_latencies)
+ if len(fact_latencies) > 1
+ else 0
+ ),
+ "consistency_advantage": self._calculate_consistency_advantage(
+ fact_latencies, rag_latencies
+ ),
+ },
}
-
+
return analysis
-
- def analyze_cost_efficiency(self,
- fact_results: List[BenchmarkResult],
- rag_results: List[BenchmarkResult]) -> Dict[str, Any]:
+
+ def analyze_cost_efficiency(
+ self, fact_results: List[BenchmarkResult], rag_results: List[BenchmarkResult]
+ ) -> Dict[str, Any]:
"""Analyze cost efficiency patterns."""
fact_costs = [r.token_cost for r in fact_results if r.token_cost is not None]
rag_costs = [r.token_cost for r in rag_results if r.token_cost is not None]
-
- hit_costs = [r.token_cost for r in fact_results if r.cache_hit and r.token_cost is not None]
- miss_costs = [r.token_cost for r in fact_results if not r.cache_hit and r.token_cost is not None]
-
+
+ hit_costs = [
+ r.token_cost
+ for r in fact_results
+ if r.cache_hit and r.token_cost is not None
+ ]
+ miss_costs = [
+ r.token_cost
+ for r in fact_results
+ if not r.cache_hit and r.token_cost is not None
+ ]
+
analysis = {
"overall_savings": {
"fact_avg_cost": statistics.mean(fact_costs) if fact_costs else 0,
"rag_avg_cost": statistics.mean(rag_costs) if rag_costs else 0,
- "savings_ratio": statistics.mean(rag_costs) / statistics.mean(fact_costs) if fact_costs else 0
+ "savings_ratio": (
+ statistics.mean(rag_costs) / statistics.mean(fact_costs)
+ if fact_costs
+ else 0
+ ),
},
"cache_impact": {
"hit_avg_cost": statistics.mean(hit_costs) if hit_costs else 0,
"miss_avg_cost": statistics.mean(miss_costs) if miss_costs else 0,
- "hit_vs_miss_ratio": statistics.mean(miss_costs) / statistics.mean(hit_costs) if hit_costs else 0
+ "hit_vs_miss_ratio": (
+ statistics.mean(miss_costs) / statistics.mean(hit_costs)
+ if hit_costs
+ else 0
+ ),
},
- "efficiency_score": self._calculate_efficiency_score(fact_costs, rag_costs)
+ "efficiency_score": self._calculate_efficiency_score(fact_costs, rag_costs),
}
-
+
return analysis
-
+
def _calculate_distribution_stats(self, values: List[float]) -> Dict[str, float]:
"""Calculate distribution statistics."""
if not values:
return {"mean": 0, "median": 0, "std": 0, "min": 0, "max": 0}
-
+
return {
"mean": statistics.mean(values),
"median": statistics.median(values),
"std": statistics.stdev(values) if len(values) > 1 else 0,
"min": min(values),
"max": max(values),
- "p90": statistics.quantiles(values, n=10)[8] if len(values) >= 10 else max(values),
- "p95": statistics.quantiles(values, n=20)[18] if len(values) >= 20 else max(values),
- "p99": statistics.quantiles(values, n=100)[98] if len(values) >= 100 else max(values)
+ "p90": (
+ statistics.quantiles(values, n=10)[8]
+ if len(values) >= 10
+ else max(values)
+ ),
+ "p95": (
+ statistics.quantiles(values, n=20)[18]
+ if len(values) >= 20
+ else max(values)
+ ),
+ "p99": (
+ statistics.quantiles(values, n=100)[98]
+ if len(values) >= 100
+ else max(values)
+ ),
}
-
- def _calculate_consistency_advantage(self,
- fact_latencies: List[float],
- rag_latencies: List[float]) -> float:
+
+ def _calculate_consistency_advantage(
+ self, fact_latencies: List[float], rag_latencies: List[float]
+ ) -> float:
"""Calculate consistency advantage of FACT over RAG."""
if len(fact_latencies) <= 1 or len(rag_latencies) <= 1:
return 0.0
-
+
fact_cv = statistics.stdev(fact_latencies) / statistics.mean(fact_latencies)
rag_cv = statistics.stdev(rag_latencies) / statistics.mean(rag_latencies)
-
+
# Lower coefficient of variation indicates better consistency
return rag_cv / fact_cv if fact_cv > 0 else 0.0
-
- def _calculate_efficiency_score(self,
- fact_costs: List[float],
- rag_costs: List[float]) -> float:
+
+ def _calculate_efficiency_score(
+ self, fact_costs: List[float], rag_costs: List[float]
+ ) -> float:
"""Calculate overall efficiency score (0-100)."""
if not fact_costs or not rag_costs:
return 0.0
-
+
cost_ratio = statistics.mean(rag_costs) / statistics.mean(fact_costs)
-
+
# Convert to 0-100 scale where 100 = perfect efficiency
# Assume 5x cost improvement = 100% efficiency
efficiency_score = min(100, (cost_ratio - 1) * 20)
- return max(0, efficiency_score)
\ No newline at end of file
+ return max(0, efficiency_score)
diff --git a/src/benchmarking/framework.py b/src/benchmarking/framework.py
index e1e8cc7..312f395 100644
--- a/src/benchmarking/framework.py
+++ b/src/benchmarking/framework.py
@@ -22,11 +22,12 @@
# Fall back to absolute imports (when run as script)
import sys
from pathlib import Path
+
# Add src to path if not already there
src_path = str(Path(__file__).parent.parent)
if src_path not in sys.path:
sys.path.insert(0, src_path)
-
+
from cache.manager import CacheManager
from core.driver import process_user_query
from monitoring.metrics import get_metrics_collector
@@ -37,6 +38,7 @@
@dataclass
class BenchmarkConfig:
"""Configuration for benchmark execution."""
+
iterations: int = 10
warmup_iterations: int = 3
concurrent_users: int = 1
@@ -52,6 +54,7 @@ class BenchmarkConfig:
@dataclass
class BenchmarkResult:
"""Individual benchmark measurement result."""
+
query: str
response_time_ms: float
success: bool
@@ -66,12 +69,13 @@ class BenchmarkResult:
@dataclass
class BenchmarkSummary:
"""Aggregated benchmark results."""
+
total_queries: int
successful_queries: int
failed_queries: int
cache_hits: int
cache_misses: int
-
+
# Latency metrics
avg_response_time_ms: float
min_response_time_ms: float
@@ -79,28 +83,28 @@ class BenchmarkSummary:
p50_response_time_ms: float
p95_response_time_ms: float
p99_response_time_ms: float
-
+
# Cache-specific latency
avg_hit_latency_ms: float
avg_miss_latency_ms: float
-
+
# Cost metrics
total_token_cost: float
avg_token_cost: float
estimated_savings: float
cost_reduction_percentage: float
-
+
# Performance targets
hit_latency_target_met: bool
miss_latency_target_met: bool
cost_reduction_target_met: bool
cache_hit_rate_target_met: bool
-
+
# Quality metrics
error_rate: float
cache_hit_rate: float
throughput_qps: float
-
+
execution_time_seconds: float
timestamp: float = field(default_factory=time.time)
@@ -108,61 +112,63 @@ class BenchmarkSummary:
class BenchmarkFramework:
"""
Core benchmarking framework for FACT performance validation.
-
+
Provides comprehensive measurement of response times, token costs,
and comparison capabilities with traditional RAG systems.
"""
-
+
def __init__(self, config: Optional[BenchmarkConfig] = None):
"""
Initialize benchmarking framework.
-
+
Args:
config: Benchmark configuration
"""
self.config = config or BenchmarkConfig()
self.metrics_collector = get_metrics_collector()
self.results_history: List[BenchmarkResult] = []
-
+
# Token cost estimation (Claude pricing)
self.input_token_cost = 0.000003 # $0.003 per 1K tokens
self.output_token_cost = 0.000015 # $0.015 per 1K tokens
-
+
logger.info("Benchmark framework initialized", config=self.config)
-
- async def run_single_benchmark(self,
- query: str,
- cache_manager: Optional[CacheManager] = None) -> BenchmarkResult:
+
+ async def run_single_benchmark(
+ self, query: str, cache_manager: Optional[CacheManager] = None
+ ) -> BenchmarkResult:
"""
Run a single benchmark measurement.
-
+
Args:
query: Query to benchmark
cache_manager: Optional cache manager for hit detection
-
+
Returns:
Benchmark result
"""
start_time = time.perf_counter()
timestamp = time.time()
-
+
try:
# Pre-check for cache hit detection
cache_hit = False
pre_cache_check_time = 0.0
-
+
if cache_manager:
pre_check_start = time.perf_counter()
query_hash = cache_manager.generate_hash(query)
cached_result = cache_manager.get(query_hash)
pre_cache_check_time = (time.perf_counter() - pre_check_start) * 1000
cache_hit = cached_result is not None
-
+
# If it's a cache hit, measure actual cache latency
if cache_hit:
response = cached_result.content if cached_result else ""
end_time = time.perf_counter()
- response_time_ms = pre_cache_check_time # Use actual cache access time
+ response_time_ms = (
+ pre_cache_check_time # Use actual cache access time
+ )
else:
# Execute query for cache miss
response = await process_user_query(query)
@@ -173,11 +179,11 @@ async def run_single_benchmark(self,
response = await process_user_query(query)
end_time = time.perf_counter()
response_time_ms = (end_time - start_time) * 1000
-
+
# Enhanced token cost calculation
token_count = self._estimate_token_count(query, response)
token_cost = self._calculate_enhanced_token_cost(token_count, cache_hit)
-
+
result = BenchmarkResult(
query=query,
response_time_ms=response_time_ms,
@@ -185,71 +191,81 @@ async def run_single_benchmark(self,
cache_hit=cache_hit,
token_count=token_count,
token_cost=token_cost,
- timestamp=timestamp
+ timestamp=timestamp,
)
-
+
self.results_history.append(result)
-
- logger.debug("Benchmark completed",
- response_time_ms=response_time_ms,
- cache_hit=cache_hit,
- token_count=token_count)
-
+
+ logger.debug(
+ "Benchmark completed",
+ response_time_ms=response_time_ms,
+ cache_hit=cache_hit,
+ token_count=token_count,
+ )
+
return result
-
+
except Exception as e:
end_time = time.perf_counter()
response_time_ms = (end_time - start_time) * 1000
-
+
result = BenchmarkResult(
query=query,
response_time_ms=response_time_ms,
success=False,
cache_hit=False,
error=str(e),
- timestamp=timestamp
+ timestamp=timestamp,
)
-
+
self.results_history.append(result)
-
- logger.error("Benchmark failed",
- query=query,
- error=str(e),
- response_time_ms=response_time_ms)
-
+
+ logger.error(
+ "Benchmark failed",
+ query=query,
+ error=str(e),
+ response_time_ms=response_time_ms,
+ )
+
return result
-
- async def run_benchmark_suite(self,
- queries: List[str],
- cache_manager: Optional[CacheManager] = None) -> BenchmarkSummary:
+
+ async def run_benchmark_suite(
+ self, queries: List[str], cache_manager: Optional[CacheManager] = None
+ ) -> BenchmarkSummary:
"""
Run a complete benchmark suite.
-
+
Args:
queries: List of queries to benchmark
cache_manager: Optional cache manager
-
+
Returns:
Benchmark summary
"""
- logger.info("Starting benchmark suite",
- total_queries=len(queries) * self.config.iterations,
- iterations=self.config.iterations)
-
+ logger.info(
+ "Starting benchmark suite",
+ total_queries=len(queries) * self.config.iterations,
+ iterations=self.config.iterations,
+ )
+
start_time = time.perf_counter()
all_results: List[BenchmarkResult] = []
-
+
# Warmup phase
if self.config.warmup_iterations > 0:
- logger.info("Running warmup phase", iterations=self.config.warmup_iterations)
+ logger.info(
+ "Running warmup phase", iterations=self.config.warmup_iterations
+ )
for _ in range(self.config.warmup_iterations):
- for query in queries[:min(len(queries), 3)]: # Use first 3 queries for warmup
+ for query in queries[
+ : min(len(queries), 3)
+ ]: # Use first 3 queries for warmup
await self.run_single_benchmark(query, cache_manager)
-
+
# Main benchmark phase
for iteration in range(self.config.iterations):
logger.debug("Benchmark iteration", iteration=iteration + 1)
-
+
if self.config.concurrent_users == 1:
# Sequential execution
for query in queries:
@@ -261,88 +277,123 @@ async def run_benchmark_suite(self,
for query in queries:
for _ in range(self.config.concurrent_users):
task = asyncio.create_task(
- self.run_single_benchmark(f"{query} (user {_})", cache_manager)
+ self.run_single_benchmark(
+ f"{query} (user {_})", cache_manager
+ )
)
tasks.append(task)
-
+
batch_results = await asyncio.gather(*tasks, return_exceptions=True)
for result in batch_results:
if isinstance(result, BenchmarkResult):
all_results.append(result)
-
+
end_time = time.perf_counter()
execution_time = end_time - start_time
-
+
# Generate summary
summary = self._generate_summary(all_results, execution_time)
-
- logger.info("Benchmark suite completed",
- total_queries=summary.total_queries,
- avg_response_time_ms=summary.avg_response_time_ms,
- cache_hit_rate=summary.cache_hit_rate,
- execution_time_seconds=execution_time)
-
+
+ logger.info(
+ "Benchmark suite completed",
+ total_queries=summary.total_queries,
+ avg_response_time_ms=summary.avg_response_time_ms,
+ cache_hit_rate=summary.cache_hit_rate,
+ execution_time_seconds=execution_time,
+ )
+
return summary
-
- def _generate_summary(self, results: List[BenchmarkResult], execution_time: float) -> BenchmarkSummary:
+
+ def _generate_summary(
+ self, results: List[BenchmarkResult], execution_time: float
+ ) -> BenchmarkSummary:
"""Generate benchmark summary from results."""
if not results:
return BenchmarkSummary(
- total_queries=0, successful_queries=0, failed_queries=0,
- cache_hits=0, cache_misses=0, avg_response_time_ms=0.0,
- min_response_time_ms=0.0, max_response_time_ms=0.0,
- p50_response_time_ms=0.0, p95_response_time_ms=0.0, p99_response_time_ms=0.0,
- avg_hit_latency_ms=0.0, avg_miss_latency_ms=0.0,
- total_token_cost=0.0, avg_token_cost=0.0, estimated_savings=0.0,
- cost_reduction_percentage=0.0, hit_latency_target_met=False,
- miss_latency_target_met=False, cost_reduction_target_met=False,
- cache_hit_rate_target_met=False, error_rate=0.0, cache_hit_rate=0.0,
- throughput_qps=0.0, execution_time_seconds=execution_time
+ total_queries=0,
+ successful_queries=0,
+ failed_queries=0,
+ cache_hits=0,
+ cache_misses=0,
+ avg_response_time_ms=0.0,
+ min_response_time_ms=0.0,
+ max_response_time_ms=0.0,
+ p50_response_time_ms=0.0,
+ p95_response_time_ms=0.0,
+ p99_response_time_ms=0.0,
+ avg_hit_latency_ms=0.0,
+ avg_miss_latency_ms=0.0,
+ total_token_cost=0.0,
+ avg_token_cost=0.0,
+ estimated_savings=0.0,
+ cost_reduction_percentage=0.0,
+ hit_latency_target_met=False,
+ miss_latency_target_met=False,
+ cost_reduction_target_met=False,
+ cache_hit_rate_target_met=False,
+ error_rate=0.0,
+ cache_hit_rate=0.0,
+ throughput_qps=0.0,
+ execution_time_seconds=execution_time,
)
-
+
# Basic counts
total_queries = len(results)
successful_queries = sum(1 for r in results if r.success)
failed_queries = total_queries - successful_queries
cache_hits = sum(1 for r in results if r.cache_hit)
cache_misses = total_queries - cache_hits
-
+
# Latency calculations
response_times = [r.response_time_ms for r in results if r.success]
- hit_latencies = [r.response_time_ms for r in results if r.success and r.cache_hit]
- miss_latencies = [r.response_time_ms for r in results if r.success and not r.cache_hit]
-
+ hit_latencies = [
+ r.response_time_ms for r in results if r.success and r.cache_hit
+ ]
+ miss_latencies = [
+ r.response_time_ms for r in results if r.success and not r.cache_hit
+ ]
+
avg_response_time = statistics.mean(response_times) if response_times else 0.0
min_response_time = min(response_times) if response_times else 0.0
max_response_time = max(response_times) if response_times else 0.0
-
+
# Percentiles
p50 = statistics.median(response_times) if response_times else 0.0
- p95 = statistics.quantiles(response_times, n=20)[18] if len(response_times) >= 20 else max_response_time
- p99 = statistics.quantiles(response_times, n=100)[98] if len(response_times) >= 100 else max_response_time
-
+ p95 = (
+ statistics.quantiles(response_times, n=20)[18]
+ if len(response_times) >= 20
+ else max_response_time
+ )
+ p99 = (
+ statistics.quantiles(response_times, n=100)[98]
+ if len(response_times) >= 100
+ else max_response_time
+ )
+
avg_hit_latency = statistics.mean(hit_latencies) if hit_latencies else 0.0
avg_miss_latency = statistics.mean(miss_latencies) if miss_latencies else 0.0
-
+
# Cost calculations
token_costs = [r.token_cost for r in results if r.token_cost is not None]
total_token_cost = sum(token_costs) if token_costs else 0.0
avg_token_cost = statistics.mean(token_costs) if token_costs else 0.0
-
+
# Enhanced cost savings calculation with realistic baseline
if token_costs:
# Calculate baseline cost using industry-standard RAG system assumptions
avg_tokens_per_query = 1200 # Conservative estimate for RAG systems
baseline_token_cost_per_query = avg_tokens_per_query * self.input_token_cost
baseline_cost = total_queries * baseline_token_cost_per_query
-
+
# FACT system cost (actual usage)
fact_cost = total_token_cost
-
+
# Calculate savings and percentage
estimated_savings = max(0, baseline_cost - fact_cost)
- cost_reduction_percentage = (estimated_savings / baseline_cost * 100) if baseline_cost > 0 else 0.0
-
+ cost_reduction_percentage = (
+ (estimated_savings / baseline_cost * 100) if baseline_cost > 0 else 0.0
+ )
+
# Ensure realistic bounds (should be achievable with FACT)
if cost_reduction_percentage > 95:
cost_reduction_percentage = 95.0 # Cap at 95% for realism
@@ -352,18 +403,22 @@ def _generate_summary(self, results: List[BenchmarkResult], execution_time: floa
baseline_cost = 0.0
estimated_savings = 0.0
cost_reduction_percentage = 90.0 # Default expected value when no data
-
+
# Quality metrics
- error_rate = (failed_queries / total_queries * 100) if total_queries > 0 else 0.0
+ error_rate = (
+ (failed_queries / total_queries * 100) if total_queries > 0 else 0.0
+ )
cache_hit_rate = (cache_hits / total_queries) if total_queries > 0 else 0.0
throughput_qps = total_queries / execution_time if execution_time > 0 else 0.0
-
+
# Target validation
hit_latency_target_met = avg_hit_latency <= self.config.target_hit_latency_ms
miss_latency_target_met = avg_miss_latency <= self.config.target_miss_latency_ms
- cost_reduction_target_met = cost_reduction_percentage >= (self.config.target_cost_reduction_hit * 100)
+ cost_reduction_target_met = cost_reduction_percentage >= (
+ self.config.target_cost_reduction_hit * 100
+ )
cache_hit_rate_target_met = cache_hit_rate >= self.config.target_cache_hit_rate
-
+
return BenchmarkSummary(
total_queries=total_queries,
successful_queries=successful_queries,
@@ -389,26 +444,30 @@ def _generate_summary(self, results: List[BenchmarkResult], execution_time: floa
error_rate=error_rate,
cache_hit_rate=cache_hit_rate,
throughput_qps=throughput_qps,
- execution_time_seconds=execution_time
+ execution_time_seconds=execution_time,
)
-
+
def _estimate_token_count(self, query: str, response: str) -> int:
"""Estimate token count for query and response."""
# Rough estimation: ~4 characters per token
query_tokens = len(query) // 4
response_tokens = len(response) // 4 if response else 0
return query_tokens + response_tokens
-
+
def _calculate_token_cost(self, token_count: int) -> float:
"""Calculate token cost based on estimated usage."""
# Assume 70% input tokens, 30% output tokens
input_tokens = int(token_count * 0.7)
output_tokens = int(token_count * 0.3)
-
- return (input_tokens * self.input_token_cost +
- output_tokens * self.output_token_cost)
-
- def _calculate_enhanced_token_cost(self, token_count: int, cache_hit: bool) -> float:
+
+ return (
+ input_tokens * self.input_token_cost
+ + output_tokens * self.output_token_cost
+ )
+
+ def _calculate_enhanced_token_cost(
+ self, token_count: int, cache_hit: bool
+ ) -> float:
"""Enhanced token cost calculation considering cache efficiency."""
if cache_hit:
# Cache hits have minimal token costs (only retrieval overhead)
@@ -417,18 +476,20 @@ def _calculate_enhanced_token_cost(self, token_count: int, cache_hit: bool) -> f
# Cache misses use full token processing
input_tokens = int(token_count * 0.65) # Slightly optimized with FACT
output_tokens = int(token_count * 0.35)
-
- return (input_tokens * self.input_token_cost +
- output_tokens * self.output_token_cost)
+
+ return (
+ input_tokens * self.input_token_cost
+ + output_tokens * self.output_token_cost
+ )
class BenchmarkRunner:
"""High-level benchmark execution orchestrator."""
-
+
def __init__(self, framework: Optional[BenchmarkFramework] = None):
"""
Initialize benchmark runner.
-
+
Args:
framework: Benchmark framework instance
"""
@@ -443,27 +504,30 @@ def __init__(self, framework: Optional[BenchmarkFramework] = None):
"Compare performance across regions",
"What is the customer acquisition cost?",
"Show quarterly expense breakdown",
- "Predict next quarter's revenue"
+ "Predict next quarter's revenue",
]
-
+
logger.info("Benchmark runner initialized")
-
- async def run_performance_validation(self,
- cache_manager: Optional[CacheManager] = None) -> Dict[str, Any]:
+
+ async def run_performance_validation(
+ self, cache_manager: Optional[CacheManager] = None
+ ) -> Dict[str, Any]:
"""
Run complete performance validation against targets.
-
+
Args:
cache_manager: Optional cache manager
-
+
Returns:
Validation results
"""
logger.info("Starting performance validation")
-
+
# Run benchmark suite
- summary = await self.framework.run_benchmark_suite(self.test_queries, cache_manager)
-
+ summary = await self.framework.run_benchmark_suite(
+ self.test_queries, cache_manager
+ )
+
# Validate against targets
validation_results = {
"timestamp": time.time(),
@@ -473,100 +537,111 @@ async def run_performance_validation(self,
"target_ms": self.framework.config.target_hit_latency_ms,
"actual_ms": summary.avg_hit_latency_ms,
"met": summary.hit_latency_target_met,
- "margin_ms": self.framework.config.target_hit_latency_ms - summary.avg_hit_latency_ms
+ "margin_ms": self.framework.config.target_hit_latency_ms
+ - summary.avg_hit_latency_ms,
},
"cache_miss_latency": {
"target_ms": self.framework.config.target_miss_latency_ms,
"actual_ms": summary.avg_miss_latency_ms,
"met": summary.miss_latency_target_met,
- "margin_ms": self.framework.config.target_miss_latency_ms - summary.avg_miss_latency_ms
+ "margin_ms": self.framework.config.target_miss_latency_ms
+ - summary.avg_miss_latency_ms,
},
"cost_reduction": {
- "target_percent": self.framework.config.target_cost_reduction_hit * 100,
+ "target_percent": self.framework.config.target_cost_reduction_hit
+ * 100,
"actual_percent": summary.cost_reduction_percentage,
"met": summary.cost_reduction_target_met,
- "margin_percent": summary.cost_reduction_percentage - (self.framework.config.target_cost_reduction_hit * 100)
+ "margin_percent": summary.cost_reduction_percentage
+ - (self.framework.config.target_cost_reduction_hit * 100),
},
"cache_hit_rate": {
"target_percent": self.framework.config.target_cache_hit_rate * 100,
"actual_percent": summary.cache_hit_rate,
"met": summary.cache_hit_rate_target_met,
- "margin_percent": summary.cache_hit_rate - (self.framework.config.target_cache_hit_rate * 100)
- }
+ "margin_percent": summary.cache_hit_rate
+ - (self.framework.config.target_cache_hit_rate * 100),
+ },
},
- "overall_pass": all([
- summary.hit_latency_target_met,
- summary.miss_latency_target_met,
- summary.cost_reduction_target_met,
- summary.cache_hit_rate_target_met
- ])
+ "overall_pass": all(
+ [
+ summary.hit_latency_target_met,
+ summary.miss_latency_target_met,
+ summary.cost_reduction_target_met,
+ summary.cache_hit_rate_target_met,
+ ]
+ ),
}
-
- logger.info("Performance validation completed",
- overall_pass=validation_results["overall_pass"],
- hit_latency_met=summary.hit_latency_target_met,
- cost_reduction_met=summary.cost_reduction_target_met)
-
+
+ logger.info(
+ "Performance validation completed",
+ overall_pass=validation_results["overall_pass"],
+ hit_latency_met=summary.hit_latency_target_met,
+ cost_reduction_met=summary.cost_reduction_target_met,
+ )
+
return validation_results
-
- async def run_load_test(self,
- concurrent_users: int = 10,
- duration_seconds: int = 60) -> Dict[str, Any]:
+
+ async def run_load_test(
+ self, concurrent_users: int = 10, duration_seconds: int = 60
+ ) -> Dict[str, Any]:
"""
Run load testing to validate performance under concurrent load.
-
+
Args:
concurrent_users: Number of concurrent users
duration_seconds: Test duration
-
+
Returns:
Load test results
"""
- logger.info("Starting load test",
- concurrent_users=concurrent_users,
- duration_seconds=duration_seconds)
-
+ logger.info(
+ "Starting load test",
+ concurrent_users=concurrent_users,
+ duration_seconds=duration_seconds,
+ )
+
# Update config for load testing
original_config = self.framework.config
self.framework.config = BenchmarkConfig(
iterations=1,
concurrent_users=concurrent_users,
- timeout_seconds=duration_seconds
+ timeout_seconds=duration_seconds,
)
-
+
start_time = time.time()
results = []
-
+
# Run concurrent sessions
async def user_session(user_id: int):
session_results = []
end_time = start_time + duration_seconds
-
+
while time.time() < end_time:
query = f"{self.test_queries[user_id % len(self.test_queries)]} (user {user_id})"
result = await self.framework.run_single_benchmark(query)
session_results.append(result)
-
+
# Brief pause between queries
await asyncio.sleep(0.1)
-
+
return session_results
-
+
# Execute concurrent sessions
tasks = [user_session(i) for i in range(concurrent_users)]
session_results = await asyncio.gather(*tasks)
-
+
# Flatten results
for session in session_results:
results.extend(session)
-
+
# Restore original config
self.framework.config = original_config
-
+
# Generate load test summary
execution_time = time.time() - start_time
summary = self.framework._generate_summary(results, execution_time)
-
+
load_test_results = {
"timestamp": time.time(),
"concurrent_users": concurrent_users,
@@ -577,32 +652,40 @@ async def user_session(user_id: int):
"p95_response_time_ms": summary.p95_response_time_ms,
"error_rate": summary.error_rate,
"cache_hit_rate": summary.cache_hit_rate,
- "performance_degradation": self._calculate_performance_degradation(results)
+ "performance_degradation": self._calculate_performance_degradation(results),
}
-
- logger.info("Load test completed",
- throughput_qps=summary.throughput_qps,
- avg_response_time_ms=summary.avg_response_time_ms)
-
+
+ logger.info(
+ "Load test completed",
+ throughput_qps=summary.throughput_qps,
+ avg_response_time_ms=summary.avg_response_time_ms,
+ )
+
return load_test_results
-
- def _calculate_performance_degradation(self, results: List[BenchmarkResult]) -> float:
+
+ def _calculate_performance_degradation(
+ self, results: List[BenchmarkResult]
+ ) -> float:
"""Calculate performance degradation over time during load test."""
if len(results) < 10:
return 0.0
-
+
# Compare first 10% vs last 10% of results
early_count = max(1, len(results) // 10)
late_count = max(1, len(results) // 10)
-
+
early_results = results[:early_count]
late_results = results[-late_count:]
-
- early_avg = statistics.mean(r.response_time_ms for r in early_results if r.success)
- late_avg = statistics.mean(r.response_time_ms for r in late_results if r.success)
-
+
+ early_avg = statistics.mean(
+ r.response_time_ms for r in early_results if r.success
+ )
+ late_avg = statistics.mean(
+ r.response_time_ms for r in late_results if r.success
+ )
+
if early_avg == 0:
return 0.0
-
+
degradation = (late_avg - early_avg) / early_avg * 100
- return max(0.0, degradation)
\ No newline at end of file
+ return max(0.0, degradation)
diff --git a/src/benchmarking/monitoring.py b/src/benchmarking/monitoring.py
index b73b34f..8d4c54a 100644
--- a/src/benchmarking/monitoring.py
+++ b/src/benchmarking/monitoring.py
@@ -24,11 +24,12 @@
# Fall back to absolute imports (when run as script)
import sys
from pathlib import Path
+
# Add src to path if not already there
src_path = str(Path(__file__).parent.parent)
if src_path not in sys.path:
sys.path.insert(0, src_path)
-
+
from benchmarking.framework import BenchmarkFramework, BenchmarkResult
from benchmarking.profiler import SystemProfiler, BottleneckAnalyzer
from cache.manager import CacheManager
@@ -40,6 +41,7 @@
@dataclass
class PerformanceAlert:
"""Performance alert definition."""
+
alert_id: str
severity: str # "info", "warning", "critical"
component: str
@@ -55,6 +57,7 @@ class PerformanceAlert:
@dataclass
class PerformanceTrend:
"""Performance trend analysis."""
+
metric_name: str
time_period: str
direction: str # "improving", "stable", "degrading"
@@ -67,12 +70,13 @@ class PerformanceTrend:
@dataclass
class MonitoringConfig:
"""Configuration for continuous monitoring."""
+
monitoring_interval_seconds: int = 60
alert_check_interval_seconds: int = 30
trend_analysis_hours: int = 24
max_alerts_per_hour: int = 10
alert_cooldown_minutes: int = 15
-
+
# Performance thresholds
response_time_warning_ms: float = 80.0
response_time_critical_ms: float = 120.0
@@ -87,15 +91,15 @@ class MonitoringConfig:
class ContinuousMonitor:
"""
Continuous performance monitoring system.
-
+
Provides real-time monitoring, alerting, and trend analysis
for FACT system performance.
"""
-
+
def __init__(self, config: Optional[MonitoringConfig] = None):
"""
Initialize continuous monitor.
-
+
Args:
config: Monitoring configuration
"""
@@ -103,64 +107,62 @@ def __init__(self, config: Optional[MonitoringConfig] = None):
self.benchmark_framework = BenchmarkFramework()
self.profiler = SystemProfiler()
self.bottleneck_analyzer = BottleneckAnalyzer()
-
+
# Monitoring state
self.monitoring_active = False
self.monitor_task: Optional[asyncio.Task] = None
self.alert_task: Optional[asyncio.Task] = None
-
+
# Data storage
self.performance_history: deque = deque(maxlen=10000)
self.active_alerts: Dict[str, PerformanceAlert] = {}
self.alert_history: deque = deque(maxlen=1000)
self.trends: Dict[str, PerformanceTrend] = {}
-
+
# Alert callbacks
self.alert_callbacks: List[Callable[[PerformanceAlert], None]] = []
-
+
# Test queries for monitoring
self.monitoring_queries = [
"What is the current system status?",
"Generate a quick performance summary",
- "Check recent metrics"
+ "Check recent metrics",
]
-
+
logger.info("Continuous monitor initialized")
-
+
async def start_monitoring(self, cache_manager: Optional[CacheManager] = None):
"""
Start continuous performance monitoring.
-
+
Args:
cache_manager: Cache manager to monitor
"""
if self.monitoring_active:
logger.warning("Monitoring already active")
return
-
+
self.monitoring_active = True
-
+
# Start profiler monitoring
await self.profiler.start_continuous_monitoring()
-
+
# Start monitoring tasks
- self.monitor_task = asyncio.create_task(
- self._monitoring_loop(cache_manager)
- )
- self.alert_task = asyncio.create_task(
- self._alert_checking_loop(cache_manager)
+ self.monitor_task = asyncio.create_task(self._monitoring_loop(cache_manager))
+ self.alert_task = asyncio.create_task(self._alert_checking_loop(cache_manager))
+
+ logger.info(
+ "Continuous monitoring started",
+ interval_seconds=self.config.monitoring_interval_seconds,
)
-
- logger.info("Continuous monitoring started",
- interval_seconds=self.config.monitoring_interval_seconds)
-
+
async def stop_monitoring(self):
"""Stop continuous performance monitoring."""
self.monitoring_active = False
-
+
# Stop profiler monitoring
await self.profiler.stop_continuous_monitoring()
-
+
# Cancel monitoring tasks
if self.monitor_task:
self.monitor_task.cancel()
@@ -168,112 +170,118 @@ async def stop_monitoring(self):
await self.monitor_task
except asyncio.CancelledError:
pass
-
+
if self.alert_task:
self.alert_task.cancel()
try:
await self.alert_task
except asyncio.CancelledError:
pass
-
+
logger.info("Continuous monitoring stopped")
-
+
async def _monitoring_loop(self, cache_manager: Optional[CacheManager]):
"""Main monitoring loop."""
query_index = 0
-
+
while self.monitoring_active:
try:
# Run performance measurement
- query = self.monitoring_queries[query_index % len(self.monitoring_queries)]
-
+ query = self.monitoring_queries[
+ query_index % len(self.monitoring_queries)
+ ]
+
start_time = time.perf_counter()
result = await self.benchmark_framework.run_single_benchmark(
query, cache_manager
)
-
+
# Store performance data
self.performance_history.append(result)
-
+
# Update trends
await self._update_performance_trends()
-
+
query_index += 1
-
- logger.debug("Monitoring measurement completed",
- response_time_ms=result.response_time_ms,
- cache_hit=result.cache_hit)
-
+
+ logger.debug(
+ "Monitoring measurement completed",
+ response_time_ms=result.response_time_ms,
+ cache_hit=result.cache_hit,
+ )
+
await asyncio.sleep(self.config.monitoring_interval_seconds)
-
+
except asyncio.CancelledError:
break
except Exception as e:
logger.error("Error in monitoring loop", error=str(e))
await asyncio.sleep(5.0) # Brief pause before retry
-
+
async def _alert_checking_loop(self, cache_manager: Optional[CacheManager]):
"""Alert checking loop."""
while self.monitoring_active:
try:
# Check for performance alerts
await self._check_performance_alerts(cache_manager)
-
+
# Check for trend alerts
await self._check_trend_alerts()
-
+
# Clean up resolved alerts
self._cleanup_resolved_alerts()
-
+
await asyncio.sleep(self.config.alert_check_interval_seconds)
-
+
except asyncio.CancelledError:
break
except Exception as e:
logger.error("Error in alert checking loop", error=str(e))
await asyncio.sleep(5.0)
-
+
async def _check_performance_alerts(self, cache_manager: Optional[CacheManager]):
"""Check for performance threshold violations."""
if not self.performance_history:
return
-
+
# Get recent performance data
recent_results = list(self.performance_history)[-10:] # Last 10 measurements
-
+
if not recent_results:
return
-
+
# Calculate current metrics
response_times = [r.response_time_ms for r in recent_results if r.success]
cache_hits = sum(1 for r in recent_results if r.cache_hit)
total_requests = len(recent_results)
errors = sum(1 for r in recent_results if not r.success)
-
+
if not response_times:
return
-
+
avg_response_time = sum(response_times) / len(response_times)
- cache_hit_rate = (cache_hits / total_requests * 100) if total_requests > 0 else 0
+ cache_hit_rate = (
+ (cache_hits / total_requests * 100) if total_requests > 0 else 0
+ )
error_rate = (errors / total_requests * 100) if total_requests > 0 else 0
-
+
# Check response time alerts
await self._check_response_time_alert(avg_response_time)
-
+
# Check cache hit rate alerts
await self._check_cache_hit_rate_alert(cache_hit_rate)
-
+
# Check error rate alerts
await self._check_error_rate_alert(error_rate)
-
+
# Check cache-specific alerts
if cache_manager:
await self._check_cache_alerts(cache_manager)
-
+
async def _check_response_time_alert(self, avg_response_time: float):
"""Check response time threshold alerts."""
alert_id = "response_time"
-
+
if avg_response_time >= self.config.response_time_critical_ms:
severity = "critical"
threshold = self.config.response_time_critical_ms
@@ -285,7 +293,7 @@ async def _check_response_time_alert(self, avg_response_time: float):
if alert_id in self.active_alerts:
await self._resolve_alert(alert_id)
return
-
+
# Create or update alert
if alert_id not in self.active_alerts:
alert = PerformanceAlert(
@@ -295,15 +303,15 @@ async def _check_response_time_alert(self, avg_response_time: float):
metric_name="avg_response_time_ms",
threshold_value=threshold,
actual_value=avg_response_time,
- message=f"Average response time {avg_response_time:.1f}ms exceeds {severity} threshold {threshold}ms"
+ message=f"Average response time {avg_response_time:.1f}ms exceeds {severity} threshold {threshold}ms",
)
-
+
await self._trigger_alert(alert)
-
+
async def _check_cache_hit_rate_alert(self, cache_hit_rate: float):
"""Check cache hit rate threshold alerts."""
alert_id = "cache_hit_rate"
-
+
if cache_hit_rate <= self.config.cache_hit_rate_critical:
severity = "critical"
threshold = self.config.cache_hit_rate_critical
@@ -315,7 +323,7 @@ async def _check_cache_hit_rate_alert(self, cache_hit_rate: float):
if alert_id in self.active_alerts:
await self._resolve_alert(alert_id)
return
-
+
if alert_id not in self.active_alerts:
alert = PerformanceAlert(
alert_id=alert_id,
@@ -324,15 +332,15 @@ async def _check_cache_hit_rate_alert(self, cache_hit_rate: float):
metric_name="cache_hit_rate",
threshold_value=threshold,
actual_value=cache_hit_rate,
- message=f"Cache hit rate {cache_hit_rate:.1f}% below {severity} threshold {threshold}%"
+ message=f"Cache hit rate {cache_hit_rate:.1f}% below {severity} threshold {threshold}%",
)
-
+
await self._trigger_alert(alert)
-
+
async def _check_error_rate_alert(self, error_rate: float):
"""Check error rate threshold alerts."""
alert_id = "error_rate"
-
+
if error_rate >= self.config.error_rate_critical:
severity = "critical"
threshold = self.config.error_rate_critical
@@ -344,7 +352,7 @@ async def _check_error_rate_alert(self, error_rate: float):
if alert_id in self.active_alerts:
await self._resolve_alert(alert_id)
return
-
+
if alert_id not in self.active_alerts:
alert = PerformanceAlert(
alert_id=alert_id,
@@ -353,18 +361,18 @@ async def _check_error_rate_alert(self, error_rate: float):
metric_name="error_rate",
threshold_value=threshold,
actual_value=error_rate,
- message=f"Error rate {error_rate:.1f}% exceeds {severity} threshold {threshold}%"
+ message=f"Error rate {error_rate:.1f}% exceeds {severity} threshold {threshold}%",
)
-
+
await self._trigger_alert(alert)
-
+
async def _check_cache_alerts(self, cache_manager: CacheManager):
"""Check cache-specific alerts."""
try:
metrics = cache_manager.get_metrics()
-
+
# Memory utilization alert
- memory_util = (metrics.total_size / cache_manager.max_size_bytes * 100)
+ memory_util = metrics.total_size / cache_manager.max_size_bytes * 100
if memory_util > 90:
alert_id = "cache_memory"
if alert_id not in self.active_alerts:
@@ -375,19 +383,19 @@ async def _check_cache_alerts(self, cache_manager: CacheManager):
metric_name="memory_utilization",
threshold_value=90.0,
actual_value=memory_util,
- message=f"Cache memory utilization {memory_util:.1f}% approaching limit"
+ message=f"Cache memory utilization {memory_util:.1f}% approaching limit",
)
await self._trigger_alert(alert)
-
+
except Exception as e:
logger.error("Error checking cache alerts", error=str(e))
-
+
async def _check_trend_alerts(self):
"""Check for concerning performance trends."""
for metric_name, trend in self.trends.items():
if trend.direction == "degrading" and trend.significance > 0.7:
alert_id = f"trend_{metric_name}"
-
+
if alert_id not in self.active_alerts:
alert = PerformanceAlert(
alert_id=alert_id,
@@ -396,90 +404,95 @@ async def _check_trend_alerts(self):
metric_name=metric_name,
threshold_value=0.0, # Trend-based, no fixed threshold
actual_value=trend.change_percentage,
- message=f"{metric_name} showing degrading trend: {trend.change_percentage:.1f}% change"
+ message=f"{metric_name} showing degrading trend: {trend.change_percentage:.1f}% change",
)
-
+
await self._trigger_alert(alert)
-
+
async def _trigger_alert(self, alert: PerformanceAlert):
"""Trigger a performance alert."""
self.active_alerts[alert.alert_id] = alert
self.alert_history.append(alert)
-
- logger.warning("Performance alert triggered",
- alert_id=alert.alert_id,
- severity=alert.severity,
- component=alert.component,
- message=alert.message)
-
+
+ logger.warning(
+ "Performance alert triggered",
+ alert_id=alert.alert_id,
+ severity=alert.severity,
+ component=alert.component,
+ message=alert.message,
+ )
+
# Call alert callbacks
for callback in self.alert_callbacks:
try:
callback(alert)
except Exception as e:
logger.error("Error in alert callback", error=str(e))
-
+
async def _resolve_alert(self, alert_id: str):
"""Resolve an active alert."""
if alert_id in self.active_alerts:
alert = self.active_alerts[alert_id]
alert.resolved = True
alert.resolution_time = time.time()
-
+
del self.active_alerts[alert_id]
-
- logger.info("Performance alert resolved",
- alert_id=alert_id,
- component=alert.component)
-
+
+ logger.info(
+ "Performance alert resolved",
+ alert_id=alert_id,
+ component=alert.component,
+ )
+
def _cleanup_resolved_alerts(self):
"""Clean up old resolved alerts from history."""
cutoff_time = time.time() - (24 * 60 * 60) # 24 hours ago
-
+
# Keep only recent alerts
recent_alerts = deque()
for alert in self.alert_history:
if alert.timestamp >= cutoff_time:
recent_alerts.append(alert)
-
+
self.alert_history = recent_alerts
-
+
async def _update_performance_trends(self):
"""Update performance trend analysis."""
if len(self.performance_history) < 10:
return
-
+
# Analyze trends for key metrics
metrics_to_analyze = [
("response_time", lambda r: r.response_time_ms),
("cache_hit_rate", lambda r: 1.0 if r.cache_hit else 0.0),
- ("success_rate", lambda r: 1.0 if r.success else 0.0)
+ ("success_rate", lambda r: 1.0 if r.success else 0.0),
]
-
+
for metric_name, extractor in metrics_to_analyze:
trend = self._calculate_trend(metric_name, extractor)
if trend:
self.trends[metric_name] = trend
-
- def _calculate_trend(self, metric_name: str, value_extractor: Callable) -> Optional[PerformanceTrend]:
+
+ def _calculate_trend(
+ self, metric_name: str, value_extractor: Callable
+ ) -> Optional[PerformanceTrend]:
"""Calculate trend for a specific metric."""
try:
# Get recent data points
recent_data = list(self.performance_history)[-100:] # Last 100 measurements
-
+
if len(recent_data) < 10:
return None
-
+
# Extract values and timestamps
data_points = [
- (result.timestamp, value_extractor(result))
- for result in recent_data
+ (result.timestamp, value_extractor(result)) for result in recent_data
]
-
+
# Simple linear trend analysis
timestamps = [p[0] for p in data_points]
values = [p[1] for p in data_points]
-
+
if not values or all(v == values[0] for v in values):
return PerformanceTrend(
metric_name=metric_name,
@@ -488,38 +501,42 @@ def _calculate_trend(self, metric_name: str, value_extractor: Callable) -> Optio
change_percentage=0.0,
significance=1.0,
data_points=data_points,
- analysis="No significant variation detected"
+ analysis="No significant variation detected",
)
-
+
# Calculate slope (trend direction)
n = len(values)
sum_x = sum(range(n))
sum_y = sum(values)
sum_xy = sum(i * v for i, v in enumerate(values))
sum_x2 = sum(i * i for i in range(n))
-
+
slope = (n * sum_xy - sum_x * sum_y) / (n * sum_x2 - sum_x * sum_x)
-
+
# Calculate percentage change
first_avg = sum(values[:5]) / 5 if len(values) >= 5 else values[0]
last_avg = sum(values[-5:]) / 5 if len(values) >= 5 else values[-1]
-
+
if first_avg != 0:
change_percentage = ((last_avg - first_avg) / first_avg) * 100
else:
change_percentage = 0.0
-
+
# Determine trend direction
if abs(change_percentage) < 5.0:
direction = "stable"
elif change_percentage > 0:
- direction = "improving" if metric_name != "response_time" else "degrading"
+ direction = (
+ "improving" if metric_name != "response_time" else "degrading"
+ )
else:
- direction = "degrading" if metric_name != "response_time" else "improving"
-
+ direction = (
+ "degrading" if metric_name != "response_time" else "improving"
+ )
+
# Calculate significance (correlation strength)
significance = min(1.0, abs(slope) * 10) # Simple significance measure
-
+
return PerformanceTrend(
metric_name=metric_name,
time_period="recent",
@@ -527,105 +544,125 @@ def _calculate_trend(self, metric_name: str, value_extractor: Callable) -> Optio
change_percentage=abs(change_percentage),
significance=significance,
data_points=data_points[-20:], # Keep last 20 points
- analysis=f"Trend shows {direction} pattern with {change_percentage:.1f}% change"
+ analysis=f"Trend shows {direction} pattern with {change_percentage:.1f}% change",
)
-
+
except Exception as e:
logger.error("Error calculating trend", metric=metric_name, error=str(e))
return None
-
+
def add_alert_callback(self, callback: Callable[[PerformanceAlert], None]):
"""Add callback function for alert notifications."""
self.alert_callbacks.append(callback)
-
+
def get_monitoring_status(self) -> Dict[str, Any]:
"""Get current monitoring status."""
return {
"monitoring_active": self.monitoring_active,
"active_alerts": len(self.active_alerts),
"alert_summary": {
- severity: len([a for a in self.active_alerts.values() if a.severity == severity])
+ severity: len(
+ [a for a in self.active_alerts.values() if a.severity == severity]
+ )
for severity in ["info", "warning", "critical"]
},
"performance_history_size": len(self.performance_history),
"trends_tracked": len(self.trends),
- "config": asdict(self.config)
+ "config": asdict(self.config),
}
-
+
def export_monitoring_report(self) -> Dict[str, Any]:
"""Export comprehensive monitoring report."""
# Recent performance summary
recent_results = list(self.performance_history)[-50:]
-
+
if recent_results:
successful_results = [r for r in recent_results if r.success]
response_times = [r.response_time_ms for r in successful_results]
cache_hits = sum(1 for r in recent_results if r.cache_hit)
-
+
performance_summary = {
"total_measurements": len(recent_results),
"successful_measurements": len(successful_results),
- "avg_response_time_ms": sum(response_times) / len(response_times) if response_times else 0,
- "cache_hit_rate": (cache_hits / len(recent_results) * 100) if recent_results else 0,
- "error_rate": ((len(recent_results) - len(successful_results)) / len(recent_results) * 100) if recent_results else 0
+ "avg_response_time_ms": (
+ sum(response_times) / len(response_times) if response_times else 0
+ ),
+ "cache_hit_rate": (
+ (cache_hits / len(recent_results) * 100) if recent_results else 0
+ ),
+ "error_rate": (
+ (
+ (len(recent_results) - len(successful_results))
+ / len(recent_results)
+ * 100
+ )
+ if recent_results
+ else 0
+ ),
}
else:
performance_summary = {}
-
+
return {
"timestamp": time.time(),
"monitoring_period": "recent",
"status": self.get_monitoring_status(),
"performance_summary": performance_summary,
"active_alerts": [asdict(alert) for alert in self.active_alerts.values()],
- "recent_alert_history": [asdict(alert) for alert in list(self.alert_history)[-10:]],
+ "recent_alert_history": [
+ asdict(alert) for alert in list(self.alert_history)[-10:]
+ ],
"trends": {name: asdict(trend) for name, trend in self.trends.items()},
- "recommendations": self._generate_monitoring_recommendations()
+ "recommendations": self._generate_monitoring_recommendations(),
}
-
+
def _generate_monitoring_recommendations(self) -> List[str]:
"""Generate recommendations based on monitoring data."""
recommendations = []
-
+
# Check for active critical alerts
- critical_alerts = [a for a in self.active_alerts.values() if a.severity == "critical"]
+ critical_alerts = [
+ a for a in self.active_alerts.values() if a.severity == "critical"
+ ]
if critical_alerts:
recommendations.append("Address critical performance alerts immediately")
-
+
# Check for degrading trends
- degrading_trends = [t for t in self.trends.values() if t.direction == "degrading"]
+ degrading_trends = [
+ t for t in self.trends.values() if t.direction == "degrading"
+ ]
if degrading_trends:
recommendations.append("Investigate degrading performance trends")
-
+
# General recommendations
if len(self.performance_history) < 100:
recommendations.append("Allow more time for comprehensive trend analysis")
-
+
if not recommendations:
recommendations.append("Performance monitoring shows healthy system status")
-
+
return recommendations
class PerformanceTracker:
"""
Lightweight performance tracking for specific operations.
-
+
Provides focused tracking for specific performance metrics
without full continuous monitoring overhead.
"""
-
+
def __init__(self):
"""Initialize performance tracker."""
self.tracked_operations: Dict[str, List[float]] = {}
self.operation_metadata: Dict[str, Dict[str, Any]] = {}
-
+
logger.info("Performance tracker initialized")
-
+
def track_operation(self, operation_name: str, duration_ms: float, **metadata):
"""
Track a single operation performance.
-
+
Args:
operation_name: Name of the operation
duration_ms: Operation duration in milliseconds
@@ -634,36 +671,40 @@ def track_operation(self, operation_name: str, duration_ms: float, **metadata):
if operation_name not in self.tracked_operations:
self.tracked_operations[operation_name] = []
self.operation_metadata[operation_name] = {}
-
+
self.tracked_operations[operation_name].append(duration_ms)
-
+
# Update metadata
for key, value in metadata.items():
if key not in self.operation_metadata[operation_name]:
self.operation_metadata[operation_name][key] = []
self.operation_metadata[operation_name][key].append(value)
-
+
# Keep only recent measurements (last 1000)
if len(self.tracked_operations[operation_name]) > 1000:
- self.tracked_operations[operation_name] = self.tracked_operations[operation_name][-1000:]
-
+ self.tracked_operations[operation_name] = self.tracked_operations[
+ operation_name
+ ][-1000:]
+
# Also trim metadata
for key in self.operation_metadata[operation_name]:
if len(self.operation_metadata[operation_name][key]) > 1000:
- self.operation_metadata[operation_name][key] = self.operation_metadata[operation_name][key][-1000:]
-
+ self.operation_metadata[operation_name][key] = (
+ self.operation_metadata[operation_name][key][-1000:]
+ )
+
def get_operation_stats(self, operation_name: str) -> Dict[str, Any]:
"""Get statistics for a tracked operation."""
if operation_name not in self.tracked_operations:
return {"error": "Operation not tracked"}
-
+
durations = self.tracked_operations[operation_name]
-
+
if not durations:
return {"error": "No data available"}
-
+
import statistics
-
+
stats = {
"operation_name": operation_name,
"total_executions": len(durations),
@@ -673,24 +714,30 @@ def get_operation_stats(self, operation_name: str) -> Dict[str, Any]:
"median_duration_ms": statistics.median(durations),
"std_deviation": statistics.stdev(durations) if len(durations) > 1 else 0,
}
-
+
# Add percentiles if enough data
if len(durations) >= 10:
- stats.update({
- "p90_duration_ms": statistics.quantiles(durations, n=10)[8],
- "p95_duration_ms": statistics.quantiles(durations, n=20)[18],
- "p99_duration_ms": statistics.quantiles(durations, n=100)[98] if len(durations) >= 100 else max(durations)
- })
-
+ stats.update(
+ {
+ "p90_duration_ms": statistics.quantiles(durations, n=10)[8],
+ "p95_duration_ms": statistics.quantiles(durations, n=20)[18],
+ "p99_duration_ms": (
+ statistics.quantiles(durations, n=100)[98]
+ if len(durations) >= 100
+ else max(durations)
+ ),
+ }
+ )
+
return stats
-
+
def get_all_stats(self) -> Dict[str, Dict[str, Any]]:
"""Get statistics for all tracked operations."""
return {
operation: self.get_operation_stats(operation)
for operation in self.tracked_operations.keys()
}
-
+
def clear_tracking(self, operation_name: Optional[str] = None):
"""Clear tracking data for specific operation or all operations."""
if operation_name:
@@ -699,5 +746,5 @@ def clear_tracking(self, operation_name: Optional[str] = None):
else:
self.tracked_operations.clear()
self.operation_metadata.clear()
-
- logger.info("Tracking data cleared", operation=operation_name or "all")
\ No newline at end of file
+
+ logger.info("Tracking data cleared", operation=operation_name or "all")
diff --git a/src/benchmarking/profiler.py b/src/benchmarking/profiler.py
index 429586a..b8f7ea7 100644
--- a/src/benchmarking/profiler.py
+++ b/src/benchmarking/profiler.py
@@ -22,11 +22,12 @@
# Fall back to absolute imports (when run as script)
import sys
from pathlib import Path
+
# Add src to path if not already there
src_path = str(Path(__file__).parent.parent)
if src_path not in sys.path:
sys.path.insert(0, src_path)
-
+
from cache.manager import CacheManager
from monitoring.metrics import get_metrics_collector
@@ -36,6 +37,7 @@
@dataclass
class ProfilePoint:
"""Individual profiling measurement point."""
+
name: str
start_time: float
end_time: float
@@ -48,6 +50,7 @@ class ProfilePoint:
@dataclass
class SystemSnapshot:
"""System resource snapshot."""
+
timestamp: float
cpu_percent: float
memory_percent: float
@@ -63,6 +66,7 @@ class SystemSnapshot:
@dataclass
class BottleneckAnalysis:
"""Analysis of system bottlenecks."""
+
component: str
severity: str # "low", "medium", "high", "critical"
impact_percentage: float
@@ -74,6 +78,7 @@ class BottleneckAnalysis:
@dataclass
class ProfileResult:
"""Complete profiling result."""
+
execution_time_ms: float
profile_points: List[ProfilePoint]
system_snapshots: List[SystemSnapshot]
@@ -86,15 +91,15 @@ class ProfileResult:
class SystemProfiler:
"""
Advanced system profiler for performance analysis.
-
+
Provides detailed profiling of system components, resource usage,
and performance bottleneck identification.
"""
-
+
def __init__(self, sampling_interval: float = 0.1):
"""
Initialize system profiler.
-
+
Args:
sampling_interval: Resource sampling interval in seconds
"""
@@ -104,24 +109,23 @@ def __init__(self, sampling_interval: float = 0.1):
self.active_profiles: Dict[str, float] = {}
self._monitoring = False
self._monitor_task: Optional[asyncio.Task] = None
-
+
# Performance thresholds
self.thresholds = {
"cpu_percent": 80.0,
"memory_percent": 85.0,
"response_time_ms": 100.0,
"cache_latency_ms": 50.0,
- "db_latency_ms": 10.0
+ "db_latency_ms": 10.0,
}
-
- logger.info("System profiler initialized",
- sampling_interval=sampling_interval)
-
+
+ logger.info("System profiler initialized", sampling_interval=sampling_interval)
+
@asynccontextmanager
async def profile_operation(self, operation_name: str, **metadata):
"""
Context manager for profiling operations.
-
+
Args:
operation_name: Name of the operation being profiled
**metadata: Additional metadata to include
@@ -129,18 +133,18 @@ async def profile_operation(self, operation_name: str, **metadata):
start_time = time.perf_counter()
start_cpu = psutil.cpu_percent()
start_memory = psutil.virtual_memory().used / (1024 * 1024)
-
+
try:
yield
finally:
end_time = time.perf_counter()
end_cpu = psutil.cpu_percent()
end_memory = psutil.virtual_memory().used / (1024 * 1024)
-
+
duration_ms = (end_time - start_time) * 1000
avg_cpu = (start_cpu + end_cpu) / 2
avg_memory = (start_memory + end_memory) / 2
-
+
profile_point = ProfilePoint(
name=operation_name,
start_time=start_time,
@@ -148,65 +152,67 @@ async def profile_operation(self, operation_name: str, **metadata):
duration_ms=duration_ms,
cpu_percent=avg_cpu,
memory_mb=avg_memory,
- metadata=metadata
+ metadata=metadata,
)
-
+
self.profile_points.append(profile_point)
-
- logger.debug("Operation profiled",
- operation=operation_name,
- duration_ms=duration_ms,
- cpu_percent=avg_cpu)
-
+
+ logger.debug(
+ "Operation profiled",
+ operation=operation_name,
+ duration_ms=duration_ms,
+ cpu_percent=avg_cpu,
+ )
+
async def start_continuous_monitoring(self):
"""Start continuous system resource monitoring."""
if self._monitoring:
return
-
+
self._monitoring = True
self._monitor_task = asyncio.create_task(self._monitor_resources())
-
+
logger.info("Started continuous monitoring")
-
+
async def stop_continuous_monitoring(self):
"""Stop continuous system resource monitoring."""
self._monitoring = False
-
+
if self._monitor_task:
self._monitor_task.cancel()
try:
await self._monitor_task
except asyncio.CancelledError:
pass
-
+
logger.info("Stopped continuous monitoring")
-
+
async def _monitor_resources(self):
"""Monitor system resources continuously."""
while self._monitoring:
try:
snapshot = self._take_system_snapshot()
self.system_snapshots.append(snapshot)
-
+
# Keep only recent snapshots (last 1000)
if len(self.system_snapshots) > 1000:
self.system_snapshots = self.system_snapshots[-1000:]
-
+
await asyncio.sleep(self.sampling_interval)
-
+
except asyncio.CancelledError:
break
except Exception as e:
logger.error("Error monitoring resources", error=str(e))
await asyncio.sleep(1.0)
-
+
def _take_system_snapshot(self) -> SystemSnapshot:
"""Take a snapshot of current system resources."""
try:
memory = psutil.virtual_memory()
disk_io = psutil.disk_io_counters()
network_io = psutil.net_io_counters()
-
+
return SystemSnapshot(
timestamp=time.time(),
cpu_percent=psutil.cpu_percent(),
@@ -214,177 +220,200 @@ def _take_system_snapshot(self) -> SystemSnapshot:
memory_available_mb=memory.available / (1024 * 1024),
disk_io_read_mb=disk_io.read_bytes / (1024 * 1024) if disk_io else 0,
disk_io_write_mb=disk_io.write_bytes / (1024 * 1024) if disk_io else 0,
- network_sent_mb=network_io.bytes_sent / (1024 * 1024) if network_io else 0,
- network_recv_mb=network_io.bytes_recv / (1024 * 1024) if network_io else 0,
+ network_sent_mb=(
+ network_io.bytes_sent / (1024 * 1024) if network_io else 0
+ ),
+ network_recv_mb=(
+ network_io.bytes_recv / (1024 * 1024) if network_io else 0
+ ),
process_count=len(psutil.pids()),
- thread_count=threading.active_count()
+ thread_count=threading.active_count(),
)
except Exception as e:
logger.error("Error taking system snapshot", error=str(e))
return SystemSnapshot(
timestamp=time.time(),
- cpu_percent=0, memory_percent=0, memory_available_mb=0,
- disk_io_read_mb=0, disk_io_write_mb=0,
- network_sent_mb=0, network_recv_mb=0,
- process_count=0, thread_count=0
+ cpu_percent=0,
+ memory_percent=0,
+ memory_available_mb=0,
+ disk_io_read_mb=0,
+ disk_io_write_mb=0,
+ network_sent_mb=0,
+ network_recv_mb=0,
+ process_count=0,
+ thread_count=0,
)
-
- async def profile_complete_operation(self,
- operation: Callable,
- operation_name: str,
- *args, **kwargs) -> Tuple[Any, ProfileResult]:
+
+ async def profile_complete_operation(
+ self, operation: Callable, operation_name: str, *args, **kwargs
+ ) -> Tuple[Any, ProfileResult]:
"""
Profile a complete operation with detailed analysis.
-
+
Args:
operation: Async operation to profile
operation_name: Name for the operation
*args, **kwargs: Arguments for the operation
-
+
Returns:
Tuple of (operation_result, profile_result)
"""
# Clear previous profiling data
self.profile_points.clear()
self.system_snapshots.clear()
-
+
# Start monitoring
await self.start_continuous_monitoring()
-
+
start_time = time.perf_counter()
-
+
try:
# Execute operation with profiling
async with self.profile_operation(operation_name):
result = await operation(*args, **kwargs)
-
+
end_time = time.perf_counter()
execution_time_ms = (end_time - start_time) * 1000
-
+
# Stop monitoring
await self.stop_continuous_monitoring()
-
+
# Analyze results
bottlenecks = self._analyze_bottlenecks()
performance_summary = self._generate_performance_summary()
recommendations = self._generate_optimization_recommendations(bottlenecks)
-
+
profile_result = ProfileResult(
execution_time_ms=execution_time_ms,
profile_points=self.profile_points.copy(),
system_snapshots=self.system_snapshots.copy(),
bottlenecks=bottlenecks,
performance_summary=performance_summary,
- optimization_recommendations=recommendations
+ optimization_recommendations=recommendations,
)
-
- logger.info("Operation profiling completed",
- operation=operation_name,
- execution_time_ms=execution_time_ms,
- bottlenecks_found=len(bottlenecks))
-
+
+ logger.info(
+ "Operation profiling completed",
+ operation=operation_name,
+ execution_time_ms=execution_time_ms,
+ bottlenecks_found=len(bottlenecks),
+ )
+
return result, profile_result
-
+
except Exception as e:
await self.stop_continuous_monitoring()
logger.error("Error during profiling", error=str(e))
raise
-
+
def _analyze_bottlenecks(self) -> List[BottleneckAnalysis]:
"""Analyze profiling data to identify bottlenecks."""
bottlenecks = []
-
+
# Analyze CPU bottlenecks
cpu_bottlenecks = self._analyze_cpu_bottlenecks()
bottlenecks.extend(cpu_bottlenecks)
-
+
# Analyze memory bottlenecks
memory_bottlenecks = self._analyze_memory_bottlenecks()
bottlenecks.extend(memory_bottlenecks)
-
+
# Analyze operation latency bottlenecks
latency_bottlenecks = self._analyze_latency_bottlenecks()
bottlenecks.extend(latency_bottlenecks)
-
+
# Analyze I/O bottlenecks
io_bottlenecks = self._analyze_io_bottlenecks()
bottlenecks.extend(io_bottlenecks)
-
+
# Sort by severity and impact
- bottlenecks.sort(key=lambda x: (
- {"critical": 4, "high": 3, "medium": 2, "low": 1}[x.severity],
- x.impact_percentage
- ), reverse=True)
-
+ bottlenecks.sort(
+ key=lambda x: (
+ {"critical": 4, "high": 3, "medium": 2, "low": 1}[x.severity],
+ x.impact_percentage,
+ ),
+ reverse=True,
+ )
+
return bottlenecks
-
+
def _analyze_cpu_bottlenecks(self) -> List[BottleneckAnalysis]:
"""Analyze CPU usage patterns for bottlenecks."""
bottlenecks = []
-
+
if not self.system_snapshots:
return bottlenecks
-
+
cpu_values = [s.cpu_percent for s in self.system_snapshots]
avg_cpu = sum(cpu_values) / len(cpu_values)
max_cpu = max(cpu_values)
-
+
if avg_cpu > self.thresholds["cpu_percent"]:
- severity = "critical" if avg_cpu > 95 else "high" if avg_cpu > 85 else "medium"
-
- bottlenecks.append(BottleneckAnalysis(
- component="CPU",
- severity=severity,
- impact_percentage=avg_cpu,
- description=f"High CPU utilization averaging {avg_cpu:.1f}%",
- recommendations=[
- "Consider optimizing CPU-intensive operations",
- "Implement asynchronous processing where possible",
- "Review algorithm efficiency",
- "Scale horizontally if sustained high load"
- ],
- metrics={"avg_cpu": avg_cpu, "max_cpu": max_cpu}
- ))
-
+ severity = (
+ "critical" if avg_cpu > 95 else "high" if avg_cpu > 85 else "medium"
+ )
+
+ bottlenecks.append(
+ BottleneckAnalysis(
+ component="CPU",
+ severity=severity,
+ impact_percentage=avg_cpu,
+ description=f"High CPU utilization averaging {avg_cpu:.1f}%",
+ recommendations=[
+ "Consider optimizing CPU-intensive operations",
+ "Implement asynchronous processing where possible",
+ "Review algorithm efficiency",
+ "Scale horizontally if sustained high load",
+ ],
+ metrics={"avg_cpu": avg_cpu, "max_cpu": max_cpu},
+ )
+ )
+
return bottlenecks
-
+
def _analyze_memory_bottlenecks(self) -> List[BottleneckAnalysis]:
"""Analyze memory usage patterns for bottlenecks."""
bottlenecks = []
-
+
if not self.system_snapshots:
return bottlenecks
-
+
memory_values = [s.memory_percent for s in self.system_snapshots]
avg_memory = sum(memory_values) / len(memory_values)
max_memory = max(memory_values)
-
+
if avg_memory > self.thresholds["memory_percent"]:
- severity = "critical" if avg_memory > 95 else "high" if avg_memory > 90 else "medium"
-
- bottlenecks.append(BottleneckAnalysis(
- component="Memory",
- severity=severity,
- impact_percentage=avg_memory,
- description=f"High memory utilization averaging {avg_memory:.1f}%",
- recommendations=[
- "Review cache size configuration",
- "Implement memory-efficient data structures",
- "Consider garbage collection tuning",
- "Monitor for memory leaks"
- ],
- metrics={"avg_memory": avg_memory, "max_memory": max_memory}
- ))
-
+ severity = (
+ "critical"
+ if avg_memory > 95
+ else "high" if avg_memory > 90 else "medium"
+ )
+
+ bottlenecks.append(
+ BottleneckAnalysis(
+ component="Memory",
+ severity=severity,
+ impact_percentage=avg_memory,
+ description=f"High memory utilization averaging {avg_memory:.1f}%",
+ recommendations=[
+ "Review cache size configuration",
+ "Implement memory-efficient data structures",
+ "Consider garbage collection tuning",
+ "Monitor for memory leaks",
+ ],
+ metrics={"avg_memory": avg_memory, "max_memory": max_memory},
+ )
+ )
+
return bottlenecks
-
+
def _analyze_latency_bottlenecks(self) -> List[BottleneckAnalysis]:
"""Analyze operation latency for bottlenecks."""
bottlenecks = []
-
+
if not self.profile_points:
return bottlenecks
-
+
# Group by operation type
operation_groups = {}
for point in self.profile_points:
@@ -392,92 +421,111 @@ def _analyze_latency_bottlenecks(self) -> List[BottleneckAnalysis]:
if op_type not in operation_groups:
operation_groups[op_type] = []
operation_groups[op_type].append(point.duration_ms)
-
+
for op_type, durations in operation_groups.items():
avg_duration = sum(durations) / len(durations)
max_duration = max(durations)
-
+
# Check against thresholds
- threshold = self.thresholds.get(f"{op_type.lower()}_latency_ms",
- self.thresholds["response_time_ms"])
-
+ threshold = self.thresholds.get(
+ f"{op_type.lower()}_latency_ms", self.thresholds["response_time_ms"]
+ )
+
if avg_duration > threshold:
severity = "critical" if avg_duration > threshold * 2 else "high"
impact = min(100, (avg_duration / threshold - 1) * 100)
-
- bottlenecks.append(BottleneckAnalysis(
- component=f"{op_type} Latency",
- severity=severity,
- impact_percentage=impact,
- description=f"{op_type} operations averaging {avg_duration:.1f}ms (threshold: {threshold}ms)",
- recommendations=[
- f"Optimize {op_type.lower()} operations",
- "Review database query efficiency",
- "Consider caching frequently accessed data",
- "Implement connection pooling"
- ],
- metrics={"avg_duration": avg_duration, "max_duration": max_duration, "threshold": threshold}
- ))
-
+
+ bottlenecks.append(
+ BottleneckAnalysis(
+ component=f"{op_type} Latency",
+ severity=severity,
+ impact_percentage=impact,
+ description=f"{op_type} operations averaging {avg_duration:.1f}ms (threshold: {threshold}ms)",
+ recommendations=[
+ f"Optimize {op_type.lower()} operations",
+ "Review database query efficiency",
+ "Consider caching frequently accessed data",
+ "Implement connection pooling",
+ ],
+ metrics={
+ "avg_duration": avg_duration,
+ "max_duration": max_duration,
+ "threshold": threshold,
+ },
+ )
+ )
+
return bottlenecks
-
+
def _analyze_io_bottlenecks(self) -> List[BottleneckAnalysis]:
"""Analyze I/O patterns for bottlenecks."""
bottlenecks = []
-
+
if len(self.system_snapshots) < 2:
return bottlenecks
-
+
# Calculate I/O rates
first_snapshot = self.system_snapshots[0]
last_snapshot = self.system_snapshots[-1]
time_diff = last_snapshot.timestamp - first_snapshot.timestamp
-
+
if time_diff <= 0:
return bottlenecks
-
- disk_read_rate = (last_snapshot.disk_io_read_mb - first_snapshot.disk_io_read_mb) / time_diff
- disk_write_rate = (last_snapshot.disk_io_write_mb - first_snapshot.disk_io_write_mb) / time_diff
-
+
+ disk_read_rate = (
+ last_snapshot.disk_io_read_mb - first_snapshot.disk_io_read_mb
+ ) / time_diff
+ disk_write_rate = (
+ last_snapshot.disk_io_write_mb - first_snapshot.disk_io_write_mb
+ ) / time_diff
+
# High disk I/O threshold (MB/s)
high_io_threshold = 50.0
-
+
if disk_read_rate > high_io_threshold:
- bottlenecks.append(BottleneckAnalysis(
- component="Disk I/O Read",
- severity="medium",
- impact_percentage=min(100, disk_read_rate / high_io_threshold * 100),
- description=f"High disk read rate: {disk_read_rate:.1f} MB/s",
- recommendations=[
- "Consider SSD storage for better performance",
- "Implement read caching",
- "Optimize database queries to reduce disk reads",
- "Review file access patterns"
- ],
- metrics={"read_rate_mbs": disk_read_rate}
- ))
-
+ bottlenecks.append(
+ BottleneckAnalysis(
+ component="Disk I/O Read",
+ severity="medium",
+ impact_percentage=min(
+ 100, disk_read_rate / high_io_threshold * 100
+ ),
+ description=f"High disk read rate: {disk_read_rate:.1f} MB/s",
+ recommendations=[
+ "Consider SSD storage for better performance",
+ "Implement read caching",
+ "Optimize database queries to reduce disk reads",
+ "Review file access patterns",
+ ],
+ metrics={"read_rate_mbs": disk_read_rate},
+ )
+ )
+
if disk_write_rate > high_io_threshold:
- bottlenecks.append(BottleneckAnalysis(
- component="Disk I/O Write",
- severity="medium",
- impact_percentage=min(100, disk_write_rate / high_io_threshold * 100),
- description=f"High disk write rate: {disk_write_rate:.1f} MB/s",
- recommendations=[
- "Implement write batching",
- "Consider asynchronous writes",
- "Review logging levels and output",
- "Optimize cache write-back policies"
- ],
- metrics={"write_rate_mbs": disk_write_rate}
- ))
-
+ bottlenecks.append(
+ BottleneckAnalysis(
+ component="Disk I/O Write",
+ severity="medium",
+ impact_percentage=min(
+ 100, disk_write_rate / high_io_threshold * 100
+ ),
+ description=f"High disk write rate: {disk_write_rate:.1f} MB/s",
+ recommendations=[
+ "Implement write batching",
+ "Consider asynchronous writes",
+ "Review logging levels and output",
+ "Optimize cache write-back policies",
+ ],
+ metrics={"write_rate_mbs": disk_write_rate},
+ )
+ )
+
return bottlenecks
-
+
def _generate_performance_summary(self) -> Dict[str, Any]:
"""Generate performance summary from profiling data."""
summary = {}
-
+
# Operation performance
if self.profile_points:
durations = [p.duration_ms for p in self.profile_points]
@@ -486,45 +534,53 @@ def _generate_performance_summary(self) -> Dict[str, Any]:
"avg_duration_ms": sum(durations) / len(durations),
"min_duration_ms": min(durations),
"max_duration_ms": max(durations),
- "total_duration_ms": sum(durations)
+ "total_duration_ms": sum(durations),
}
-
+
# Resource utilization
if self.system_snapshots:
cpu_values = [s.cpu_percent for s in self.system_snapshots]
memory_values = [s.memory_percent for s in self.system_snapshots]
-
+
summary["resources"] = {
"avg_cpu_percent": sum(cpu_values) / len(cpu_values),
"max_cpu_percent": max(cpu_values),
"avg_memory_percent": sum(memory_values) / len(memory_values),
"max_memory_percent": max(memory_values),
- "monitoring_duration_seconds": self.system_snapshots[-1].timestamp - self.system_snapshots[0].timestamp
+ "monitoring_duration_seconds": self.system_snapshots[-1].timestamp
+ - self.system_snapshots[0].timestamp,
}
-
+
return summary
-
- def _generate_optimization_recommendations(self,
- bottlenecks: List[BottleneckAnalysis]) -> List[str]:
+
+ def _generate_optimization_recommendations(
+ self, bottlenecks: List[BottleneckAnalysis]
+ ) -> List[str]:
"""Generate optimization recommendations based on bottlenecks."""
recommendations = []
-
+
# High-priority recommendations based on critical bottlenecks
critical_bottlenecks = [b for b in bottlenecks if b.severity == "critical"]
if critical_bottlenecks:
- recommendations.append("Address critical performance bottlenecks immediately")
+ recommendations.append(
+ "Address critical performance bottlenecks immediately"
+ )
for bottleneck in critical_bottlenecks:
- recommendations.extend(bottleneck.recommendations[:2]) # Top 2 recommendations
-
+ recommendations.extend(
+ bottleneck.recommendations[:2]
+ ) # Top 2 recommendations
+
# General optimization recommendations
- recommendations.extend([
- "Implement comprehensive monitoring and alerting",
- "Consider performance testing under various load conditions",
- "Review and optimize database queries and indexes",
- "Implement efficient caching strategies",
- "Consider horizontal scaling for high-load scenarios"
- ])
-
+ recommendations.extend(
+ [
+ "Implement comprehensive monitoring and alerting",
+ "Consider performance testing under various load conditions",
+ "Review and optimize database queries and indexes",
+ "Implement efficient caching strategies",
+ "Consider horizontal scaling for high-load scenarios",
+ ]
+ )
+
# Remove duplicates while preserving order
seen = set()
unique_recommendations = []
@@ -532,126 +588,148 @@ def _generate_optimization_recommendations(self,
if rec not in seen:
seen.add(rec)
unique_recommendations.append(rec)
-
+
return unique_recommendations[:10] # Return top 10 recommendations
class BottleneckAnalyzer:
"""
Specialized analyzer for identifying and categorizing bottlenecks.
-
+
Provides advanced analysis of system bottlenecks with actionable
optimization recommendations.
"""
-
+
def __init__(self):
"""Initialize bottleneck analyzer."""
self.analysis_history: List[BottleneckAnalysis] = []
logger.info("Bottleneck analyzer initialized")
-
- def analyze_cache_performance(self, cache_manager: CacheManager) -> List[BottleneckAnalysis]:
+
+ def analyze_cache_performance(
+ self, cache_manager: CacheManager
+ ) -> List[BottleneckAnalysis]:
"""Analyze cache performance for bottlenecks."""
bottlenecks = []
-
+
try:
metrics = cache_manager.get_metrics()
-
+
# Cache hit rate analysis
if metrics.hit_rate < 60.0: # Below target
severity = "high" if metrics.hit_rate < 40 else "medium"
-
- bottlenecks.append(BottleneckAnalysis(
- component="Cache Hit Rate",
- severity=severity,
- impact_percentage=100 - metrics.hit_rate,
- description=f"Cache hit rate {metrics.hit_rate:.1f}% below optimal",
- recommendations=[
- "Review cache warming strategies",
- "Increase cache size if memory allows",
- "Optimize cache eviction policies",
- "Analyze query patterns for better caching"
- ],
- metrics={"hit_rate": metrics.hit_rate, "target": 60.0}
- ))
-
+
+ bottlenecks.append(
+ BottleneckAnalysis(
+ component="Cache Hit Rate",
+ severity=severity,
+ impact_percentage=100 - metrics.hit_rate,
+ description=f"Cache hit rate {metrics.hit_rate:.1f}% below optimal",
+ recommendations=[
+ "Review cache warming strategies",
+ "Increase cache size if memory allows",
+ "Optimize cache eviction policies",
+ "Analyze query patterns for better caching",
+ ],
+ metrics={"hit_rate": metrics.hit_rate, "target": 60.0},
+ )
+ )
+
# Cache memory utilization
if metrics.total_size > cache_manager.max_size_bytes * 0.9:
- bottlenecks.append(BottleneckAnalysis(
- component="Cache Memory",
- severity="medium",
- impact_percentage=90.0,
- description="Cache approaching memory limit",
- recommendations=[
- "Increase cache size allocation",
- "Implement more aggressive eviction",
- "Review cached content for optimization",
- "Consider distributed caching"
- ],
- metrics={"utilization": metrics.total_size / cache_manager.max_size_bytes}
- ))
-
+ bottlenecks.append(
+ BottleneckAnalysis(
+ component="Cache Memory",
+ severity="medium",
+ impact_percentage=90.0,
+ description="Cache approaching memory limit",
+ recommendations=[
+ "Increase cache size allocation",
+ "Implement more aggressive eviction",
+ "Review cached content for optimization",
+ "Consider distributed caching",
+ ],
+ metrics={
+ "utilization": metrics.total_size
+ / cache_manager.max_size_bytes
+ },
+ )
+ )
+
except Exception as e:
logger.error("Error analyzing cache performance", error=str(e))
-
+
return bottlenecks
-
- def analyze_query_patterns(self, recent_queries: List[str]) -> List[BottleneckAnalysis]:
+
+ def analyze_query_patterns(
+ self, recent_queries: List[str]
+ ) -> List[BottleneckAnalysis]:
"""Analyze query patterns for potential bottlenecks."""
bottlenecks = []
-
+
if not recent_queries:
return bottlenecks
-
+
# Analyze query complexity
complex_queries = [q for q in recent_queries if len(q) > 200]
if len(complex_queries) > len(recent_queries) * 0.3: # >30% complex queries
- bottlenecks.append(BottleneckAnalysis(
- component="Query Complexity",
- severity="medium",
- impact_percentage=len(complex_queries) / len(recent_queries) * 100,
- description=f"{len(complex_queries)} of {len(recent_queries)} queries are complex",
- recommendations=[
- "Break down complex queries into simpler components",
- "Implement query preprocessing and optimization",
- "Consider query result caching",
- "Review query construction patterns"
- ],
- metrics={"complex_ratio": len(complex_queries) / len(recent_queries)}
- ))
-
+ bottlenecks.append(
+ BottleneckAnalysis(
+ component="Query Complexity",
+ severity="medium",
+ impact_percentage=len(complex_queries) / len(recent_queries) * 100,
+ description=f"{len(complex_queries)} of {len(recent_queries)} queries are complex",
+ recommendations=[
+ "Break down complex queries into simpler components",
+ "Implement query preprocessing and optimization",
+ "Consider query result caching",
+ "Review query construction patterns",
+ ],
+ metrics={
+ "complex_ratio": len(complex_queries) / len(recent_queries)
+ },
+ )
+ )
+
return bottlenecks
-
- def generate_bottleneck_report(self,
- bottlenecks: List[BottleneckAnalysis]) -> Dict[str, Any]:
+
+ def generate_bottleneck_report(
+ self, bottlenecks: List[BottleneckAnalysis]
+ ) -> Dict[str, Any]:
"""Generate comprehensive bottleneck analysis report."""
if not bottlenecks:
return {
"summary": "No significant bottlenecks detected",
"total_bottlenecks": 0,
"severity_breakdown": {},
- "recommendations": ["Continue monitoring for performance trends"]
+ "recommendations": ["Continue monitoring for performance trends"],
}
-
+
# Severity breakdown
severity_counts = {}
for bottleneck in bottlenecks:
- severity_counts[bottleneck.severity] = severity_counts.get(bottleneck.severity, 0) + 1
-
+ severity_counts[bottleneck.severity] = (
+ severity_counts.get(bottleneck.severity, 0) + 1
+ )
+
# Top recommendations
all_recommendations = []
for bottleneck in bottlenecks:
all_recommendations.extend(bottleneck.recommendations)
-
+
# Remove duplicates and prioritize
unique_recommendations = list(dict.fromkeys(all_recommendations))
-
+
report = {
"timestamp": time.time(),
"summary": f"Identified {len(bottlenecks)} performance bottlenecks",
"total_bottlenecks": len(bottlenecks),
"severity_breakdown": severity_counts,
- "critical_components": [b.component for b in bottlenecks if b.severity == "critical"],
- "high_impact_components": [b.component for b in bottlenecks if b.impact_percentage > 70],
+ "critical_components": [
+ b.component for b in bottlenecks if b.severity == "critical"
+ ],
+ "high_impact_components": [
+ b.component for b in bottlenecks if b.impact_percentage > 70
+ ],
"recommendations": unique_recommendations[:10],
"detailed_analysis": [
{
@@ -659,10 +737,10 @@ def generate_bottleneck_report(self,
"severity": b.severity,
"impact": b.impact_percentage,
"description": b.description,
- "top_recommendations": b.recommendations[:3]
+ "top_recommendations": b.recommendations[:3],
}
for b in bottlenecks
- ]
+ ],
}
-
- return report
\ No newline at end of file
+
+ return report
diff --git a/src/benchmarking/visualization.py b/src/benchmarking/visualization.py
index 4f641db..a2d9b97 100644
--- a/src/benchmarking/visualization.py
+++ b/src/benchmarking/visualization.py
@@ -22,11 +22,12 @@
# Fall back to absolute imports (when run as script)
import sys
from pathlib import Path
+
# Add src to path if not already there
src_path = str(Path(__file__).parent.parent)
if src_path not in sys.path:
sys.path.insert(0, src_path)
-
+
from benchmarking.framework import BenchmarkResult, BenchmarkSummary
from benchmarking.comparisons import ComparisonResult
from benchmarking.profiler import ProfileResult, BottleneckAnalysis
@@ -38,6 +39,7 @@
@dataclass
class ChartData:
"""Data structure for chart generation."""
+
chart_type: str # "line", "bar", "scatter", "histogram"
title: str
x_label: str
@@ -49,6 +51,7 @@ class ChartData:
@dataclass
class ReportSection:
"""Individual section of a benchmark report."""
+
section_id: str
title: str
content_type: str # "text", "table", "chart", "metrics"
@@ -59,6 +62,7 @@ class ReportSection:
@dataclass
class BenchmarkReport:
"""Complete benchmark report structure."""
+
report_id: str
title: str
generated_at: float
@@ -72,49 +76,49 @@ class BenchmarkReport:
class BenchmarkVisualizer:
"""
Comprehensive visualization system for benchmark data.
-
+
Generates charts, tables, and visual reports for performance analysis.
"""
-
+
def __init__(self):
"""Initialize benchmark visualizer."""
self.color_scheme = {
- "fact": "#2E8B57", # Sea Green
- "rag": "#CD5C5C", # Indian Red
+ "fact": "#2E8B57", # Sea Green
+ "rag": "#CD5C5C", # Indian Red
"cache_hit": "#32CD32", # Lime Green
- "cache_miss": "#FF6347", # Tomato
- "warning": "#FFD700", # Gold
- "critical": "#DC143C", # Crimson
- "success": "#228B22", # Forest Green
- "neutral": "#708090" # Slate Gray
+ "cache_miss": "#FF6347", # Tomato
+ "warning": "#FFD700", # Gold
+ "critical": "#DC143C", # Crimson
+ "success": "#228B22", # Forest Green
+ "neutral": "#708090", # Slate Gray
}
-
+
logger.info("Benchmark visualizer initialized")
-
- def create_latency_comparison_chart(self,
- fact_results: List[BenchmarkResult],
- rag_results: List[BenchmarkResult]) -> ChartData:
+
+ def create_latency_comparison_chart(
+ self, fact_results: List[BenchmarkResult], rag_results: List[BenchmarkResult]
+ ) -> ChartData:
"""Create latency comparison chart between FACT and RAG."""
# Extract latency data
fact_latencies = [r.response_time_ms for r in fact_results if r.success]
rag_latencies = [r.response_time_ms for r in rag_results if r.success]
-
+
# Create comparison data
data_series = [
{
"name": "FACT System",
"data": fact_latencies,
"color": self.color_scheme["fact"],
- "type": "box_plot"
+ "type": "box_plot",
},
{
"name": "Traditional RAG",
"data": rag_latencies,
- "color": self.color_scheme["rag"],
- "type": "box_plot"
- }
+ "color": self.color_scheme["rag"],
+ "type": "box_plot",
+ },
]
-
+
return ChartData(
chart_type="box_plot",
title="Response Time Comparison: FACT vs Traditional RAG",
@@ -122,39 +126,64 @@ def create_latency_comparison_chart(self,
y_label="Response Time (ms)",
data_series=data_series,
metadata={
- "fact_avg": sum(fact_latencies) / len(fact_latencies) if fact_latencies else 0,
- "rag_avg": sum(rag_latencies) / len(rag_latencies) if rag_latencies else 0,
- "improvement_factor": (sum(rag_latencies) / len(rag_latencies)) / (sum(fact_latencies) / len(fact_latencies)) if fact_latencies and rag_latencies else 0
- }
+ "fact_avg": (
+ sum(fact_latencies) / len(fact_latencies) if fact_latencies else 0
+ ),
+ "rag_avg": (
+ sum(rag_latencies) / len(rag_latencies) if rag_latencies else 0
+ ),
+ "improvement_factor": (
+ (sum(rag_latencies) / len(rag_latencies))
+ / (sum(fact_latencies) / len(fact_latencies))
+ if fact_latencies and rag_latencies
+ else 0
+ ),
+ },
)
-
+
def create_cache_performance_chart(self, benchmark_summary) -> ChartData:
"""Create cache hit/miss performance visualization."""
# Handle both BenchmarkSummary and List[BenchmarkResult]
- if hasattr(benchmark_summary, 'avg_hit_latency_ms'):
+ if hasattr(benchmark_summary, "avg_hit_latency_ms"):
# BenchmarkSummary object
- hit_latencies = [benchmark_summary.avg_hit_latency_ms] if benchmark_summary.cache_hits > 0 else []
- miss_latencies = [benchmark_summary.avg_miss_latency_ms] if benchmark_summary.cache_misses > 0 else []
+ hit_latencies = (
+ [benchmark_summary.avg_hit_latency_ms]
+ if benchmark_summary.cache_hits > 0
+ else []
+ )
+ miss_latencies = (
+ [benchmark_summary.avg_miss_latency_ms]
+ if benchmark_summary.cache_misses > 0
+ else []
+ )
else:
# List of BenchmarkResult objects
- hit_latencies = [r.response_time_ms for r in benchmark_summary if r.success and r.cache_hit]
- miss_latencies = [r.response_time_ms for r in benchmark_summary if r.success and not r.cache_hit]
-
+ hit_latencies = [
+ r.response_time_ms
+ for r in benchmark_summary
+ if r.success and r.cache_hit
+ ]
+ miss_latencies = [
+ r.response_time_ms
+ for r in benchmark_summary
+ if r.success and not r.cache_hit
+ ]
+
data_series = [
{
"name": "Cache Hits",
"data": hit_latencies,
"color": self.color_scheme["cache_hit"],
- "type": "histogram"
+ "type": "histogram",
},
{
- "name": "Cache Misses",
+ "name": "Cache Misses",
"data": miss_latencies,
"color": self.color_scheme["cache_miss"],
- "type": "histogram"
- }
+ "type": "histogram",
+ },
]
-
+
return ChartData(
chart_type="histogram",
title="Response Time Distribution: Cache Hits vs Misses",
@@ -162,41 +191,55 @@ def create_cache_performance_chart(self, benchmark_summary) -> ChartData:
y_label="Frequency",
data_series=data_series,
metadata={
- "hit_avg": sum(hit_latencies) / len(hit_latencies) if hit_latencies else 0,
- "miss_avg": sum(miss_latencies) / len(miss_latencies) if miss_latencies else 0,
+ "hit_avg": (
+ sum(hit_latencies) / len(hit_latencies) if hit_latencies else 0
+ ),
+ "miss_avg": (
+ sum(miss_latencies) / len(miss_latencies) if miss_latencies else 0
+ ),
"hit_count": len(hit_latencies),
"miss_count": len(miss_latencies),
- "cache_hit_rate": len(hit_latencies) / (len(hit_latencies) + len(miss_latencies)) * 100 if (len(hit_latencies) + len(miss_latencies)) > 0 else 0
- }
+ "cache_hit_rate": (
+ len(hit_latencies)
+ / (len(hit_latencies) + len(miss_latencies))
+ * 100
+ if (len(hit_latencies) + len(miss_latencies)) > 0
+ else 0
+ ),
+ },
)
-
- def create_performance_timeline_chart(self, results: List[BenchmarkResult]) -> ChartData:
+
+ def create_performance_timeline_chart(
+ self, results: List[BenchmarkResult]
+ ) -> ChartData:
"""Create performance over time timeline chart."""
# Sort results by timestamp
sorted_results = sorted(results, key=lambda r: r.timestamp)
-
+
# Create time series data
timestamps = [r.timestamp for r in sorted_results]
response_times = [r.response_time_ms for r in sorted_results]
cache_hits = [1 if r.cache_hit else 0 for r in sorted_results]
-
+
data_series = [
{
"name": "Response Time",
"data": list(zip(timestamps, response_times)),
"color": self.color_scheme["fact"],
"type": "line",
- "y_axis": "left"
+ "y_axis": "left",
},
{
"name": "Cache Hit Rate",
- "data": list(zip(timestamps, self._calculate_rolling_cache_hit_rate(cache_hits))),
+ "data": list(
+ zip(timestamps, self._calculate_rolling_cache_hit_rate(cache_hits))
+ ),
"color": self.color_scheme["cache_hit"],
- "type": "line",
- "y_axis": "right"
- }
+ "type": "line",
+ "y_axis": "right",
+ },
]
-
+
return ChartData(
chart_type="line",
title="Performance Timeline",
@@ -205,29 +248,41 @@ def create_performance_timeline_chart(self, results: List[BenchmarkResult]) -> C
data_series=data_series,
metadata={
"total_measurements": len(results),
- "time_span_hours": (max(timestamps) - min(timestamps)) / 3600 if timestamps else 0
- }
+ "time_span_hours": (
+ (max(timestamps) - min(timestamps)) / 3600 if timestamps else 0
+ ),
+ },
)
-
- def create_cost_comparison_chart(self, comparison_result: ComparisonResult) -> ChartData:
+
+ def create_cost_comparison_chart(
+ self, comparison_result: ComparisonResult
+ ) -> ChartData:
"""Create cost comparison visualization."""
fact_cost = comparison_result.fact_metrics.avg_token_cost
rag_cost = comparison_result.rag_metrics.avg_token_cost
-
+
data_series = [
{
"name": "Cost Comparison",
"data": [
- {"category": "FACT System", "value": fact_cost, "color": self.color_scheme["fact"]},
- {"category": "Traditional RAG", "value": rag_cost, "color": self.color_scheme["rag"]}
+ {
+ "category": "FACT System",
+ "value": fact_cost,
+ "color": self.color_scheme["fact"],
+ },
+ {
+ "category": "Traditional RAG",
+ "value": rag_cost,
+ "color": self.color_scheme["rag"],
+ },
],
- "type": "bar"
+ "type": "bar",
}
]
-
+
savings = rag_cost - fact_cost if rag_cost > fact_cost else 0
savings_percentage = (savings / rag_cost * 100) if rag_cost > 0 else 0
-
+
return ChartData(
chart_type="bar",
title="Token Cost Comparison per Query",
@@ -238,11 +293,13 @@ def create_cost_comparison_chart(self, comparison_result: ComparisonResult) -> C
"fact_cost": fact_cost,
"rag_cost": rag_cost,
"absolute_savings": savings,
- "percentage_savings": savings_percentage
- }
+ "percentage_savings": savings_percentage,
+ },
)
-
- def create_bottleneck_analysis_chart(self, bottlenecks: List[BottleneckAnalysis]) -> ChartData:
+
+ def create_bottleneck_analysis_chart(
+ self, bottlenecks: List[BottleneckAnalysis]
+ ) -> ChartData:
"""Create bottleneck analysis visualization."""
if not bottlenecks:
return ChartData(
@@ -251,20 +308,22 @@ def create_bottleneck_analysis_chart(self, bottlenecks: List[BottleneckAnalysis]
x_label="Components",
y_label="Impact (%)",
data_series=[],
- metadata={"status": "healthy"}
+ metadata={"status": "healthy"},
)
-
+
# Sort bottlenecks by impact
- sorted_bottlenecks = sorted(bottlenecks, key=lambda b: b.impact_percentage, reverse=True)
-
+ sorted_bottlenecks = sorted(
+ bottlenecks, key=lambda b: b.impact_percentage, reverse=True
+ )
+
# Create severity color mapping
severity_colors = {
"critical": self.color_scheme["critical"],
"high": "#FF8C00", # Dark Orange
"medium": self.color_scheme["warning"],
- "low": "#90EE90" # Light Green
+ "low": "#90EE90", # Light Green
}
-
+
data_series = [
{
"name": "Bottleneck Impact",
@@ -272,15 +331,17 @@ def create_bottleneck_analysis_chart(self, bottlenecks: List[BottleneckAnalysis]
{
"category": b.component,
"value": b.impact_percentage,
- "color": severity_colors.get(b.severity, self.color_scheme["neutral"]),
- "severity": b.severity
+ "color": severity_colors.get(
+ b.severity, self.color_scheme["neutral"]
+ ),
+ "severity": b.severity,
}
for b in sorted_bottlenecks[:10] # Top 10 bottlenecks
],
- "type": "bar"
+ "type": "bar",
}
]
-
+
return ChartData(
chart_type="bar",
title="Performance Bottleneck Analysis",
@@ -289,23 +350,27 @@ def create_bottleneck_analysis_chart(self, bottlenecks: List[BottleneckAnalysis]
data_series=data_series,
metadata={
"total_bottlenecks": len(bottlenecks),
- "critical_count": sum(1 for b in bottlenecks if b.severity == "critical"),
- "high_count": sum(1 for b in bottlenecks if b.severity == "high")
- }
+ "critical_count": sum(
+ 1 for b in bottlenecks if b.severity == "critical"
+ ),
+ "high_count": sum(1 for b in bottlenecks if b.severity == "high"),
+ },
)
-
- def _calculate_rolling_cache_hit_rate(self, cache_hits: List[int], window_size: int = 10) -> List[float]:
+
+ def _calculate_rolling_cache_hit_rate(
+ self, cache_hits: List[int], window_size: int = 10
+ ) -> List[float]:
"""Calculate rolling cache hit rate."""
rolling_rates = []
-
+
for i in range(len(cache_hits)):
start_idx = max(0, i - window_size + 1)
- window_data = cache_hits[start_idx:i + 1]
+ window_data = cache_hits[start_idx : i + 1]
hit_rate = sum(window_data) / len(window_data) * 100
rolling_rates.append(hit_rate)
-
+
return rolling_rates
-
+
def export_chart_data_json(self, chart: ChartData) -> str:
"""Export chart data as JSON for external visualization tools."""
return json.dumps(asdict(chart), indent=2, default=str)
@@ -314,76 +379,94 @@ def export_chart_data_json(self, chart: ChartData) -> str:
class ReportGenerator:
"""
Comprehensive report generation system.
-
+
Creates detailed HTML, JSON, and text reports from benchmark data.
"""
-
+
def __init__(self, visualizer: Optional[BenchmarkVisualizer] = None):
"""
Initialize report generator.
-
+
Args:
visualizer: Benchmark visualizer instance
"""
self.visualizer = visualizer or BenchmarkVisualizer()
-
+
logger.info("Report generator initialized")
-
- def generate_comprehensive_report(self,
- benchmark_summary: BenchmarkSummary,
- comparison_result: Optional[ComparisonResult] = None,
- profile_result: Optional[ProfileResult] = None,
- alerts: Optional[List[PerformanceAlert]] = None) -> BenchmarkReport:
+
+ def generate_comprehensive_report(
+ self,
+ benchmark_summary: BenchmarkSummary,
+ comparison_result: Optional[ComparisonResult] = None,
+ profile_result: Optional[ProfileResult] = None,
+ alerts: Optional[List[PerformanceAlert]] = None,
+ ) -> BenchmarkReport:
"""
Generate comprehensive benchmark report.
-
+
Args:
benchmark_summary: Main benchmark results
comparison_result: Optional comparison with RAG
profile_result: Optional profiling results
alerts: Optional performance alerts
-
+
Returns:
Complete benchmark report
"""
report_id = f"fact_benchmark_{int(time.time())}"
-
+
# Generate report sections
sections = []
charts = []
-
+
# Executive Summary
- sections.append(self._create_executive_summary_section(benchmark_summary, comparison_result))
-
+ sections.append(
+ self._create_executive_summary_section(benchmark_summary, comparison_result)
+ )
+
# Performance Metrics Section
sections.append(self._create_performance_metrics_section(benchmark_summary))
-
+
# Target Compliance Section
sections.append(self._create_target_compliance_section(benchmark_summary))
-
+
# Comparison Analysis (if available)
if comparison_result:
sections.append(self._create_comparison_analysis_section(comparison_result))
- charts.append(self.visualizer.create_cost_comparison_chart(comparison_result))
-
+ charts.append(
+ self.visualizer.create_cost_comparison_chart(comparison_result)
+ )
+
# Bottleneck Analysis (if available)
if profile_result:
sections.append(self._create_bottleneck_analysis_section(profile_result))
- charts.append(self.visualizer.create_bottleneck_analysis_chart(profile_result.bottlenecks))
-
+ charts.append(
+ self.visualizer.create_bottleneck_analysis_chart(
+ profile_result.bottlenecks
+ )
+ )
+
# Alerts Section (if available)
if alerts:
sections.append(self._create_alerts_section(alerts))
-
+
# Recommendations Section
- sections.append(self._create_recommendations_section(benchmark_summary, comparison_result, profile_result))
-
+ sections.append(
+ self._create_recommendations_section(
+ benchmark_summary, comparison_result, profile_result
+ )
+ )
+
# Generate overall recommendations
- recommendations = self._generate_overall_recommendations(benchmark_summary, comparison_result, profile_result, alerts)
-
+ recommendations = self._generate_overall_recommendations(
+ benchmark_summary, comparison_result, profile_result, alerts
+ )
+
# Create summary
- summary = self._create_report_summary(benchmark_summary, comparison_result, profile_result)
-
+ summary = self._create_report_summary(
+ benchmark_summary, comparison_result, profile_result
+ )
+
return BenchmarkReport(
report_id=report_id,
title="FACT System Performance Benchmark Report",
@@ -394,25 +477,39 @@ def generate_comprehensive_report(self,
recommendations=recommendations,
raw_data={
"benchmark_summary": asdict(benchmark_summary),
- "comparison_result": asdict(comparison_result) if comparison_result else None,
+ "comparison_result": (
+ asdict(comparison_result) if comparison_result else None
+ ),
"profile_result": asdict(profile_result) if profile_result else None,
- "alerts": [asdict(alert) for alert in alerts] if alerts else None
- }
+ "alerts": [asdict(alert) for alert in alerts] if alerts else None,
+ },
)
-
- def _create_executive_summary_section(self,
- benchmark_summary: BenchmarkSummary,
- comparison_result: Optional[ComparisonResult]) -> ReportSection:
+
+ def _create_executive_summary_section(
+ self,
+ benchmark_summary: BenchmarkSummary,
+ comparison_result: Optional[ComparisonResult],
+ ) -> ReportSection:
"""Create executive summary section."""
# Calculate key metrics
- overall_performance = "EXCELLENT" if benchmark_summary.avg_response_time_ms < 50 else \
- "GOOD" if benchmark_summary.avg_response_time_ms < 100 else \
- "NEEDS_IMPROVEMENT"
-
- cache_performance = "EXCELLENT" if benchmark_summary.cache_hit_rate > 70 else \
- "GOOD" if benchmark_summary.cache_hit_rate > 50 else \
- "NEEDS_IMPROVEMENT"
-
+ overall_performance = (
+ "EXCELLENT"
+ if benchmark_summary.avg_response_time_ms < 50
+ else (
+ "GOOD"
+ if benchmark_summary.avg_response_time_ms < 100
+ else "NEEDS_IMPROVEMENT"
+ )
+ )
+
+ cache_performance = (
+ "EXCELLENT"
+ if benchmark_summary.cache_hit_rate > 70
+ else (
+ "GOOD" if benchmark_summary.cache_hit_rate > 50 else "NEEDS_IMPROVEMENT"
+ )
+ )
+
content = {
"overall_assessment": overall_performance,
"cache_assessment": cache_performance,
@@ -420,32 +517,38 @@ def _create_executive_summary_section(self,
"avg_response_time_ms": benchmark_summary.avg_response_time_ms,
"cache_hit_rate": benchmark_summary.cache_hit_rate,
"error_rate": benchmark_summary.error_rate,
- "throughput_qps": benchmark_summary.throughput_qps
+ "throughput_qps": benchmark_summary.throughput_qps,
},
"target_compliance": {
"hit_latency_target_met": benchmark_summary.hit_latency_target_met,
"miss_latency_target_met": benchmark_summary.miss_latency_target_met,
"cost_reduction_target_met": benchmark_summary.cost_reduction_target_met,
- "cache_hit_rate_target_met": benchmark_summary.cache_hit_rate_target_met
- }
+ "cache_hit_rate_target_met": benchmark_summary.cache_hit_rate_target_met,
+ },
}
-
+
if comparison_result:
content["improvement_summary"] = {
- "latency_improvement": comparison_result.improvement_factors.get("latency", 1.0),
- "cost_savings_percentage": comparison_result.cost_savings.get("percentage", 0.0),
- "recommendation": comparison_result.recommendation
+ "latency_improvement": comparison_result.improvement_factors.get(
+ "latency", 1.0
+ ),
+ "cost_savings_percentage": comparison_result.cost_savings.get(
+ "percentage", 0.0
+ ),
+ "recommendation": comparison_result.recommendation,
}
-
+
return ReportSection(
section_id="executive_summary",
title="Executive Summary",
content_type="metrics",
content=content,
- priority=1
+ priority=1,
)
-
- def _create_performance_metrics_section(self, benchmark_summary: BenchmarkSummary) -> ReportSection:
+
+ def _create_performance_metrics_section(
+ self, benchmark_summary: BenchmarkSummary
+ ) -> ReportSection:
"""Create detailed performance metrics section."""
content = {
"latency_metrics": {
@@ -454,84 +557,106 @@ def _create_performance_metrics_section(self, benchmark_summary: BenchmarkSummar
"maximum_ms": benchmark_summary.max_response_time_ms,
"p50_ms": benchmark_summary.p50_response_time_ms,
"p95_ms": benchmark_summary.p95_response_time_ms,
- "p99_ms": benchmark_summary.p99_response_time_ms
+ "p99_ms": benchmark_summary.p99_response_time_ms,
},
"cache_metrics": {
"hit_rate_percent": benchmark_summary.cache_hit_rate,
"hit_count": benchmark_summary.cache_hits,
"miss_count": benchmark_summary.cache_misses,
"avg_hit_latency_ms": benchmark_summary.avg_hit_latency_ms,
- "avg_miss_latency_ms": benchmark_summary.avg_miss_latency_ms
+ "avg_miss_latency_ms": benchmark_summary.avg_miss_latency_ms,
},
"reliability_metrics": {
- "success_rate_percent": (benchmark_summary.successful_queries / benchmark_summary.total_queries * 100) if benchmark_summary.total_queries > 0 else 0,
+ "success_rate_percent": (
+ (
+ benchmark_summary.successful_queries
+ / benchmark_summary.total_queries
+ * 100
+ )
+ if benchmark_summary.total_queries > 0
+ else 0
+ ),
"error_rate_percent": benchmark_summary.error_rate,
"total_queries": benchmark_summary.total_queries,
- "failed_queries": benchmark_summary.failed_queries
+ "failed_queries": benchmark_summary.failed_queries,
},
"cost_metrics": {
"total_cost_usd": benchmark_summary.total_token_cost,
"avg_cost_per_query_usd": benchmark_summary.avg_token_cost,
"estimated_savings_usd": benchmark_summary.estimated_savings,
- "cost_reduction_percentage": benchmark_summary.cost_reduction_percentage
- }
+ "cost_reduction_percentage": benchmark_summary.cost_reduction_percentage,
+ },
}
-
+
return ReportSection(
section_id="performance_metrics",
title="Detailed Performance Metrics",
content_type="table",
content=content,
- priority=1
+ priority=1,
)
-
- def _create_target_compliance_section(self, benchmark_summary: BenchmarkSummary) -> ReportSection:
+
+ def _create_target_compliance_section(
+ self, benchmark_summary: BenchmarkSummary
+ ) -> ReportSection:
"""Create target compliance analysis section."""
targets = {
"cache_hit_latency": {
"target_ms": 48.0,
"actual_ms": benchmark_summary.avg_hit_latency_ms,
"met": benchmark_summary.hit_latency_target_met,
- "status": "PASS" if benchmark_summary.hit_latency_target_met else "FAIL"
+ "status": (
+ "PASS" if benchmark_summary.hit_latency_target_met else "FAIL"
+ ),
},
"cache_miss_latency": {
"target_ms": 140.0,
"actual_ms": benchmark_summary.avg_miss_latency_ms,
"met": benchmark_summary.miss_latency_target_met,
- "status": "PASS" if benchmark_summary.miss_latency_target_met else "FAIL"
+ "status": (
+ "PASS" if benchmark_summary.miss_latency_target_met else "FAIL"
+ ),
},
"cost_reduction": {
"target_percent": 75.0,
"actual_percent": benchmark_summary.cost_reduction_percentage,
"met": benchmark_summary.cost_reduction_target_met,
- "status": "PASS" if benchmark_summary.cost_reduction_target_met else "FAIL"
+ "status": (
+ "PASS" if benchmark_summary.cost_reduction_target_met else "FAIL"
+ ),
},
"cache_hit_rate": {
"target_percent": 60.0,
"actual_percent": benchmark_summary.cache_hit_rate,
"met": benchmark_summary.cache_hit_rate_target_met,
- "status": "PASS" if benchmark_summary.cache_hit_rate_target_met else "FAIL"
- }
+ "status": (
+ "PASS" if benchmark_summary.cache_hit_rate_target_met else "FAIL"
+ ),
+ },
}
-
+
overall_compliance = all(target["met"] for target in targets.values())
-
+
content = {
"overall_compliance": overall_compliance,
"overall_status": "PASS" if overall_compliance else "FAIL",
"target_details": targets,
- "compliance_score": sum(1 for target in targets.values() if target["met"]) / len(targets) * 100
+ "compliance_score": sum(1 for target in targets.values() if target["met"])
+ / len(targets)
+ * 100,
}
-
+
return ReportSection(
section_id="target_compliance",
title="Performance Target Compliance",
content_type="table",
content=content,
- priority=1
+ priority=1,
)
-
- def _create_comparison_analysis_section(self, comparison_result: ComparisonResult) -> ReportSection:
+
+ def _create_comparison_analysis_section(
+ self, comparison_result: ComparisonResult
+ ) -> ReportSection:
"""Create RAG comparison analysis section."""
content = {
"performance_improvements": comparison_result.improvement_factors,
@@ -539,121 +664,156 @@ def _create_comparison_analysis_section(self, comparison_result: ComparisonResul
"fact_metrics": asdict(comparison_result.fact_metrics),
"rag_metrics": asdict(comparison_result.rag_metrics),
"detailed_analysis": comparison_result.performance_analysis,
- "recommendation": comparison_result.recommendation
+ "recommendation": comparison_result.recommendation,
}
-
+
return ReportSection(
section_id="comparison_analysis",
title="FACT vs Traditional RAG Comparison",
content_type="table",
content=content,
- priority=1
+ priority=1,
)
-
- def _create_bottleneck_analysis_section(self, profile_result: ProfileResult) -> ReportSection:
+
+ def _create_bottleneck_analysis_section(
+ self, profile_result: ProfileResult
+ ) -> ReportSection:
"""Create bottleneck analysis section."""
- critical_bottlenecks = [b for b in profile_result.bottlenecks if b.severity == "critical"]
- high_bottlenecks = [b for b in profile_result.bottlenecks if b.severity == "high"]
-
+ critical_bottlenecks = [
+ b for b in profile_result.bottlenecks if b.severity == "critical"
+ ]
+ high_bottlenecks = [
+ b for b in profile_result.bottlenecks if b.severity == "high"
+ ]
+
content = {
"total_bottlenecks": len(profile_result.bottlenecks),
"critical_count": len(critical_bottlenecks),
"high_count": len(high_bottlenecks),
"bottleneck_details": [asdict(b) for b in profile_result.bottlenecks],
"performance_summary": profile_result.performance_summary,
- "optimization_recommendations": profile_result.optimization_recommendations
+ "optimization_recommendations": profile_result.optimization_recommendations,
}
-
+
return ReportSection(
section_id="bottleneck_analysis",
title="Performance Bottleneck Analysis",
content_type="table",
content=content,
- priority=2
+ priority=2,
)
-
+
def _create_alerts_section(self, alerts: List[PerformanceAlert]) -> ReportSection:
"""Create performance alerts section."""
critical_alerts = [a for a in alerts if a.severity == "critical"]
warning_alerts = [a for a in alerts if a.severity == "warning"]
-
+
content = {
"total_alerts": len(alerts),
"critical_count": len(critical_alerts),
"warning_count": len(warning_alerts),
"alert_details": [asdict(alert) for alert in alerts],
- "urgent_actions_required": len(critical_alerts) > 0
+ "urgent_actions_required": len(critical_alerts) > 0,
}
-
+
return ReportSection(
section_id="performance_alerts",
title="Performance Alerts",
content_type="table",
content=content,
- priority=1 if critical_alerts else 2
+ priority=1 if critical_alerts else 2,
)
-
- def _create_recommendations_section(self,
- benchmark_summary: BenchmarkSummary,
- comparison_result: Optional[ComparisonResult],
- profile_result: Optional[ProfileResult]) -> ReportSection:
+
+ def _create_recommendations_section(
+ self,
+ benchmark_summary: BenchmarkSummary,
+ comparison_result: Optional[ComparisonResult],
+ profile_result: Optional[ProfileResult],
+ ) -> ReportSection:
"""Create recommendations section."""
recommendations = []
-
+
# Performance-based recommendations
if benchmark_summary.avg_response_time_ms > 100:
- recommendations.append("Investigate high response times - consider caching optimization")
-
+ recommendations.append(
+ "Investigate high response times - consider caching optimization"
+ )
+
if benchmark_summary.cache_hit_rate < 60:
- recommendations.append("Improve cache hit rate through better warming strategies")
-
+ recommendations.append(
+ "Improve cache hit rate through better warming strategies"
+ )
+
if benchmark_summary.error_rate > 5:
- recommendations.append("Address system reliability issues - high error rate detected")
-
+ recommendations.append(
+ "Address system reliability issues - high error rate detected"
+ )
+
# Comparison-based recommendations
- if comparison_result and comparison_result.improvement_factors.get("latency", 1) < 2:
- recommendations.append("Consider additional optimization - latency improvement below target")
-
+ if (
+ comparison_result
+ and comparison_result.improvement_factors.get("latency", 1) < 2
+ ):
+ recommendations.append(
+ "Consider additional optimization - latency improvement below target"
+ )
+
# Bottleneck-based recommendations
if profile_result:
recommendations.extend(profile_result.optimization_recommendations[:3])
-
+
content = {
"priority_recommendations": recommendations[:5],
"all_recommendations": recommendations,
- "implementation_priority": "HIGH" if benchmark_summary.error_rate > 10 else "MEDIUM"
+ "implementation_priority": (
+ "HIGH" if benchmark_summary.error_rate > 10 else "MEDIUM"
+ ),
}
-
+
return ReportSection(
section_id="recommendations",
title="Optimization Recommendations",
content_type="text",
content=content,
- priority=1
+ priority=1,
)
-
- def _create_report_summary(self,
- benchmark_summary: BenchmarkSummary,
- comparison_result: Optional[ComparisonResult],
- profile_result: Optional[ProfileResult]) -> Dict[str, Any]:
+
+ def _create_report_summary(
+ self,
+ benchmark_summary: BenchmarkSummary,
+ comparison_result: Optional[ComparisonResult],
+ profile_result: Optional[ProfileResult],
+ ) -> Dict[str, Any]:
"""Create overall report summary."""
summary = {
"benchmark_execution": {
"total_queries": benchmark_summary.total_queries,
"execution_time_seconds": benchmark_summary.execution_time_seconds,
- "success_rate": (benchmark_summary.successful_queries / benchmark_summary.total_queries * 100) if benchmark_summary.total_queries > 0 else 0
+ "success_rate": (
+ (
+ benchmark_summary.successful_queries
+ / benchmark_summary.total_queries
+ * 100
+ )
+ if benchmark_summary.total_queries > 0
+ else 0
+ ),
},
"performance_grade": self._calculate_performance_grade(benchmark_summary),
- "key_findings": self._extract_key_findings(benchmark_summary, comparison_result, profile_result),
- "action_required": self._determine_action_required(benchmark_summary, profile_result)
+ "key_findings": self._extract_key_findings(
+ benchmark_summary, comparison_result, profile_result
+ ),
+ "action_required": self._determine_action_required(
+ benchmark_summary, profile_result
+ ),
}
-
+
return summary
-
+
def _calculate_performance_grade(self, benchmark_summary: BenchmarkSummary) -> str:
"""Calculate overall performance grade."""
score = 0
-
+
# Latency score (40% weight)
if benchmark_summary.avg_response_time_ms <= 50:
score += 40
@@ -663,7 +823,7 @@ def _calculate_performance_grade(self, benchmark_summary: BenchmarkSummary) -> s
score += 20
else:
score += 10
-
+
# Cache hit rate score (30% weight)
if benchmark_summary.cache_hit_rate >= 70:
score += 30
@@ -673,7 +833,7 @@ def _calculate_performance_grade(self, benchmark_summary: BenchmarkSummary) -> s
score += 20
else:
score += 10
-
+
# Error rate score (20% weight)
if benchmark_summary.error_rate <= 1:
score += 20
@@ -683,7 +843,7 @@ def _calculate_performance_grade(self, benchmark_summary: BenchmarkSummary) -> s
score += 10
else:
score += 5
-
+
# Cost efficiency score (10% weight)
if benchmark_summary.cost_reduction_percentage >= 80:
score += 10
@@ -693,7 +853,7 @@ def _calculate_performance_grade(self, benchmark_summary: BenchmarkSummary) -> s
score += 6
else:
score += 3
-
+
# Convert to grade
if score >= 90:
return "A+"
@@ -709,14 +869,16 @@ def _calculate_performance_grade(self, benchmark_summary: BenchmarkSummary) -> s
return "C"
else:
return "D"
-
- def _extract_key_findings(self,
- benchmark_summary: BenchmarkSummary,
- comparison_result: Optional[ComparisonResult],
- profile_result: Optional[ProfileResult]) -> List[str]:
+
+ def _extract_key_findings(
+ self,
+ benchmark_summary: BenchmarkSummary,
+ comparison_result: Optional[ComparisonResult],
+ profile_result: Optional[ProfileResult],
+ ) -> List[str]:
"""Extract key findings from benchmark results."""
findings = []
-
+
# Performance findings
if benchmark_summary.avg_response_time_ms <= 48:
findings.append("Excellent response time performance achieved")
@@ -724,7 +886,7 @@ def _extract_key_findings(self,
findings.append("Good response time performance within targets")
else:
findings.append("Response time performance needs improvement")
-
+
# Cache findings
if benchmark_summary.cache_hit_rate >= 70:
findings.append("Cache performance excellent with high hit rates")
@@ -732,100 +894,127 @@ def _extract_key_findings(self,
findings.append("Cache performance good, meeting minimum targets")
else:
findings.append("Cache performance below optimal, requires attention")
-
+
# Cost findings
if benchmark_summary.cost_reduction_percentage >= 80:
- findings.append("Significant cost savings achieved compared to traditional methods")
+ findings.append(
+ "Significant cost savings achieved compared to traditional methods"
+ )
elif benchmark_summary.cost_reduction_percentage >= 60:
findings.append("Good cost efficiency with substantial savings")
else:
findings.append("Cost efficiency below expectations")
-
+
# Comparison findings
if comparison_result:
- latency_improvement = comparison_result.improvement_factors.get("latency", 1.0)
+ latency_improvement = comparison_result.improvement_factors.get(
+ "latency", 1.0
+ )
if latency_improvement >= 4.0:
findings.append("Outstanding latency improvement over traditional RAG")
elif latency_improvement >= 2.0:
findings.append("Significant latency improvement demonstrated")
else:
findings.append("Latency improvement present but below expectations")
-
+
return findings[:5] # Return top 5 findings
-
- def _determine_action_required(self,
- benchmark_summary: BenchmarkSummary,
- profile_result: Optional[ProfileResult]) -> str:
+
+ def _determine_action_required(
+ self,
+ benchmark_summary: BenchmarkSummary,
+ profile_result: Optional[ProfileResult],
+ ) -> str:
"""Determine required actions based on results."""
critical_issues = []
-
+
if benchmark_summary.error_rate > 10:
critical_issues.append("High error rate")
-
+
if benchmark_summary.avg_response_time_ms > 200:
critical_issues.append("Poor response times")
-
+
if benchmark_summary.cache_hit_rate < 40:
critical_issues.append("Low cache efficiency")
-
+
if profile_result:
- critical_bottlenecks = [b for b in profile_result.bottlenecks if b.severity == "critical"]
+ critical_bottlenecks = [
+ b for b in profile_result.bottlenecks if b.severity == "critical"
+ ]
if critical_bottlenecks:
critical_issues.append("Critical performance bottlenecks")
-
+
if critical_issues:
return f"IMMEDIATE ACTION REQUIRED: {', '.join(critical_issues)}"
- elif benchmark_summary.avg_response_time_ms > 100 or benchmark_summary.cache_hit_rate < 60:
+ elif (
+ benchmark_summary.avg_response_time_ms > 100
+ or benchmark_summary.cache_hit_rate < 60
+ ):
return "OPTIMIZATION RECOMMENDED: Performance improvements needed"
else:
return "MONITORING: System performing within acceptable parameters"
-
- def _generate_overall_recommendations(self,
- benchmark_summary: BenchmarkSummary,
- comparison_result: Optional[ComparisonResult],
- profile_result: Optional[ProfileResult],
- alerts: Optional[List[PerformanceAlert]]) -> List[str]:
+
+ def _generate_overall_recommendations(
+ self,
+ benchmark_summary: BenchmarkSummary,
+ comparison_result: Optional[ComparisonResult],
+ profile_result: Optional[ProfileResult],
+ alerts: Optional[List[PerformanceAlert]],
+ ) -> List[str]:
"""Generate overall optimization recommendations."""
recommendations = []
-
+
# Critical alerts take priority
if alerts:
critical_alerts = [a for a in alerts if a.severity == "critical"]
if critical_alerts:
- recommendations.append("Address critical performance alerts immediately")
-
+ recommendations.append(
+ "Address critical performance alerts immediately"
+ )
+
# Performance-based recommendations
if benchmark_summary.avg_response_time_ms > 100:
- recommendations.append("Optimize response time through caching and query optimization")
-
+ recommendations.append(
+ "Optimize response time through caching and query optimization"
+ )
+
if benchmark_summary.cache_hit_rate < 60:
- recommendations.append("Implement cache warming and improve hit rate strategies")
-
+ recommendations.append(
+ "Implement cache warming and improve hit rate strategies"
+ )
+
if benchmark_summary.error_rate > 5:
recommendations.append("Investigate and resolve system reliability issues")
-
+
# Bottleneck recommendations
if profile_result and profile_result.bottlenecks:
- top_bottleneck = max(profile_result.bottlenecks, key=lambda b: b.impact_percentage)
- recommendations.append(f"Address {top_bottleneck.component.lower()} bottleneck")
-
+ top_bottleneck = max(
+ profile_result.bottlenecks, key=lambda b: b.impact_percentage
+ )
+ recommendations.append(
+ f"Address {top_bottleneck.component.lower()} bottleneck"
+ )
+
# Cost optimization
if benchmark_summary.cost_reduction_percentage < 70:
- recommendations.append("Improve cost efficiency through better caching strategies")
-
+ recommendations.append(
+ "Improve cost efficiency through better caching strategies"
+ )
+
# General recommendations
- recommendations.extend([
- "Implement continuous performance monitoring",
- "Establish performance regression testing",
- "Document optimization strategies and results"
- ])
-
+ recommendations.extend(
+ [
+ "Implement continuous performance monitoring",
+ "Establish performance regression testing",
+ "Document optimization strategies and results",
+ ]
+ )
+
return recommendations[:8] # Return top 8 recommendations
-
+
def export_report_json(self, report: BenchmarkReport) -> str:
"""Export report as JSON."""
return json.dumps(asdict(report), indent=2, default=str)
-
+
def export_report_summary_text(self, report: BenchmarkReport) -> str:
"""Export report summary as text."""
lines = []
@@ -835,19 +1024,25 @@ def export_report_summary_text(self, report: BenchmarkReport) -> str:
lines.append(f"Generated: {datetime.fromtimestamp(report.generated_at)}")
lines.append(f"Report ID: {report.report_id}")
lines.append("")
-
+
# Executive Summary
- exec_summary = next((s for s in report.sections if s.section_id == "executive_summary"), None)
+ exec_summary = next(
+ (s for s in report.sections if s.section_id == "executive_summary"), None
+ )
if exec_summary:
lines.append("EXECUTIVE SUMMARY")
lines.append("-" * 20)
content = exec_summary.content
lines.append(f"Overall Performance: {content['overall_assessment']}")
lines.append(f"Cache Performance: {content['cache_assessment']}")
- lines.append(f"Average Response Time: {content['key_metrics']['avg_response_time_ms']:.1f}ms")
- lines.append(f"Cache Hit Rate: {content['key_metrics']['cache_hit_rate']:.1f}%")
+ lines.append(
+ f"Average Response Time: {content['key_metrics']['avg_response_time_ms']:.1f}ms"
+ )
+ lines.append(
+ f"Cache Hit Rate: {content['key_metrics']['cache_hit_rate']:.1f}%"
+ )
lines.append("")
-
+
# Key Recommendations
if report.recommendations:
lines.append("KEY RECOMMENDATIONS")
@@ -855,10 +1050,10 @@ def export_report_summary_text(self, report: BenchmarkReport) -> str:
for i, rec in enumerate(report.recommendations[:5], 1):
lines.append(f"{i}. {rec}")
lines.append("")
-
+
# Performance Grade
if "performance_grade" in report.summary:
lines.append(f"PERFORMANCE GRADE: {report.summary['performance_grade']}")
lines.append("")
-
- return "\n".join(lines)
\ No newline at end of file
+
+ return "\n".join(lines)
diff --git a/src/cache/validation.py b/src/cache/validation.py
index 0e1796c..7cd9590 100644
--- a/src/cache/validation.py
+++ b/src/cache/validation.py
@@ -582,7 +582,7 @@ async def auto_repair_cache(self, validation_result: ValidationResult) -> Dict[s
repair_summary["warnings_addressed"] += 1
logger.info("Auto-repair completed", **repair_summary)
- return freed_space
+ return repair_summary
except Exception as e:
logger.error("Auto-repair failed", error=str(e))
diff --git a/src/core/agentic_flow.py b/src/core/agentic_flow.py
index 1ffc727..eaca05a 100644
--- a/src/core/agentic_flow.py
+++ b/src/core/agentic_flow.py
@@ -6,7 +6,7 @@
The FACT algorithm combines:
- Fast cache-first query resolution
-- Access pattern optimization
+- Access pattern optimization
- Caching strategy adaptation
- Token-efficient LLM interactions
"""
@@ -32,6 +32,7 @@
@dataclass
class FACTQueryContext:
"""Context for FACT query processing."""
+
query_id: str
user_query: str
query_hash: str
@@ -39,7 +40,7 @@ class FACTQueryContext:
cache_mode: str = "read"
priority: float = 1.0
metadata: Dict[str, Any] = None
-
+
def __post_init__(self):
if self.metadata is None:
self.metadata = {}
@@ -48,6 +49,7 @@ def __post_init__(self):
@dataclass
class FACTResponse:
"""FACT algorithm response with performance metrics."""
+
content: str
cache_hit: bool
execution_time_ms: float
@@ -57,7 +59,7 @@ class FACTResponse:
strategy_used: str
query_context: FACTQueryContext
performance_metrics: Dict[str, Any] = None
-
+
def __post_init__(self):
if self.performance_metrics is None:
self.performance_metrics = {}
@@ -66,22 +68,22 @@ def __post_init__(self):
class FACTAlgorithm:
"""
Core FACT Algorithm Implementation
-
+
Implements the Fast Access Caching Technology algorithm that optimizes
LLM interactions through intelligent caching, access pattern analysis,
and adaptive strategy selection.
-
+
Key Components:
1. Cache-First Query Resolution
- 2. Token-Efficient Content Processing
+ 2. Token-Efficient Content Processing
3. Adaptive Caching Strategies
4. Performance Monitoring and Optimization
"""
-
+
def __init__(self, config: Optional[Config] = None):
"""
Initialize FACT algorithm with configuration.
-
+
Args:
config: Optional configuration instance
"""
@@ -91,20 +93,20 @@ def __init__(self, config: Optional[Config] = None):
self.metrics_collector = get_metrics_collector()
self.tool_registry = get_tool_registry()
self._initialized = False
-
+
# FACT algorithm parameters
self.cache_hit_target_ms = 50 # Target latency for cache hits
- self.cache_miss_target_ms = 140 # Target latency for cache misses
+ self.cache_miss_target_ms = 140 # Target latency for cache misses
self.min_cache_tokens = 500 # Minimum tokens for caching
self.token_efficiency_target = 100.0 # Tokens per KB target
-
+
logger.info("FACT algorithm initialized")
-
+
async def initialize(self) -> None:
"""Initialize FACT algorithm components."""
if self._initialized:
return
-
+
try:
# Initialize cache manager with FACT-specific configuration
cache_config = {
@@ -113,132 +115,144 @@ async def initialize(self) -> None:
"max_size": "10MB",
"ttl_seconds": 3600, # 1 hour TTL
"hit_target_ms": self.cache_hit_target_ms,
- "miss_target_ms": self.cache_miss_target_ms
+ "miss_target_ms": self.cache_miss_target_ms,
}
-
+
self.cache_manager = get_cache_manager(cache_config)
-
+
# Initialize cache optimizer with adaptive strategy
self.cache_optimizer = get_cache_optimizer(CacheStrategy.ADAPTIVE)
-
+
# Start background optimization
asyncio.create_task(self._run_background_optimization())
-
+
self._initialized = True
logger.info("FACT algorithm initialization completed")
-
+
except Exception as e:
logger.error("FACT algorithm initialization failed", error=str(e))
raise FACTError(f"FACT initialization failed: {e}")
-
- async def process_query(self, user_query: str, context: Optional[Dict[str, Any]] = None) -> FACTResponse:
+
+ async def process_query(
+ self, user_query: str, context: Optional[Dict[str, Any]] = None
+ ) -> FACTResponse:
"""
Process query using FACT algorithm.
-
+
This is the main entry point for the FACT algorithm, implementing:
1. Query normalization and hashing
2. Cache-first lookup with performance tracking
3. LLM fallback with tool integration
4. Response caching with strategy optimization
5. Performance metrics collection
-
+
Args:
user_query: User's natural language query
context: Optional context for query processing
-
+
Returns:
FACTResponse with content and performance metrics
"""
if not self._initialized:
await self.initialize()
-
+
start_time = time.perf_counter()
-
+
# Create query context
query_context = self._create_query_context(user_query, context)
-
+
try:
- logger.info("FACT query processing started",
- query_id=query_context.query_id,
- query_hash=query_context.query_hash[:16])
-
+ logger.info(
+ "FACT query processing started",
+ query_id=query_context.query_id,
+ query_hash=query_context.query_hash[:16],
+ )
+
# Phase 1: Cache-first lookup
cached_response = await self._cache_lookup_phase(query_context)
if cached_response:
return cached_response
-
+
# Phase 2: LLM processing with tool integration
llm_response = await self._llm_processing_phase(query_context)
-
+
# Phase 3: Cache storage and optimization
await self._cache_storage_phase(query_context, llm_response)
-
+
# Record performance metrics
execution_time = (time.perf_counter() - start_time) * 1000
self._record_query_metrics(query_context, llm_response, execution_time)
-
+
return llm_response
-
+
except Exception as e:
execution_time = (time.perf_counter() - start_time) * 1000
- logger.error("FACT query processing failed",
- query_id=query_context.query_id,
- error=str(e),
- execution_time_ms=execution_time)
-
+ logger.error(
+ "FACT query processing failed",
+ query_id=query_context.query_id,
+ error=str(e),
+ execution_time_ms=execution_time,
+ )
+
# Record failure metrics
self.metrics_collector.record_tool_execution(
tool_name="fact_algorithm",
success=False,
execution_time=execution_time,
error_type=type(e).__name__,
- metadata={"query_id": query_context.query_id}
+ metadata={"query_id": query_context.query_id},
)
-
+
raise FACTError(f"FACT query processing failed: {e}")
-
- def _create_query_context(self, user_query: str, context: Optional[Dict[str, Any]]) -> FACTQueryContext:
+
+ def _create_query_context(
+ self, user_query: str, context: Optional[Dict[str, Any]]
+ ) -> FACTQueryContext:
"""Create query context with normalization and hashing."""
# Normalize query for consistent hashing
normalized_query = user_query.strip().lower()
-
+
# Generate deterministic hash
query_hash = self.cache_manager.generate_hash(normalized_query)
-
+
# Create unique query ID
query_id = f"fact_{int(time.time() * 1000)}_{query_hash[:8]}"
-
+
return FACTQueryContext(
query_id=query_id,
user_query=user_query,
query_hash=query_hash,
timestamp=time.time(),
- metadata=context or {}
+ metadata=context or {},
)
-
- async def _cache_lookup_phase(self, query_context: FACTQueryContext) -> Optional[FACTResponse]:
+
+ async def _cache_lookup_phase(
+ self, query_context: FACTQueryContext
+ ) -> Optional[FACTResponse]:
"""
Phase 1: Cache-first lookup with performance tracking.
-
+
Implements fast cache lookup with:
- Performance monitoring
- Cache hit optimization
- Access pattern tracking
"""
cache_start = time.perf_counter()
-
+
try:
# Attempt cache lookup
cached_entry = self.cache_manager.get(query_context.query_hash)
-
+
cache_latency = (time.perf_counter() - cache_start) * 1000
-
+
if cached_entry:
- logger.info("FACT cache hit",
- query_id=query_context.query_id,
- cache_latency_ms=cache_latency,
- access_count=cached_entry.access_count)
-
+ logger.info(
+ "FACT cache hit",
+ query_id=query_context.query_id,
+ cache_latency_ms=cache_latency,
+ access_count=cached_entry.access_count,
+ )
+
# Create cache hit response
response = FACTResponse(
content=cached_entry.content,
@@ -252,10 +266,10 @@ async def _cache_lookup_phase(self, query_context: FACTQueryContext) -> Optional
performance_metrics={
"cache_latency_ms": cache_latency,
"access_count": cached_entry.access_count,
- "cache_age_seconds": time.time() - cached_entry.created_at
- }
+ "cache_age_seconds": time.time() - cached_entry.created_at,
+ },
)
-
+
# Record cache hit metrics
self.metrics_collector.record_tool_execution(
tool_name="fact_cache_hit",
@@ -264,59 +278,66 @@ async def _cache_lookup_phase(self, query_context: FACTQueryContext) -> Optional
metadata={
"query_id": query_context.query_id,
"token_count": cached_entry.token_count,
- "access_count": cached_entry.access_count
- }
+ "access_count": cached_entry.access_count,
+ },
)
-
+
return response
else:
- logger.info("FACT cache miss",
- query_id=query_context.query_id,
- cache_latency_ms=cache_latency)
-
+ logger.info(
+ "FACT cache miss",
+ query_id=query_context.query_id,
+ cache_latency_ms=cache_latency,
+ )
+
# Record cache miss metrics
self.metrics_collector.record_tool_execution(
tool_name="fact_cache_miss",
success=True,
execution_time=cache_latency,
- metadata={"query_id": query_context.query_id}
+ metadata={"query_id": query_context.query_id},
)
-
+
return None
-
+
except Exception as e:
cache_latency = (time.perf_counter() - cache_start) * 1000
- logger.warning("Cache lookup failed",
- query_id=query_context.query_id,
- error=str(e),
- cache_latency_ms=cache_latency)
+ logger.warning(
+ "Cache lookup failed",
+ query_id=query_context.query_id,
+ error=str(e),
+ cache_latency_ms=cache_latency,
+ )
return None
-
- async def _llm_processing_phase(self, query_context: FACTQueryContext) -> FACTResponse:
+
+ async def _llm_processing_phase(
+ self, query_context: FACTQueryContext
+ ) -> FACTResponse:
"""
Phase 2: LLM processing with tool integration.
-
+
Implements intelligent LLM interaction with:
- Tool execution optimization
- Token efficiency monitoring
- Response quality assessment
"""
llm_start = time.perf_counter()
-
+
try:
# Get FACT driver for LLM processing
from .driver import get_driver
+
driver = await get_driver(self.config)
-
+
# Process query through driver
response_content = await driver.process_query(query_context.user_query)
-
+
llm_latency = (time.perf_counter() - llm_start) * 1000
-
+
# Calculate token count and efficiency
token_count = CacheEntry._count_tokens(response_content)
cache_efficiency = self._calculate_content_efficiency(response_content)
-
+
# Create LLM response
response = FACTResponse(
content=response_content,
@@ -330,43 +351,51 @@ async def _llm_processing_phase(self, query_context: FACTQueryContext) -> FACTRe
performance_metrics={
"llm_latency_ms": llm_latency,
"token_count": token_count,
- "cache_efficiency": cache_efficiency
- }
+ "cache_efficiency": cache_efficiency,
+ },
)
-
- logger.info("FACT LLM processing completed",
- query_id=query_context.query_id,
- llm_latency_ms=llm_latency,
- token_count=token_count,
- cache_efficiency=cache_efficiency)
-
+
+ logger.info(
+ "FACT LLM processing completed",
+ query_id=query_context.query_id,
+ llm_latency_ms=llm_latency,
+ token_count=token_count,
+ cache_efficiency=cache_efficiency,
+ )
+
return response
-
+
except Exception as e:
llm_latency = (time.perf_counter() - llm_start) * 1000
- logger.error("LLM processing failed",
- query_id=query_context.query_id,
- error=str(e),
- llm_latency_ms=llm_latency)
+ logger.error(
+ "LLM processing failed",
+ query_id=query_context.query_id,
+ error=str(e),
+ llm_latency_ms=llm_latency,
+ )
raise ToolExecutionError(f"LLM processing failed: {e}")
-
- async def _cache_storage_phase(self, query_context: FACTQueryContext, response: FACTResponse) -> None:
+
+ async def _cache_storage_phase(
+ self, query_context: FACTQueryContext, response: FACTResponse
+ ) -> None:
"""
Phase 3: Cache storage with strategy optimization.
-
+
Implements intelligent caching with:
- Content quality assessment
- Storage strategy optimization
- Cache efficiency monitoring
"""
if not response.content or response.token_count < self.min_cache_tokens:
- logger.debug("Content not cached - insufficient tokens",
- query_id=query_context.query_id,
- token_count=response.token_count)
+ logger.debug(
+ "Content not cached - insufficient tokens",
+ query_id=query_context.query_id,
+ token_count=response.token_count,
+ )
return
-
+
storage_start = time.perf_counter()
-
+
try:
# Check if content should be cached based on strategy
should_cache = self.cache_optimizer.should_cache_content(
@@ -374,21 +403,25 @@ async def _cache_storage_phase(self, query_context: FACTQueryContext, response:
{
"query": query_context.user_query,
"token_count": response.token_count,
- "execution_time": response.execution_time_ms
- }
+ "execution_time": response.execution_time_ms,
+ },
)
-
+
if should_cache:
# Store in cache
- cache_entry = self.cache_manager.store(query_context.query_hash, response.content)
-
+ cache_entry = self.cache_manager.store(
+ query_context.query_hash, response.content
+ )
+
storage_latency = (time.perf_counter() - storage_start) * 1000
-
- logger.info("Content cached successfully",
- query_id=query_context.query_id,
- token_count=cache_entry.token_count,
- storage_latency_ms=storage_latency)
-
+
+ logger.info(
+ "Content cached successfully",
+ query_id=query_context.query_id,
+ token_count=cache_entry.token_count,
+ storage_latency_ms=storage_latency,
+ )
+
# Record cache storage metrics
self.metrics_collector.record_tool_execution(
tool_name="fact_cache_store",
@@ -396,41 +429,52 @@ async def _cache_storage_phase(self, query_context: FACTQueryContext, response:
execution_time=storage_latency,
metadata={
"query_id": query_context.query_id,
- "token_count": cache_entry.token_count
- }
+ "token_count": cache_entry.token_count,
+ },
)
else:
- logger.debug("Content not cached - strategy decision",
- query_id=query_context.query_id)
-
+ logger.debug(
+ "Content not cached - strategy decision",
+ query_id=query_context.query_id,
+ )
+
except CacheError as e:
storage_latency = (time.perf_counter() - storage_start) * 1000
- logger.warning("Cache storage failed",
- query_id=query_context.query_id,
- error=str(e),
- storage_latency_ms=storage_latency)
-
+ logger.warning(
+ "Cache storage failed",
+ query_id=query_context.query_id,
+ error=str(e),
+ storage_latency_ms=storage_latency,
+ )
+
def _calculate_cache_efficiency(self, cache_entry: CacheEntry) -> float:
"""Calculate cache efficiency score for an entry."""
- content_size_kb = len(cache_entry.content.encode('utf-8')) / 1024
+ content_size_kb = len(cache_entry.content.encode("utf-8")) / 1024
return cache_entry.token_count / content_size_kb if content_size_kb > 0 else 0.0
-
+
def _calculate_content_efficiency(self, content: str) -> float:
"""Calculate content efficiency for new content."""
if not content:
return 0.0
-
+
token_count = CacheEntry._count_tokens(content)
- content_size_kb = len(content.encode('utf-8')) / 1024
+ content_size_kb = len(content.encode("utf-8")) / 1024
return token_count / content_size_kb if content_size_kb > 0 else 0.0
-
+
def _count_tool_calls(self, content: str) -> int:
"""Estimate number of tool calls from response content."""
# Simple heuristic - count mentions of tool execution patterns
tool_indicators = ["executed", "query", "result", "data"]
- return sum(1 for indicator in tool_indicators if indicator.lower() in content.lower())
-
- def _record_query_metrics(self, query_context: FACTQueryContext, response: FACTResponse, execution_time: float) -> None:
+ return sum(
+ 1 for indicator in tool_indicators if indicator.lower() in content.lower()
+ )
+
+ def _record_query_metrics(
+ self,
+ query_context: FACTQueryContext,
+ response: FACTResponse,
+ execution_time: float,
+ ) -> None:
"""Record comprehensive query processing metrics."""
self.metrics_collector.record_tool_execution(
tool_name="fact_query_complete",
@@ -442,42 +486,45 @@ def _record_query_metrics(self, query_context: FACTQueryContext, response: FACTR
"token_count": response.token_count,
"cache_efficiency": response.cache_efficiency,
"strategy_used": response.strategy_used,
- "tool_calls_count": response.tool_calls_count
- }
+ "tool_calls_count": response.tool_calls_count,
+ },
)
-
+
async def _run_background_optimization(self) -> None:
"""Run background cache optimization."""
logger.info("Starting FACT background optimization")
-
+
while True:
try:
await asyncio.sleep(300) # Run every 5 minutes
-
+
if self.cache_manager and self.cache_optimizer:
- optimization_results = await self.cache_optimizer.optimize_cache(self.cache_manager)
-
- logger.debug("Background optimization completed",
- **optimization_results)
-
+ optimization_results = await self.cache_optimizer.optimize_cache(
+ self.cache_manager
+ )
+
+ logger.debug(
+ "Background optimization completed", **optimization_results
+ )
+
except asyncio.CancelledError:
logger.info("Background optimization cancelled")
break
except Exception as e:
logger.error("Background optimization failed", error=str(e))
# Continue running despite errors
-
+
def get_algorithm_metrics(self) -> Dict[str, Any]:
"""Get FACT algorithm performance metrics."""
if not self.cache_manager:
return {}
-
+
# Get cache metrics
cache_metrics = self.cache_manager.get_metrics()
-
+
# Get system metrics
system_metrics = self.metrics_collector.get_system_metrics()
-
+
# Calculate FACT-specific metrics
fact_metrics = {
"algorithm": "FACT",
@@ -494,10 +541,10 @@ def get_algorithm_metrics(self) -> Dict[str, Any]:
"cache_hit_target_ms": self.cache_hit_target_ms,
"cache_miss_target_ms": self.cache_miss_target_ms,
"min_cache_tokens": self.min_cache_tokens,
- "token_efficiency_target": self.token_efficiency_target
- }
+ "token_efficiency_target": self.token_efficiency_target,
+ },
}
-
+
return fact_metrics
@@ -508,32 +555,34 @@ def get_algorithm_metrics(self) -> Dict[str, Any]:
async def get_fact_algorithm(config: Optional[Config] = None) -> FACTAlgorithm:
"""
Get or create the global FACT algorithm instance.
-
+
Args:
config: Optional configuration
-
+
Returns:
Initialized FACT algorithm instance
"""
global _fact_algorithm
-
+
if _fact_algorithm is None:
_fact_algorithm = FACTAlgorithm(config)
await _fact_algorithm.initialize()
-
+
return _fact_algorithm
-async def process_fact_query(user_query: str, context: Optional[Dict[str, Any]] = None) -> FACTResponse:
+async def process_fact_query(
+ user_query: str, context: Optional[Dict[str, Any]] = None
+) -> FACTResponse:
"""
Process a query using the FACT algorithm.
-
+
Args:
user_query: User's natural language query
context: Optional query context
-
+
Returns:
FACT response with performance metrics
"""
algorithm = await get_fact_algorithm()
- return await algorithm.process_query(user_query, context)
\ No newline at end of file
+ return await algorithm.process_query(user_query, context)
diff --git a/src/core/cli.py b/src/core/cli.py
index 4d5ba5b..16b41ff 100644
--- a/src/core/cli.py
+++ b/src/core/cli.py
@@ -16,128 +16,129 @@
from .config import get_config
from .errors import FACTError, create_user_friendly_message
-
logger = structlog.get_logger(__name__)
class FACTCLi:
"""
Interactive command-line interface for the FACT system.
-
+
Provides user interaction, query processing, and system management
through a console-based interface.
"""
-
+
def __init__(self):
"""Initialize CLI interface."""
self.driver = None
self.running = False
-
+
async def initialize(self) -> None:
"""Initialize the CLI and underlying FACT system."""
try:
logger.info("Initializing FACT CLI")
self.driver = await get_driver()
logger.info("FACT CLI initialized successfully")
-
+
except Exception as e:
logger.error("CLI initialization failed", error=str(e))
print(f"ā Initialization failed: {create_user_friendly_message(e)}")
sys.exit(1)
-
+
async def run_interactive(self) -> None:
"""
Run the interactive CLI loop.
-
+
Handles user input, query processing, and graceful shutdown.
"""
if not self.driver:
await self.initialize()
-
+
self.running = True
-
+
# Print welcome message
print("š FACT System - Fast-Access Cached Tools")
- print("š” Ask questions about financial data. Type 'help' for commands or Ctrl+C to exit.")
+ print(
+ "š” Ask questions about financial data. Type 'help' for commands or Ctrl+C to exit."
+ )
print()
-
+
# Show system status
await self._show_status()
-
+
while self.running:
try:
# Get user input
user_input = input("\nš¬ > ").strip()
-
+
if not user_input:
continue
-
+
# Handle special commands
if await self._handle_command(user_input):
continue
-
+
# Process query through FACT system
print("š¤ Processing your query...")
-
+
response = await self.driver.process_query(user_input)
-
+
# Display response
print("\nš Response:")
print(response)
-
+
except (EOFError, KeyboardInterrupt):
print("\nš Goodbye!")
break
-
+
except Exception as e:
error_message = create_user_friendly_message(e)
print(f"\nā Error: {error_message}")
logger.error("CLI error", error=str(e))
-
+
await self._shutdown()
-
+
async def _handle_command(self, input_text: str) -> bool:
"""
Handle special CLI commands.
-
+
Args:
input_text: User input to check for commands
-
+
Returns:
True if input was a command, False otherwise
"""
command = input_text.lower().strip()
-
- if command in ['help', '?']:
+
+ if command in ["help", "?"]:
await self._show_help()
return True
-
- elif command in ['status', 'stats']:
+
+ elif command in ["status", "stats"]:
await self._show_status()
return True
-
- elif command in ['tools', 'list-tools']:
+
+ elif command in ["tools", "list-tools"]:
await self._show_tools()
return True
-
- elif command in ['schema', 'db-schema']:
+
+ elif command in ["schema", "db-schema"]:
await self._show_schema()
return True
-
- elif command in ['samples', 'examples']:
+
+ elif command in ["samples", "examples"]:
await self._show_sample_queries()
return True
-
- elif command in ['metrics', 'performance']:
+
+ elif command in ["metrics", "performance"]:
await self._show_metrics()
return True
-
- elif command in ['exit', 'quit', 'q']:
+
+ elif command in ["exit", "quit", "q"]:
self.running = False
return True
-
+
return False
-
+
async def _show_help(self) -> None:
"""Display help information."""
help_text = """
@@ -161,57 +162,62 @@ async def _show_help(self) -> None:
The system will automatically use SQL tools to retrieve data and provide answers.
"""
print(help_text)
-
+
async def _show_status(self) -> None:
"""Display system status information."""
try:
if not self.driver:
print("ā System not initialized")
return
-
+
config = get_config()
metrics = self.driver.get_metrics()
-
+
print("š System Status:")
- print(f" ⢠Status: {'ā
Ready' if metrics['initialized'] else 'ā Not Ready'}")
+ print(
+ f" ⢠Status: {'ā
Ready' if metrics['initialized'] else 'ā Not Ready'}"
+ )
print(f" ⢠Database: {config.database_path}")
print(f" ⢠Model: {config.claude_model}")
print(f" ⢠Cache Prefix: {config.cache_prefix}")
- print(f" ⢠Tools Registered: {len(self.driver.tool_registry.list_tools())}")
-
+ print(
+ f" ⢠Tools Registered: {len(self.driver.tool_registry.list_tools())}"
+ )
+
except Exception as e:
print(f"ā Failed to get status: {e}")
-
+
async def _show_tools(self) -> None:
"""Display available tools."""
try:
if not self.driver:
print("ā System not initialized")
return
-
+
tool_names = self.driver.tool_registry.list_tools()
-
+
print("š ļø Available Tools:")
for tool_name in tool_names:
tool_def = self.driver.tool_registry.get_tool(tool_name)
print(f" ⢠{tool_name}: {tool_def.description}")
-
+
print(f"\nTotal: {len(tool_names)} tools")
-
+
except Exception as e:
print(f"ā Failed to list tools: {e}")
-
+
async def _show_schema(self) -> None:
"""Display database schema information."""
try:
if not self.driver:
print("ā System not initialized")
return
-
+
# Use the SQL tool to get schema
from ..tools.connectors.sql import sql_get_schema
+
schema_info = await sql_get_schema()
-
+
if schema_info.get("status") == "success":
print("šļø Database Schema:")
for table in schema_info["tables"]:
@@ -219,41 +225,44 @@ async def _show_schema(self) -> None:
for column in table["columns"]:
nullable = "NULL" if column["nullable"] else "NOT NULL"
pk = " (PRIMARY KEY)" if column["primary_key"] else ""
- print(f" ⢠{column['name']}: {column['type']} {nullable}{pk}")
+ print(
+ f" ⢠{column['name']}: {column['type']} {nullable}{pk}"
+ )
else:
print(f"ā Failed to get schema: {schema_info.get('error')}")
-
+
except Exception as e:
print(f"ā Failed to show schema: {e}")
-
+
async def _show_sample_queries(self) -> None:
"""Display sample SQL queries."""
try:
if not self.driver:
print("ā System not initialized")
return
-
+
# Use the SQL tool to get sample queries
from ..tools.connectors.sql import sql_get_sample_queries
+
samples = sql_get_sample_queries()
-
+
print("š Sample Queries:")
for i, sample in enumerate(samples["sample_queries"], 1):
print(f"\n {i}. {sample['description']}")
print(f" {sample['query']}")
-
+
except Exception as e:
print(f"ā Failed to show samples: {e}")
-
+
async def _show_metrics(self) -> None:
"""Display performance metrics."""
try:
if not self.driver:
print("ā System not initialized")
return
-
+
metrics = self.driver.get_metrics()
-
+
print("š Performance Metrics:")
print(f" ⢠Total Queries: {metrics['total_queries']}")
print(f" ⢠Cache Hit Rate: {metrics['cache_hit_rate']:.1f}%")
@@ -261,17 +270,17 @@ async def _show_metrics(self) -> None:
print(f" ⢠Error Rate: {metrics['error_rate']:.1f}%")
print(f" ⢠Cache Hits: {metrics['cache_hits']}")
print(f" ⢠Cache Misses: {metrics['cache_misses']}")
-
+
except Exception as e:
print(f"ā Failed to show metrics: {e}")
-
+
async def _shutdown(self) -> None:
"""Shutdown the CLI and underlying systems."""
try:
print("\nš Shutting down...")
await shutdown_driver()
print("ā
Shutdown complete")
-
+
except Exception as e:
print(f"ā Shutdown error: {e}")
@@ -279,41 +288,29 @@ async def _shutdown(self) -> None:
async def main() -> None:
"""
Main entry point for the FACT CLI application.
-
+
Handles command-line arguments and starts the appropriate mode.
"""
parser = argparse.ArgumentParser(
description="FACT System - Fast-Access Cached Tools",
- formatter_class=argparse.RawDescriptionHelpFormatter
- )
-
- parser.add_argument(
- "--version",
- action="version",
- version="FACT System v1.0.0"
- )
-
- parser.add_argument(
- "--query",
- type=str,
- help="Process a single query and exit"
- )
-
- parser.add_argument(
- "--config",
- type=str,
- help="Path to configuration file"
+ formatter_class=argparse.RawDescriptionHelpFormatter,
)
-
+
+ parser.add_argument("--version", action="version", version="FACT System v1.0.0")
+
+ parser.add_argument("--query", type=str, help="Process a single query and exit")
+
+ parser.add_argument("--config", type=str, help="Path to configuration file")
+
parser.add_argument(
"--log-level",
choices=["DEBUG", "INFO", "WARNING", "ERROR"],
default="INFO",
- help="Set logging level"
+ help="Set logging level",
)
-
+
args = parser.parse_args()
-
+
# Configure logging
structlog.configure(
wrapper_class=structlog.make_filtering_bound_logger(
@@ -325,14 +322,14 @@ async def main() -> None:
structlog.stdlib.add_logger_name,
structlog.stdlib.add_log_level,
structlog.processors.TimeStamper(fmt="ISO"),
- structlog.dev.ConsoleRenderer()
+ structlog.dev.ConsoleRenderer(),
],
cache_logger_on_first_use=True,
)
-
+
try:
cli = FACTCLi()
-
+
if args.query:
# Single query mode
await cli.initialize()
@@ -341,7 +338,7 @@ async def main() -> None:
else:
# Interactive mode
await cli.run_interactive()
-
+
except KeyboardInterrupt:
print("\nš Interrupted by user")
sys.exit(0)
@@ -352,4 +349,4 @@ async def main() -> None:
if __name__ == "__main__":
- asyncio.run(main())
\ No newline at end of file
+ asyncio.run(main())
diff --git a/src/core/config.py b/src/core/config.py
index 6161771..6d63681 100644
--- a/src/core/config.py
+++ b/src/core/config.py
@@ -11,32 +11,33 @@
from dotenv import load_dotenv
import structlog
-
logger = structlog.get_logger(__name__)
class ConfigurationError(Exception):
"""Raised when configuration validation fails."""
+
pass
class ConnectionError(Exception):
"""Raised when API connectivity tests fail."""
+
pass
class Config:
"""
Central configuration management for the FACT system.
-
+
Handles environment variable loading, validation, and provides
structured access to system configuration parameters.
"""
-
+
def __init__(self, env_file: Optional[str] = None):
"""
Initialize configuration from environment variables.
-
+
Args:
env_file: Optional path to .env file (defaults to .env)
"""
@@ -44,7 +45,7 @@ def __init__(self, env_file: Optional[str] = None):
self._config: Dict[str, Any] = {}
self._load_environment()
self._validate_required_keys()
-
+
def _load_environment(self) -> None:
"""Load environment variables from .env file if it exists."""
if os.path.exists(self.env_file):
@@ -52,22 +53,19 @@ def _load_environment(self) -> None:
logger.info("Loaded environment configuration", file=self.env_file)
else:
logger.warning("No .env file found, using system environment")
-
+
def _validate_required_keys(self) -> None:
"""
Validate that all required configuration keys are present and valid.
-
+
Raises:
ConfigurationError: If any required keys are missing or invalid
"""
- required_keys = [
- "ANTHROPIC_API_KEY",
- "ARCADE_API_KEY"
- ]
-
+ required_keys = ["ANTHROPIC_API_KEY", "ARCADE_API_KEY"]
+
missing_keys = []
invalid_keys = []
-
+
for key in required_keys:
value = os.getenv(key)
if not value:
@@ -76,19 +74,19 @@ def _validate_required_keys(self) -> None:
missing_keys.append(key) # Treat whitespace-only as missing
elif self._is_placeholder_key(value.strip()):
invalid_keys.append(key)
-
+
if missing_keys:
raise ConfigurationError(
f"Missing required configuration keys: {', '.join(missing_keys)}"
)
-
+
if invalid_keys:
raise ConfigurationError(
f"Invalid placeholder values for keys: {', '.join(invalid_keys)}. Please set real API keys."
)
-
+
logger.info("Configuration validation passed")
-
+
def _is_placeholder_key(self, value: str) -> bool:
"""Check if a configuration value is a placeholder."""
placeholder_patterns = [
@@ -98,35 +96,35 @@ def _is_placeholder_key(self, value: str) -> bool:
"placeholder",
"changeme",
"todo",
- "fix_me"
+ "fix_me",
]
return any(pattern in value.lower() for pattern in placeholder_patterns)
-
+
@property
def anthropic_api_key(self) -> str:
"""Get Anthropic API key."""
return os.getenv("ANTHROPIC_API_KEY", "")
-
+
@property
def arcade_api_key(self) -> str:
"""Get Arcade API key."""
return os.getenv("ARCADE_API_KEY", "")
-
+
@property
def arcade_base_url(self) -> str:
"""Get Arcade base URL."""
return os.getenv("ARCADE_BASE_URL", "https://api.arcade-ai.com")
-
+
@property
def database_path(self) -> str:
"""Get database file path."""
return os.getenv("DATABASE_PATH", "data/fact_demo.db")
-
+
@property
def cache_prefix(self) -> str:
"""Get cache prefix for Claude caching."""
return os.getenv("CACHE_PREFIX", "fact_v1")
-
+
@property
def cache_config(self) -> Dict[str, Any]:
"""Get cache configuration dictionary."""
@@ -136,9 +134,9 @@ def cache_config(self) -> Dict[str, Any]:
"max_size": os.getenv("CACHE_MAX_SIZE", "10MB"),
"ttl_seconds": int(os.getenv("CACHE_TTL_SECONDS", "3600")),
"hit_target_ms": float(os.getenv("CACHE_HIT_TARGET_MS", "30")),
- "miss_target_ms": float(os.getenv("CACHE_MISS_TARGET_MS", "120"))
+ "miss_target_ms": float(os.getenv("CACHE_MISS_TARGET_MS", "120")),
}
-
+
@property
def system_prompt(self) -> str:
"""Get system prompt for Claude."""
@@ -161,32 +159,33 @@ def system_prompt(self) -> str:
Example: If asked "What's TechCorp's revenue?" immediately execute:
SELECT revenue FROM financial_records WHERE company_id = (SELECT id FROM companies WHERE name LIKE '%TechCorp%')
-Always show real data, not placeholders or descriptions of what you would do."""
+Always show real data, not placeholders or descriptions of what you would do.""",
)
-
+
@property
def claude_model(self) -> str:
"""Get Claude model name."""
return os.getenv("CLAUDE_MODEL", "claude-3-haiku-20240307")
+
@property
def max_retries(self) -> int:
"""Get maximum retry attempts for failed operations."""
return int(os.getenv("MAX_RETRIES", "3"))
-
+
@property
def request_timeout(self) -> int:
"""Get request timeout in seconds."""
return int(os.getenv("REQUEST_TIMEOUT", "30"))
-
+
@property
def log_level(self) -> str:
"""Get logging level."""
return os.getenv("LOG_LEVEL", "INFO")
-
+
def to_dict(self) -> Dict[str, Any]:
"""
Export configuration as dictionary (excluding sensitive data).
-
+
Returns:
Dictionary of non-sensitive configuration values
"""
@@ -207,7 +206,7 @@ def to_dict(self) -> Dict[str, Any]:
def get_config() -> Config:
"""
Get global configuration instance.
-
+
Returns:
Configured Config instance
"""
@@ -217,10 +216,10 @@ def get_config() -> Config:
def validate_configuration(config: Config) -> None:
"""
Validate configuration and test connectivity to required services.
-
+
Args:
config: Configuration instance to validate
-
+
Raises:
ConfigurationError: If configuration is invalid
ConnectionError: If service connectivity tests fail
@@ -228,10 +227,10 @@ def validate_configuration(config: Config) -> None:
try:
# Basic configuration validation is done in Config.__init__
logger.info("Configuration validation completed successfully")
-
+
# Log configuration summary (without sensitive data)
logger.info("Configuration summary", config=config.to_dict())
-
+
except Exception as e:
logger.error("Configuration validation failed", error=str(e))
- raise ConfigurationError(f"Configuration validation failed: {e}")
\ No newline at end of file
+ raise ConfigurationError(f"Configuration validation failed: {e}")
diff --git a/src/core/conversation.py b/src/core/conversation.py
index 5b89d07..f9fc846 100644
--- a/src/core/conversation.py
+++ b/src/core/conversation.py
@@ -10,13 +10,13 @@
from dataclasses import dataclass, field
import structlog
-
logger = structlog.get_logger(__name__)
@dataclass
class ConversationTurn:
"""Represents a single conversation turn."""
+
timestamp: float
user_input: str
assistant_response: str
@@ -28,49 +28,54 @@ class ConversationTurn:
@dataclass
class ConversationContext:
"""Manages conversation context and state."""
+
conversation_id: str
turns: List[ConversationTurn] = field(default_factory=list)
current_topic: Optional[str] = None
pending_actions: List[str] = field(default_factory=list)
database_context: Dict[str, Any] = field(default_factory=dict)
-
- def add_turn(self, user_input: str, assistant_response: str,
- tool_calls: List[Dict[str, Any]] = None,
- tool_results: List[Dict[str, Any]] = None) -> None:
+
+ def add_turn(
+ self,
+ user_input: str,
+ assistant_response: str,
+ tool_calls: List[Dict[str, Any]] = None,
+ tool_results: List[Dict[str, Any]] = None,
+ ) -> None:
"""Add a new conversation turn."""
turn = ConversationTurn(
timestamp=time.time(),
user_input=user_input,
assistant_response=assistant_response,
tool_calls=tool_calls or [],
- tool_results=tool_results or []
+ tool_results=tool_results or [],
)
self.turns.append(turn)
-
+
# Detect incomplete responses and add pending actions
incomplete_actions = self.detect_incomplete_response(assistant_response)
for action in incomplete_actions:
self.add_pending_action(action)
-
+
# Keep only last 10 turns to manage memory
if len(self.turns) > 10:
self.turns = self.turns[-10:]
-
+
def get_context_summary(self) -> str:
"""Generate a context summary for the LLM."""
if not self.turns:
return ""
-
+
context_parts = []
-
+
# Add current topic if available
if self.current_topic:
context_parts.append(f"Current conversation topic: {self.current_topic}")
-
+
# Add pending actions
if self.pending_actions:
context_parts.append(f"Pending actions: {', '.join(self.pending_actions)}")
-
+
# Add recent conversation history (last 3 turns)
recent_turns = self.turns[-3:]
if len(recent_turns) > 1: # Only add if there's actual history
@@ -78,10 +83,12 @@ def get_context_summary(self) -> str:
for i, turn in enumerate(recent_turns[:-1], 1): # Exclude current turn
context_parts.append(f" {i}. User: {turn.user_input[:100]}...")
if turn.assistant_response:
- context_parts.append(f" Assistant: {turn.assistant_response[:100]}...")
-
+ context_parts.append(
+ f" Assistant: {turn.assistant_response[:100]}..."
+ )
+
return "\n".join(context_parts) if context_parts else ""
-
+
def detect_topic(self, user_input: str) -> None:
"""Detect and update current conversation topic."""
# Simple keyword-based topic detection
@@ -89,32 +96,32 @@ def detect_topic(self, user_input: str) -> None:
"revenue": ["revenue", "sales", "income", "earnings"],
"companies": ["company", "companies", "business", "organization"],
"financial_analysis": ["compare", "trend", "analysis", "performance"],
- "database_schema": ["schema", "tables", "structure", "database"]
+ "database_schema": ["schema", "tables", "structure", "database"],
}
-
+
user_lower = user_input.lower()
for topic, keywords in topics.items():
if any(keyword in user_lower for keyword in keywords):
self.current_topic = topic
break
-
+
def add_pending_action(self, action: str) -> None:
"""Add a pending action to be completed."""
if action not in self.pending_actions:
self.pending_actions.append(action)
logger.info("Added pending action", action=action)
-
+
def complete_action(self, action: str) -> None:
"""Mark an action as completed."""
if action in self.pending_actions:
self.pending_actions.remove(action)
logger.info("Completed action", action=action)
-
+
def clear_pending_actions(self) -> None:
"""Clear all pending actions."""
self.pending_actions.clear()
logger.info("Cleared all pending actions")
-
+
def detect_incomplete_response(self, response: str) -> List[str]:
"""Detect if response is incomplete and suggest follow-up actions."""
incomplete_indicators = [
@@ -127,82 +134,86 @@ def detect_incomplete_response(self, response: str) -> List[str]:
"I'll show",
"I'll analyze",
"Looking at",
- "I see that"
+ "I see that",
]
-
+
suggested_actions = []
response_lower = response.lower()
-
+
# Check for incomplete indicators
for indicator in incomplete_indicators:
if indicator.lower() in response_lower:
# Extract what was promised
if "retrieve" in response_lower or "query" in response_lower:
- suggested_actions.append("Execute the database query to get the data")
+ suggested_actions.append(
+ "Execute the database query to get the data"
+ )
elif "show" in response_lower or "display" in response_lower:
suggested_actions.append("Display the requested information")
elif "analyze" in response_lower or "compare" in response_lower:
suggested_actions.append("Perform the analysis as mentioned")
break
-
+
# Check for specific incomplete patterns
if "companies in the technology sector" in response_lower:
suggested_actions.append("Query and display Technology sector companies")
elif "revenue trends" in response_lower:
- suggested_actions.append("Query and analyze revenue trends across companies")
+ suggested_actions.append(
+ "Query and analyze revenue trends across companies"
+ )
elif "compare" in response_lower and "companies" in response_lower:
suggested_actions.append("Compare companies as requested")
-
+
return suggested_actions
class ConversationManager:
"""Manages conversation contexts and provides enhanced prompting."""
-
+
def __init__(self):
"""Initialize conversation manager."""
self.conversations: Dict[str, ConversationContext] = {}
self.current_conversation_id: Optional[str] = None
-
+
def start_conversation(self, conversation_id: Optional[str] = None) -> str:
"""Start a new conversation or get existing one."""
if conversation_id is None:
conversation_id = f"conv_{int(time.time() * 1000)}"
-
+
if conversation_id not in self.conversations:
self.conversations[conversation_id] = ConversationContext(conversation_id)
logger.info("Started new conversation", conversation_id=conversation_id)
-
+
self.current_conversation_id = conversation_id
return conversation_id
-
+
def get_current_context(self) -> Optional[ConversationContext]:
"""Get current conversation context."""
if self.current_conversation_id:
return self.conversations.get(self.current_conversation_id)
return None
-
+
def enhance_system_prompt(self, base_prompt: str, user_input: str) -> str:
"""Enhance system prompt with conversation context."""
context = self.get_current_context()
if not context:
return base_prompt
-
+
# Update topic based on current input
context.detect_topic(user_input)
-
+
# Get context summary
context_summary = context.get_context_summary()
-
+
enhanced_prompt = base_prompt
-
+
if context_summary:
enhanced_prompt += f"\n\nCONVERSATION CONTEXT:\n{context_summary}"
-
+
# Add specific guidance based on context
if context.pending_actions:
enhanced_prompt += "\n\nIMPORTANT: You have pending actions to complete. Make sure to follow through with the necessary tool calls to provide complete answers."
-
+
# Add topic-specific guidance
if context.current_topic == "revenue":
enhanced_prompt += "\n\nCONTEXT: User is interested in revenue/financial data. Use SQL tools to get specific numbers and trends."
@@ -210,25 +221,25 @@ def enhance_system_prompt(self, base_prompt: str, user_input: str) -> str:
enhanced_prompt += "\n\nCONTEXT: User is asking about companies. Use SQL tools to get company information and details."
elif context.current_topic == "financial_analysis":
enhanced_prompt += "\n\nCONTEXT: User wants financial analysis/comparisons. Execute queries to get data, then provide insights and comparisons."
-
+
enhanced_prompt += "\n\nCRITICAL: Always complete your analysis. If you identify relevant data to retrieve, execute the SQL queries and provide the actual results, not just acknowledgments."
-
+
return enhanced_prompt
-
+
def should_auto_continue(self) -> bool:
"""Check if system should automatically continue with pending actions."""
context = self.get_current_context()
return bool(context and context.pending_actions)
-
+
def generate_follow_up_prompt(self) -> Optional[str]:
"""Generate a follow-up prompt to complete pending actions."""
context = self.get_current_context()
if not context or not context.pending_actions:
return None
-
+
# Generate prompt based on pending actions
action = context.pending_actions[0] # Take first pending action
-
+
if "query" in action.lower() and "technology sector" in action.lower():
return "Show me all companies in the Technology sector"
elif "revenue trends" in action.lower():
@@ -239,16 +250,22 @@ def generate_follow_up_prompt(self) -> Optional[str]:
return "Please execute the database query to get the data"
else:
return "Please continue and complete the previous response"
-
- def add_turn(self, user_input: str, assistant_response: str,
- tool_calls: List[Dict[str, Any]] = None,
- tool_results: List[Dict[str, Any]] = None) -> None:
+
+ def add_turn(
+ self,
+ user_input: str,
+ assistant_response: str,
+ tool_calls: List[Dict[str, Any]] = None,
+ tool_results: List[Dict[str, Any]] = None,
+ ) -> None:
"""Add a turn to current conversation."""
context = self.get_current_context()
if context:
context.add_turn(user_input, assistant_response, tool_calls, tool_results)
-
- def detect_incomplete_response(self, user_input: str, assistant_response: str) -> bool:
+
+ def detect_incomplete_response(
+ self, user_input: str, assistant_response: str
+ ) -> bool:
"""Detect if the assistant response seems incomplete."""
# Check for common incomplete response patterns
incomplete_indicators = [
@@ -257,9 +274,9 @@ def detect_incomplete_response(self, user_input: str, assistant_response: str) -
"I see that there is",
"I'll retrieve",
"Let me get",
- "I'll check"
+ "I'll check",
]
-
+
# Check if response contains tool identification without execution
if any(indicator in assistant_response for indicator in incomplete_indicators):
# Check if it actually contains meaningful data/results
@@ -269,17 +286,17 @@ def detect_incomplete_response(self, user_input: str, assistant_response: str) -
"shows",
"Total:",
"revenue of",
- "companies are"
+ "companies are",
]
if not any(data in assistant_response for data in data_indicators):
return True
-
+
# Check for very short responses that don't answer the question
if len(assistant_response.strip()) < 100 and "?" not in user_input:
return True
-
+
return False
-
+
def get_continuation_prompt(self) -> str:
"""Get a prompt to encourage continuation of incomplete responses."""
return """Continue with your analysis. Execute the necessary SQL queries to get the actual data and provide specific results to the user. Don't just describe what you will do - do it and show the results."""
@@ -294,4 +311,4 @@ def get_conversation_manager() -> ConversationManager:
global _conversation_manager
if _conversation_manager is None:
_conversation_manager = ConversationManager()
- return _conversation_manager
\ No newline at end of file
+ return _conversation_manager
diff --git a/src/core/driver.py b/src/core/driver.py
index a3a4860..b9fc604 100644
--- a/src/core/driver.py
+++ b/src/core/driver.py
@@ -16,10 +16,17 @@
from .config import Config, get_config, validate_configuration
from .errors import (
- FACTError, ConfigurationError, ConnectionError, ToolExecutionError,
- classify_error, create_user_friendly_message, log_error_with_context,
- provide_graceful_degradation, CacheError
+ FACTError,
+ ConfigurationError,
+ ConnectionError,
+ ToolExecutionError,
+ classify_error,
+ create_user_friendly_message,
+ log_error_with_context,
+ provide_graceful_degradation,
+ CacheError,
)
+
try:
# Try relative imports first (when used as package)
from ..db.connection import DatabaseManager
@@ -32,11 +39,12 @@
# Fall back to absolute imports (when run as script)
import sys
from pathlib import Path
+
# Add src to path if not already there
src_path = str(Path(__file__).parent.parent)
if src_path not in sys.path:
sys.path.insert(0, src_path)
-
+
from db.connection import DatabaseManager
from tools.decorators import get_tool_registry
from tools.connectors.sql import initialize_sql_tool
@@ -47,17 +55,18 @@
logger = structlog.get_logger(__name__)
+
class FACTDriver:
"""
-
+
Manages cache control, query processing, tool execution, and system coordination
following the FACT architecture principles.
"""
-
+
def __init__(self, config: Optional[Config] = None):
"""
Initialize FACT driver with configuration.
-
+
Args:
config: Optional configuration instance (creates default if None)
"""
@@ -66,19 +75,19 @@ def __init__(self, config: Optional[Config] = None):
self.tool_registry = get_tool_registry()
self.cache_system: Optional[FACTCacheSystem] = None
self._initialized = False
-
+
# Monitoring and metrics
self.metrics_collector = get_metrics_collector()
-
+
# Cache resilience components
self.cache_circuit_breaker: Optional[CacheCircuitBreaker] = None
self.resilient_cache: Optional[ResilientCacheWrapper] = None
self._cache_degraded = False
-
+
async def initialize(self) -> None:
"""
Initialize the FACT system components.
-
+
Raises:
ConfigurationError: If configuration is invalid
ConnectionError: If service connections fail
@@ -86,54 +95,56 @@ async def initialize(self) -> None:
if self._initialized:
logger.info("FACT driver already initialized")
return
-
+
try:
logger.info("Initializing FACT system")
-
+
# Validate configuration
validate_configuration(self.config)
-
+
# Initialize database
await self._initialize_database()
-
+
# Initialize cache system
await self._initialize_cache()
-
+
# Initialize tools
await self._initialize_tools()
-
+
# Test connections
await self._test_connections()
-
+
self._initialized = True
logger.info("FACT system initialized successfully")
-
+
except Exception as e:
logger.error("FACT system initialization failed", error=str(e))
raise ConfigurationError(f"System initialization failed: {e}")
-
+
async def process_query(self, user_input: str) -> str:
"""
Process a user query through the FACT pipeline.
-
+
Args:
user_input: Natural language query from user
-
+
Returns:
Generated response string
-
+
Raises:
FACTError: If query processing fails
"""
if not self._initialized:
await self.initialize()
-
+
query_id = f"query_{int(time.time() * 1000)}"
start_time = time.time()
-
+
try:
- logger.info("Processing user query", query_id=query_id, query=user_input[:100])
-
+ logger.info(
+ "Processing user query", query_id=query_id, query=user_input[:100]
+ )
+
# Step 1: Check cache first (cache-first pattern) with resilience
cached_response = None
if not self._cache_degraded:
@@ -146,336 +157,361 @@ async def process_query(self, user_input: str) -> str:
cached_response = cache_entry.content
elif self.cache_system:
# Fallback to direct cache system
- cached_response = await self.cache_system.get_cached_response(user_input)
-
+ cached_response = await self.cache_system.get_cached_response(
+ user_input
+ )
+
except CacheError as e:
if "CIRCUIT_BREAKER" in str(e.error_code):
- logger.info("Cache circuit breaker active - proceeding without cache",
- query_id=query_id,
- circuit_state=self.cache_circuit_breaker.get_state().value if self.cache_circuit_breaker else "unknown")
+ logger.info(
+ "Cache circuit breaker active - proceeding without cache",
+ query_id=query_id,
+ circuit_state=(
+ self.cache_circuit_breaker.get_state().value
+ if self.cache_circuit_breaker
+ else "unknown"
+ ),
+ )
else:
- logger.warning("Cache operation failed - continuing without cache",
- query_id=query_id,
- error=str(e))
+ logger.warning(
+ "Cache operation failed - continuing without cache",
+ query_id=query_id,
+ error=str(e),
+ )
except Exception as e:
- logger.warning("Unexpected cache error - continuing without cache",
- query_id=query_id,
- error=str(e))
-
+ logger.warning(
+ "Unexpected cache error - continuing without cache",
+ query_id=query_id,
+ error=str(e),
+ )
+
if cached_response:
# Cache hit - return cached response
end_time = time.time()
latency = (end_time - start_time) * 1000
-
- logger.info("Cache hit - returning cached response",
- query_id=query_id,
- latency_ms=latency)
-
+
+ logger.info(
+ "Cache hit - returning cached response",
+ query_id=query_id,
+ latency_ms=latency,
+ )
+
return cached_response
-
+
# Step 2: Cache miss - process query normally
logger.info("Cache miss - processing query with LLM", query_id=query_id)
-
+
# Prepare messages for LLM
- messages = [
- {
- "role": "user",
- "content": user_input
- }
- ]
-
+ messages = [{"role": "user", "content": user_input}]
+
# Get tool schemas for LLM
tool_schemas = self.tool_registry.export_all_schemas()
-
+
# Make initial LLM call with cache control
response = await self._call_llm_with_cache(
- messages=messages,
- tools=tool_schemas,
- cache_mode="read"
+ messages=messages, tools=tool_schemas, cache_mode="read"
)
-
+
# Handle tool calls if present (Anthropic format)
tool_use_blocks = []
- if hasattr(response, 'content') and response.content:
+ if hasattr(response, "content") and response.content:
for block in response.content:
- if hasattr(block, 'type') and block.type == 'tool_use':
+ if hasattr(block, "type") and block.type == "tool_use":
tool_use_blocks.append(block)
-
+
if tool_use_blocks:
logger.info("Processing tool calls", count=len(tool_use_blocks))
-
+
# Add assistant message with tool_use blocks to conversation
- assistant_message = {
- "role": "assistant",
- "content": response.content
- }
+ assistant_message = {"role": "assistant", "content": response.content}
messages.append(assistant_message)
-
+
# Execute tool calls
tool_results = await self._execute_tool_calls(tool_use_blocks)
-
+
# Add tool results to message history
messages.extend(tool_results)
-
+
# Get final response with tool results
response = await self._call_llm_with_cache(
- messages=messages,
- tools=tool_schemas,
- cache_mode="read"
+ messages=messages, tools=tool_schemas, cache_mode="read"
)
-
+
# Check if the final response also has tool calls
final_tool_blocks = []
- if hasattr(response, 'content') and response.content:
+ if hasattr(response, "content") and response.content:
for block in response.content:
- if hasattr(block, 'type') and block.type == 'tool_use':
+ if hasattr(block, "type") and block.type == "tool_use":
final_tool_blocks.append(block)
-
+
# Execute any final tool calls and continue until we get text response
max_iterations = 5 # Prevent infinite loops
iteration = 0
-
+
while final_tool_blocks and iteration < max_iterations:
- logger.info("Processing final tool calls", count=len(final_tool_blocks), iteration=iteration)
-
+ logger.info(
+ "Processing final tool calls",
+ count=len(final_tool_blocks),
+ iteration=iteration,
+ )
+
# Add assistant message with tool_use blocks
assistant_message = {
"role": "assistant",
- "content": response.content
+ "content": response.content,
}
messages.append(assistant_message)
-
+
# Execute final tool calls
- final_tool_results = await self._execute_tool_calls(final_tool_blocks)
-
+ final_tool_results = await self._execute_tool_calls(
+ final_tool_blocks
+ )
+
# Add final tool results
messages.extend(final_tool_results)
-
+
# Get final response
response = await self._call_llm_with_cache(
- messages=messages,
- tools=tool_schemas,
- cache_mode="read"
+ messages=messages, tools=tool_schemas, cache_mode="read"
)
-
+
# Check if this response also has tool calls
final_tool_blocks = []
- if hasattr(response, 'content') and response.content:
+ if hasattr(response, "content") and response.content:
for block in response.content:
- if hasattr(block, 'type') and block.type == 'tool_use':
+ if hasattr(block, "type") and block.type == "tool_use":
final_tool_blocks.append(block)
-
+
iteration += 1
-
+
# Extract response content from Anthropic SDK response
response_text = ""
- if hasattr(response, 'content') and response.content:
+ if hasattr(response, "content") and response.content:
for block in response.content:
- if hasattr(block, 'type') and block.type == 'text':
+ if hasattr(block, "type") and block.type == "text":
response_text += block.text
-
+
# If no text content found, provide informative message
if not response_text:
logger.warning("No text response received from LLM", query_id=query_id)
response_text = "I apologize, but I was unable to generate a proper response. Please try rephrasing your question."
-
+
# Step 3: Store response in cache for future use with resilience
if not self._cache_degraded and response_text:
try:
if self.resilient_cache:
# Use resilient cache with circuit breaker
query_hash = self.resilient_cache.generate_hash(user_input)
- cache_entry = await self.resilient_cache.store(query_hash, response_text)
+ cache_entry = await self.resilient_cache.store(
+ query_hash, response_text
+ )
if cache_entry:
logger.debug("Response stored in cache", query_id=query_id)
else:
- logger.debug("Response not suitable for caching", query_id=query_id)
+ logger.debug(
+ "Response not suitable for caching", query_id=query_id
+ )
elif self.cache_system:
# Fallback to direct cache system
- cache_stored = await self.cache_system.store_response(user_input, response_text)
+ cache_stored = await self.cache_system.store_response(
+ user_input, response_text
+ )
if cache_stored:
logger.debug("Response stored in cache", query_id=query_id)
else:
- logger.debug("Response not suitable for caching", query_id=query_id)
-
+ logger.debug(
+ "Response not suitable for caching", query_id=query_id
+ )
+
except CacheError as e:
if "CIRCUIT_BREAKER" in str(e.error_code):
- logger.debug("Cache storage skipped - circuit breaker active", query_id=query_id)
+ logger.debug(
+ "Cache storage skipped - circuit breaker active",
+ query_id=query_id,
+ )
else:
- logger.warning("Cache storage failed - continuing without cache storage",
- query_id=query_id,
- error=str(e))
+ logger.warning(
+ "Cache storage failed - continuing without cache storage",
+ query_id=query_id,
+ error=str(e),
+ )
except Exception as e:
- logger.warning("Unexpected cache storage error",
- query_id=query_id,
- error=str(e))
-
+ logger.warning(
+ "Unexpected cache storage error",
+ query_id=query_id,
+ error=str(e),
+ )
+
# Log performance metrics
end_time = time.time()
latency = (end_time - start_time) * 1000
-
- logger.info("Query processed successfully",
- query_id=query_id,
- latency_ms=latency,
- response_length=len(response_text))
-
+
+ logger.info(
+ "Query processed successfully",
+ query_id=query_id,
+ latency_ms=latency,
+ response_length=len(response_text),
+ )
+
return response_text
-
+
except Exception as e:
end_time = time.time()
latency = (end_time - start_time) * 1000
-
+
# Record error in metrics collector
self.metrics_collector.record_tool_execution(
tool_name="fact_query",
success=False,
execution_time=latency,
error_type=type(e).__name__,
- metadata={"query_id": query_id}
+ metadata={"query_id": query_id},
)
-
+
# Log error with context
- log_error_with_context(e, {
- "query_id": query_id,
- "user_input": user_input[:100],
- "latency_ms": latency
- })
-
+ log_error_with_context(
+ e,
+ {
+ "query_id": query_id,
+ "user_input": user_input[:100],
+ "latency_ms": latency,
+ },
+ )
+
# Handle error with graceful degradation
error_category = classify_error(e)
-
+
if error_category in ["connectivity", "tool_execution"]:
return provide_graceful_degradation(error_category)
else:
return create_user_friendly_message(e)
-
+
async def _initialize_database(self) -> None:
"""Initialize database connection and schema."""
try:
self.database_manager = DatabaseManager(self.config.database_path)
await self.database_manager.initialize_database()
logger.info("Database initialized successfully")
-
+
except Exception as e:
logger.error("Database initialization failed", error=str(e))
raise ConfigurationError(f"Database initialization failed: {e}")
-
+
async def _initialize_cache(self) -> None:
"""Initialize cache system with resilience and circuit breaker protection."""
try:
cache_config = self.config.cache_config
-
+
# Initialize base cache system
self.cache_system = await initialize_cache_system(
- config=cache_config,
- enable_background_tasks=True
+ config=cache_config, enable_background_tasks=True
)
-
+
# Initialize circuit breaker for cache resilience
from ..cache.resilience import CircuitBreakerConfig
-
+
circuit_config = CircuitBreakerConfig(
failure_threshold=5, # Open after 5 failures
success_threshold=3, # Close after 3 successes
timeout_seconds=60.0, # Wait 60s before retry
rolling_window_seconds=300.0, # 5-minute window
gradual_recovery=True,
- recovery_factor=0.5 # 50% of requests during recovery
+ recovery_factor=0.5, # 50% of requests during recovery
)
-
+
self.cache_circuit_breaker = CacheCircuitBreaker(circuit_config)
-
+
# Wrap cache system with resilient wrapper
- if hasattr(self.cache_system, 'cache_manager'):
+ if hasattr(self.cache_system, "cache_manager"):
self.resilient_cache = ResilientCacheWrapper(
- self.cache_system.cache_manager,
- self.cache_circuit_breaker
+ self.cache_system.cache_manager, self.cache_circuit_breaker
)
-
+
# Start health monitoring
await self.resilient_cache.start_monitoring()
-
- logger.info("Cache system with resilience initialized successfully",
- prefix=cache_config["prefix"],
- max_size=cache_config["max_size"],
- circuit_breaker_enabled=True)
-
+
+ logger.info(
+ "Cache system with resilience initialized successfully",
+ prefix=cache_config["prefix"],
+ max_size=cache_config["max_size"],
+ circuit_breaker_enabled=True,
+ )
+
except Exception as e:
logger.error("Cache system initialization failed", error=str(e))
# Enable graceful degradation - continue without cache
self._cache_degraded = True
logger.warning("Continuing with cache degradation mode")
-
+
async def _initialize_tools(self) -> None:
"""Initialize and register system tools."""
try:
# Initialize SQL tool with database manager
initialize_sql_tool(self.database_manager)
-
+
# Log registered tools
tool_info = self.tool_registry.get_tool_info()
logger.info("Tools initialized", **tool_info)
-
+
except Exception as e:
logger.error("Tool initialization failed", error=str(e))
raise ConfigurationError(f"Tool initialization failed: {e}")
-
+
async def _test_connections(self) -> None:
"""Test connections to external services."""
# Skip API validation if configured for testing
if os.getenv("SKIP_API_VALIDATION", "false").lower() == "true":
logger.info("Skipping API validation for testing environment")
return
-
+
try:
# Test database connection
if self.database_manager:
await self.database_manager.get_database_info()
logger.info("Database connection test passed")
-
+
# Test LLM connection with direct Anthropic SDK
client = anthropic.Anthropic(api_key=self.config.anthropic_api_key)
test_response = client.messages.create(
model=self.config.claude_model,
messages=[{"role": "user", "content": "Test"}],
- max_tokens=10
+ max_tokens=10,
)
-
+
if test_response:
logger.info("LLM connection test passed")
-
+
except Exception as e:
logger.error("Connection test failed", error=str(e))
raise ConnectionError(f"Service connection test failed: {e}")
-
- async def _call_llm_with_cache(self,
- messages: List[Dict[str, Any]],
- tools: List[Dict[str, Any]],
- cache_mode: str = "read") -> Any:
+
+ async def _call_llm_with_cache(
+ self,
+ messages: List[Dict[str, Any]],
+ tools: List[Dict[str, Any]],
+ cache_mode: str = "read",
+ ) -> Any:
"""
Call LLM with cache control and tool support.
-
+
Args:
messages: Message history for LLM
tools: Available tools for LLM
cache_mode: Cache mode ('read' or 'write')
-
+
Returns:
LLM response object
"""
try:
# Configure cache control
- cache_control = {
- "mode": cache_mode,
- "prefix": self.config.cache_prefix
- }
-
+ cache_control = {"mode": cache_mode, "prefix": self.config.cache_prefix}
+
# Track cache behavior
# Cache hits/misses can be tracked via tool execution metadata if needed
-
+
# Make LLM call with direct Anthropic SDK
client = anthropic.Anthropic(api_key=self.config.anthropic_api_key)
-
+
# Anthropic API requires system prompt as separate parameter, not in messages
response = client.messages.create(
model=self.config.claude_model,
@@ -486,46 +522,47 @@ async def _call_llm_with_cache(self,
tools=tools if tools else None,
tool_choice={"type": "any"} if tools else None,
)
-
+
return response
-
+
except Exception as e:
logger.error("LLM call failed", error=str(e), cache_mode=cache_mode)
raise ConnectionError(f"LLM call failed: {e}")
-
+
async def _execute_tool_calls(self, tool_calls: List[Any]) -> List[Dict[str, Any]]:
"""
Execute tool calls and format results as messages.
-
+
Args:
tool_calls: List of tool calls from LLM
-
+
Returns:
List of tool result messages
"""
tool_messages = []
-
+
for call in tool_calls:
try:
# Record tool execution in metrics_collector
-
+
# Extract tool information (Anthropic format)
tool_name = call.name
tool_args = call.input
-
+
# tool_args should already be a dict in Anthropic format
if isinstance(tool_args, str):
import json
+
tool_args = json.loads(tool_args)
# Get tool definition
tool_definition = self.tool_registry.get_tool(tool_name)
-
+
# Execute tool with proper async handling
if asyncio.iscoroutinefunction(tool_definition.function):
result = await tool_definition.function(**tool_args)
else:
result = tool_definition.function(**tool_args)
-
+
# Format as tool message (Anthropic format)
tool_message = {
"role": "user",
@@ -533,33 +570,35 @@ async def _execute_tool_calls(self, tool_calls: List[Any]) -> List[Dict[str, Any
{
"type": "tool_result",
"tool_use_id": call.id,
- "content": str(result) if not isinstance(result, str) else result
+ "content": (
+ str(result) if not isinstance(result, str) else result
+ ),
}
- ]
+ ],
}
-
- logger.info("Tool executed successfully",
- tool_name=tool_name,
- execution_time=result.get("execution_time_ms", 0))
+
+ logger.info(
+ "Tool executed successfully",
+ tool_name=tool_name,
+ execution_time=result.get("execution_time_ms", 0),
+ )
self.metrics_collector.record_tool_execution(
tool_name=tool_name,
success=True,
execution_time=result.get("execution_time_ms", 0),
- metadata={"args": tool_args}
+ metadata={"args": tool_args},
)
-
+
except Exception as e:
- logger.error("Tool execution failed",
- tool_name=tool_name,
- error=str(e))
+ logger.error("Tool execution failed", tool_name=tool_name, error=str(e))
self.metrics_collector.record_tool_execution(
tool_name=tool_name,
success=False,
execution_time=0,
error_type=str(e),
- metadata={"args": tool_args}
+ metadata={"args": tool_args},
)
-
+
# Format error as tool message (Anthropic format)
tool_message = {
"role": "user",
@@ -567,29 +606,31 @@ async def _execute_tool_calls(self, tool_calls: List[Any]) -> List[Dict[str, Any
{
"type": "tool_result",
"tool_use_id": call.id,
- "content": str({
- "error": "Tool execution failed",
- "details": str(e),
- "status": "failed"
- })
+ "content": str(
+ {
+ "error": "Tool execution failed",
+ "details": str(e),
+ "status": "failed",
+ }
+ ),
}
- ]
+ ],
}
-
+
tool_messages.append(tool_message)
-
+
return tool_messages
-
+
def get_metrics(self) -> Dict[str, Any]:
"""
Get system performance metrics.
-
+
Returns:
Dictionary containing performance metrics
"""
# Use the unified metrics collector for system metrics
sys_metrics = self.metrics_collector.get_system_metrics()
-
+
# Get cache metrics including circuit breaker metrics
cache_metrics = {}
if not self._cache_degraded:
@@ -597,11 +638,11 @@ def get_metrics(self) -> Dict[str, Any]:
if self.resilient_cache:
# Get comprehensive metrics from resilient cache
resilient_metrics = self.resilient_cache.get_metrics()
-
+
# Extract cache metrics
cache_data = resilient_metrics.get("cache", {})
circuit_data = resilient_metrics.get("circuit_breaker", {})
-
+
cache_metrics = {
"cache_hit_rate": cache_data.get("hit_rate", 0),
"cache_hits": cache_data.get("cache_hits", 0),
@@ -609,13 +650,20 @@ def get_metrics(self) -> Dict[str, Any]:
"cache_total_entries": cache_data.get("total_entries", 0),
"cache_total_size": cache_data.get("total_size", 0),
"cache_token_efficiency": cache_data.get("token_efficiency", 0),
-
# Circuit breaker metrics
"circuit_breaker_state": circuit_data.get("state", "unknown"),
- "circuit_breaker_failures": circuit_data.get("failure_count", 0),
- "circuit_breaker_successes": circuit_data.get("success_count", 0),
- "circuit_breaker_failure_rate": circuit_data.get("failure_rate", 0),
- "circuit_breaker_state_changes": circuit_data.get("state_changes", 0)
+ "circuit_breaker_failures": circuit_data.get(
+ "failure_count", 0
+ ),
+ "circuit_breaker_successes": circuit_data.get(
+ "success_count", 0
+ ),
+ "circuit_breaker_failure_rate": circuit_data.get(
+ "failure_rate", 0
+ ),
+ "circuit_breaker_state_changes": circuit_data.get(
+ "state_changes", 0
+ ),
}
elif self.cache_system:
# Fallback to basic cache metrics
@@ -627,7 +675,7 @@ def get_metrics(self) -> Dict[str, Any]:
"cache_total_entries": basic_cache_metrics.total_entries,
"cache_total_size": basic_cache_metrics.total_size,
"cache_token_efficiency": basic_cache_metrics.token_efficiency,
- "circuit_breaker_state": "disabled"
+ "circuit_breaker_state": "disabled",
}
except Exception as e:
logger.warning("Failed to get cache metrics", error=str(e))
@@ -635,7 +683,7 @@ def get_metrics(self) -> Dict[str, Any]:
"cache_hit_rate": 0,
"cache_hits": 0,
"cache_misses": 0,
- "circuit_breaker_state": "error"
+ "circuit_breaker_state": "error",
}
else:
cache_metrics = {
@@ -643,21 +691,21 @@ def get_metrics(self) -> Dict[str, Any]:
"cache_hits": 0,
"cache_misses": 0,
"cache_degraded": True,
- "circuit_breaker_state": "degraded"
+ "circuit_breaker_state": "degraded",
}
-
+
return {
"total_queries": sys_metrics.total_executions,
"tool_executions": sys_metrics.total_executions,
"error_rate": sys_metrics.error_rate,
"initialized": self._initialized,
- **cache_metrics
+ **cache_metrics,
}
-
+
async def shutdown(self) -> None:
"""Gracefully shutdown the FACT system."""
logger.info("Shutting down FACT system")
-
+
# Shutdown resilient cache monitoring
if self.resilient_cache:
try:
@@ -665,7 +713,7 @@ async def shutdown(self) -> None:
self.resilient_cache = None
except Exception as e:
logger.warning("Error stopping cache monitoring", error=str(e))
-
+
# Shutdown cache circuit breaker
if self.cache_circuit_breaker:
try:
@@ -673,17 +721,17 @@ async def shutdown(self) -> None:
self.cache_circuit_breaker = None
except Exception as e:
logger.warning("Error stopping circuit breaker", error=str(e))
-
+
# Shutdown cache system
if self.cache_system:
await self.cache_system.shutdown()
self.cache_system = None
-
+
# Close database connections
if self.database_manager:
# Database manager handles its own cleanup
pass
-
+
self._initialized = False
self._cache_degraded = False
logger.info("FACT system shutdown complete")
@@ -696,26 +744,26 @@ async def shutdown(self) -> None:
async def get_driver(config: Optional[Config] = None) -> FACTDriver:
"""
Get or create the global FACT driver instance.
-
+
Args:
config: Optional configuration (only used for first creation)
-
+
Returns:
Initialized FACTDriver instance
"""
global _driver_instance
-
+
if _driver_instance is None:
_driver_instance = FACTDriver(config)
await _driver_instance.initialize()
-
+
return _driver_instance
async def shutdown_driver() -> None:
"""Shutdown the global driver instance."""
global _driver_instance
-
+
if _driver_instance:
await _driver_instance.shutdown()
_driver_instance = None
@@ -724,14 +772,14 @@ async def shutdown_driver() -> None:
async def process_user_query(query: str) -> str:
"""
Process user query using the global driver instance.
-
+
This is a compatibility wrapper for the benchmarking framework.
-
+
Args:
query: User query string
-
+
Returns:
Response string
"""
driver = await get_driver()
- return await driver.process_query(query)
\ No newline at end of file
+ return await driver.process_query(query)
diff --git a/src/core/errors.py b/src/core/errors.py
index f08ca62..9dd8fc9 100644
--- a/src/core/errors.py
+++ b/src/core/errors.py
@@ -8,17 +8,21 @@
from typing import Optional, Dict, Any
import structlog
-
logger = structlog.get_logger(__name__)
class FACTError(Exception):
"""Base exception class for all FACT system errors."""
-
- def __init__(self, message: str, error_code: Optional[str] = None, context: Optional[Dict[str, Any]] = None):
+
+ def __init__(
+ self,
+ message: str,
+ error_code: Optional[str] = None,
+ context: Optional[Dict[str, Any]] = None,
+ ):
"""
Initialize FACT error.
-
+
Args:
message: Human-readable error message
error_code: Optional error code for categorization
@@ -32,91 +36,106 @@ def __init__(self, message: str, error_code: Optional[str] = None, context: Opti
class ConfigurationError(FACTError):
"""Raised when system configuration is invalid or missing."""
+
pass
class ConnectionError(FACTError):
"""Raised when API connectivity tests fail."""
+
pass
class AuthenticationError(FACTError):
"""Raised when API authentication fails."""
+
pass
class ValidationError(FACTError):
"""Raised when input validation fails."""
+
pass
class ToolExecutionError(FACTError):
"""Raised when tool execution fails."""
+
pass
class ToolValidationError(FACTError):
"""Raised when tool validation fails."""
+
pass
class ToolNotFoundError(FACTError):
"""Raised when a requested tool is not found."""
+
pass
class UnauthorizedError(FACTError):
"""Raised when user lacks authorization for a tool."""
+
pass
class InvalidArgumentsError(FACTError):
"""Raised when tool arguments are invalid."""
+
pass
class DatabaseError(FACTError):
"""Raised when database operations fail."""
+
pass
class SecurityError(FACTError):
"""Raised when security violations are detected."""
+
pass
class InvalidSQLError(FACTError):
"""Raised when SQL statements are invalid or dangerous."""
+
pass
class CacheError(FACTError):
"""Raised when cache operations fail."""
+
pass
class ToolRegistrationError(FACTError):
"""Raised when tool registration fails."""
+
pass
class FinalRetryError(FACTError):
"""Raised when maximum retry attempts are exceeded."""
+
pass
def classify_error(error: Exception) -> str:
"""
Classify an error into a category for handling strategies.
-
+
Args:
error: Exception to classify
-
+
Returns:
String category of the error
"""
error_type = type(error)
-
+
# Map error types to categories
error_categories = {
ConfigurationError: "configuration",
@@ -135,29 +154,31 @@ def classify_error(error: Exception) -> str:
ToolRegistrationError: "tool_registration",
FinalRetryError: "connectivity",
}
-
+
category = error_categories.get(error_type, "unknown")
-
- logger.debug("Error classified",
- error_type=error_type.__name__,
- category=category,
- message=str(error))
-
+
+ logger.debug(
+ "Error classified",
+ error_type=error_type.__name__,
+ category=category,
+ message=str(error),
+ )
+
return category
def create_user_friendly_message(error: Exception) -> str:
"""
Create a user-friendly error message from an exception.
-
+
Args:
error: Exception to convert
-
+
Returns:
User-friendly error message
"""
category = classify_error(error)
-
+
# Default user-friendly messages by category
friendly_messages = {
"configuration": "System configuration error. Please check your setup.",
@@ -171,18 +192,20 @@ def create_user_friendly_message(error: Exception) -> str:
"tool_registration": "Tool registration failed. Please check your tool configuration.",
"unknown": "An unexpected error occurred. Please try again later.",
}
-
+
# Use specific message if available, otherwise use category default
- if hasattr(error, 'message'):
+ if hasattr(error, "message"):
return error.message
-
+
return friendly_messages.get(category, str(error))
-def log_error_with_context(error: Exception, context: Optional[Dict[str, Any]] = None) -> None:
+def log_error_with_context(
+ error: Exception, context: Optional[Dict[str, Any]] = None
+) -> None:
"""
Log error with full context information.
-
+
Args:
error: Exception to log
context: Additional context information
@@ -192,25 +215,25 @@ def log_error_with_context(error: Exception, context: Optional[Dict[str, Any]] =
"error_message": str(error),
"error_category": classify_error(error),
}
-
+
# Add specific error context if available
- if hasattr(error, 'context'):
+ if hasattr(error, "context"):
error_context.update(error.context)
-
+
# Add provided context
if context:
error_context.update(context)
-
+
logger.error("Error occurred", **error_context)
def provide_graceful_degradation(failed_component: str) -> str:
"""
Provide graceful degradation message for failed components.
-
+
Args:
failed_component: Name of the component that failed
-
+
Returns:
Graceful degradation message
"""
@@ -221,14 +244,13 @@ def provide_graceful_degradation(failed_component: str) -> str:
"anthropic": "Claude API is temporarily unavailable. Please try again later.",
"arcade": "Tool execution service is temporarily unavailable. Please try again later.",
}
-
+
message = degradation_messages.get(
- failed_component,
- "System is experiencing issues. Please try again later."
+ failed_component, "System is experiencing issues. Please try again later."
)
-
- logger.warning("Graceful degradation activated",
- component=failed_component,
- message=message)
-
- return message
\ No newline at end of file
+
+ logger.warning(
+ "Graceful degradation activated", component=failed_component, message=message
+ )
+
+ return message
diff --git a/src/db/connection.py b/src/db/connection.py
index 4d01985..4a1f181 100644
--- a/src/db/connection.py
+++ b/src/db/connection.py
@@ -22,24 +22,25 @@
SAMPLE_COMPANIES,
SAMPLE_FINANCIAL_RECORDS,
QueryResult,
- validate_schema_integrity
+ validate_schema_integrity,
)
except ImportError:
# Fall back to absolute imports (when run as script)
import sys
from pathlib import Path
+
# Add src to path if not already there
src_path = str(Path(__file__).parent.parent)
if src_path not in sys.path:
sys.path.insert(0, src_path)
-
+
from core.errors import DatabaseError, SecurityError, InvalidSQLError
from db.models import (
DATABASE_SCHEMA,
SAMPLE_COMPANIES,
SAMPLE_FINANCIAL_RECORDS,
QueryResult,
- validate_schema_integrity
+ validate_schema_integrity,
)
@@ -49,14 +50,14 @@
class AsyncConnectionPool:
"""
Async connection pool for SQLite database connections.
-
+
Provides connection reuse and management to reduce connection overhead.
"""
-
+
def __init__(self, database_path: str, pool_size: int = 10):
"""
Initialize connection pool.
-
+
Args:
database_path: Path to SQLite database
pool_size: Maximum number of connections in pool
@@ -66,7 +67,7 @@ def __init__(self, database_path: str, pool_size: int = 10):
self.pool: asyncio.Queue = asyncio.Queue(maxsize=pool_size)
self.created_connections = 0
self._lock = asyncio.Lock()
-
+
async def get_connection(self) -> aiosqlite.Connection:
"""Get a connection from the pool or create a new one."""
try:
@@ -80,15 +81,17 @@ async def get_connection(self) -> aiosqlite.Connection:
if self.created_connections < self.pool_size:
conn = await aiosqlite.connect(self.database_path)
self.created_connections += 1
- logger.debug("Created new pooled connection",
- total_connections=self.created_connections)
+ logger.debug(
+ "Created new pooled connection",
+ total_connections=self.created_connections,
+ )
return conn
else:
# Wait for a connection to become available
conn = await self.pool.get()
logger.debug("Retrieved connection from pool after wait")
return conn
-
+
async def return_connection(self, conn: aiosqlite.Connection):
"""Return a connection to the pool."""
try:
@@ -99,9 +102,10 @@ async def return_connection(self, conn: aiosqlite.Connection):
await conn.close()
async with self._lock:
self.created_connections -= 1
- logger.debug("Closed excess connection",
- total_connections=self.created_connections)
-
+ logger.debug(
+ "Closed excess connection", total_connections=self.created_connections
+ )
+
async def close_all(self):
"""Close all connections in the pool."""
connections_closed = 0
@@ -112,25 +116,27 @@ async def close_all(self):
connections_closed += 1
except asyncio.QueueEmpty:
break
-
+
async with self._lock:
self.created_connections = 0
-
- logger.info("Closed all pooled connections", connections_closed=connections_closed)
+
+ logger.info(
+ "Closed all pooled connections", connections_closed=connections_closed
+ )
class DatabaseManager:
"""
Manages SQLite database connections and operations for the FACT system.
-
+
Provides secure database access with read-only query validation,
connection pooling, and performance monitoring.
"""
-
+
def __init__(self, database_path: str, pool_size: int = 10):
"""
Initialize database manager with connection pooling.
-
+
Args:
database_path: Path to SQLite database file
pool_size: Maximum number of connections in pool
@@ -141,18 +147,18 @@ def __init__(self, database_path: str, pool_size: int = 10):
self._ensure_directory_exists()
self._query_plan_cache = {} # Cache for validated queries
self._cache_max_size = 1000
-
+
def _ensure_directory_exists(self) -> None:
"""Ensure the database directory exists."""
db_dir = os.path.dirname(self.database_path)
if db_dir and not os.path.exists(db_dir):
os.makedirs(db_dir, exist_ok=True)
logger.info("Created database directory", path=db_dir)
-
+
async def initialize_database(self) -> None:
"""
Initialize database with schema and sample data.
-
+
Raises:
DatabaseError: If database initialization fails
"""
@@ -161,41 +167,50 @@ async def initialize_database(self) -> None:
# Execute schema creation
await db.executescript(DATABASE_SCHEMA)
await db.commit()
-
+
# Check if data already exists
cursor = await db.execute("SELECT COUNT(*) FROM companies")
company_count = (await cursor.fetchone())[0]
await cursor.close()
-
+
# Check if financial_data table has data
cursor = await db.execute("SELECT COUNT(*) FROM financial_data")
financial_data_count = (await cursor.fetchone())[0]
await cursor.close()
-
+
# Check if benchmarks table has data
cursor = await db.execute("SELECT COUNT(*) FROM benchmarks")
benchmarks_count = (await cursor.fetchone())[0]
await cursor.close()
-
+
if company_count == 0:
# Batch insert sample companies for better performance
- await db.executemany("""
+ await db.executemany(
+ """
INSERT INTO companies (name, symbol, sector, founded_year, employees, market_cap)
VALUES (:name, :symbol, :sector, :founded_year, :employees, :market_cap)
- """, SAMPLE_COMPANIES)
-
+ """,
+ SAMPLE_COMPANIES,
+ )
+
# Batch insert sample financial records
- await db.executemany("""
+ await db.executemany(
+ """
INSERT INTO financial_records (company_id, quarter, year, revenue, profit, expenses)
VALUES (:company_id, :quarter, :year, :revenue, :profit, :expenses)
- """, SAMPLE_FINANCIAL_RECORDS)
-
+ """,
+ SAMPLE_FINANCIAL_RECORDS,
+ )
+
# Batch insert into financial_data table for validation compatibility
- await db.executemany("""
+ await db.executemany(
+ """
INSERT INTO financial_data (company_id, quarter, year, revenue, profit, expenses)
VALUES (:company_id, :quarter, :year, :revenue, :profit, :expenses)
- """, SAMPLE_FINANCIAL_RECORDS)
-
+ """,
+ SAMPLE_FINANCIAL_RECORDS,
+ )
+
# Insert sample benchmark data
sample_benchmarks = [
{
@@ -205,7 +220,7 @@ async def initialize_database(self) -> None:
"cache_hit_rate": 0.0,
"average_response_time_ms": 0.0,
"success_rate": 1.0,
- "notes": "Initial system setup benchmark"
+ "notes": "Initial system setup benchmark",
},
{
"test_name": "cache_warming",
@@ -214,33 +229,39 @@ async def initialize_database(self) -> None:
"cache_hit_rate": 0.0,
"average_response_time_ms": 18.02,
"success_rate": 1.0,
- "notes": "Cache warming performance test"
- }
+ "notes": "Cache warming performance test",
+ },
]
-
+
for benchmark in sample_benchmarks:
- await db.execute("""
+ await db.execute(
+ """
INSERT INTO benchmarks (test_name, duration_ms, queries_executed, cache_hit_rate,
average_response_time_ms, success_rate, notes)
VALUES (:test_name, :duration_ms, :queries_executed, :cache_hit_rate,
:average_response_time_ms, :success_rate, :notes)
- """, benchmark)
-
+ """,
+ benchmark,
+ )
+
await db.commit()
logger.info("Database initialized with sample data")
-
+
# Handle the case where companies exist but other tables don't
elif financial_data_count == 0:
logger.info("Adding missing financial_data records")
# Insert sample financial records into financial_data table
for record in SAMPLE_FINANCIAL_RECORDS:
- await db.execute("""
+ await db.execute(
+ """
INSERT INTO financial_data (company_id, quarter, year, revenue, profit, expenses)
VALUES (:company_id, :quarter, :year, :revenue, :profit, :expenses)
- """, record)
+ """,
+ record,
+ )
await db.commit()
logger.info("Financial data populated")
-
+
if benchmarks_count == 0:
logger.info("Adding missing benchmark records")
# Insert sample benchmark data
@@ -252,7 +273,7 @@ async def initialize_database(self) -> None:
"cache_hit_rate": 0.0,
"average_response_time_ms": 0.0,
"success_rate": 1.0,
- "notes": "Initial system setup benchmark"
+ "notes": "Initial system setup benchmark",
},
{
"test_name": "cache_warming",
@@ -261,108 +282,139 @@ async def initialize_database(self) -> None:
"cache_hit_rate": 0.0,
"average_response_time_ms": 18.02,
"success_rate": 1.0,
- "notes": "Cache warming performance test"
- }
+ "notes": "Cache warming performance test",
+ },
]
-
+
for benchmark in sample_benchmarks:
- await db.execute("""
+ await db.execute(
+ """
INSERT INTO benchmarks (test_name, duration_ms, queries_executed, cache_hit_rate,
average_response_time_ms, success_rate, notes)
VALUES (:test_name, :duration_ms, :queries_executed, :cache_hit_rate,
:average_response_time_ms, :success_rate, :notes)
- """, benchmark)
+ """,
+ benchmark,
+ )
await db.commit()
logger.info("Benchmark data populated")
else:
- logger.info("Database already contains data, skipping sample data insertion")
-
+ logger.info(
+ "Database already contains data, skipping sample data insertion"
+ )
+
# Validate schema integrity
- cursor = await db.execute("SELECT name FROM sqlite_master WHERE type='table'")
+ cursor = await db.execute(
+ "SELECT name FROM sqlite_master WHERE type='table'"
+ )
tables = await cursor.fetchall()
await cursor.close()
-
+
tables_info = [{"name": table[0]} for table in tables]
if not validate_schema_integrity(tables_info):
raise DatabaseError("Database schema validation failed")
-
+
except Exception as e:
logger.error("Database initialization failed", error=str(e))
raise DatabaseError(f"Failed to initialize database: {e}")
+
def validate_sql_query(self, statement: str) -> None:
"""
Validate SQL query for security and syntax with caching.
-
+
Args:
statement: SQL statement to validate
-
+
Raises:
SecurityError: If statement contains dangerous operations
InvalidSQLError: If statement has syntax errors
"""
# Generate cache key for query validation
import hashlib
+
query_hash = hashlib.md5(statement.encode()).hexdigest()
-
+
# Check cache first
if query_hash in self._query_plan_cache:
logger.debug("Query validation cache hit", statement=statement[:100])
return
-
+
normalized_statement = statement.lower().strip()
-
+
# Security check: allow SELECT statements and safe PRAGMA queries
is_select = normalized_statement.startswith("select")
is_safe_pragma = normalized_statement.startswith("pragma table_info")
-
+
if not (is_select or is_safe_pragma):
- raise SecurityError("Only SELECT statements and PRAGMA table_info queries are allowed")
-
+ raise SecurityError(
+ "Only SELECT statements and PRAGMA table_info queries are allowed"
+ )
+
# Enhanced dangerous keyword detection (excluding safe pragma)
dangerous_keywords = [
- "drop", "delete", "update", "insert", "alter", "create",
- "truncate", "replace", "merge", "exec", "execute",
- "attach", "detach", "vacuum", "reindex", "analyze"
+ "drop",
+ "delete",
+ "update",
+ "insert",
+ "alter",
+ "create",
+ "truncate",
+ "replace",
+ "merge",
+ "exec",
+ "execute",
+ "attach",
+ "detach",
+ "vacuum",
+ "reindex",
+ "analyze",
]
-
+
# For PRAGMA queries, only allow table_info
- if normalized_statement.startswith("pragma") and not normalized_statement.startswith("pragma table_info"):
+ if normalized_statement.startswith(
+ "pragma"
+ ) and not normalized_statement.startswith("pragma table_info"):
raise SecurityError("Only PRAGMA table_info queries are allowed")
-
+
# Check for dangerous keywords with word boundaries
import re
+
for keyword in dangerous_keywords:
- pattern = r'\b' + re.escape(keyword) + r'\b'
+ pattern = r"\b" + re.escape(keyword) + r"\b"
if re.search(pattern, normalized_statement, re.IGNORECASE):
raise SecurityError(f"Dangerous SQL keyword detected: {keyword}")
-
+
# Check for SQL injection patterns
# Check for SQL injection patterns (skip for safe PRAGMA queries)
if not normalized_statement.startswith("pragma table_info"):
injection_patterns = [
- r'--', # SQL comments
- r'/\*.*?\*/', # Multi-line comments
- r';\s*\w+', # Multiple statements
- r'\bunion\s+select\b', # Union injection
- r'\bor\s+1\s*=\s*1\b', # Always true conditions
- r'\band\s+1\s*=\s*1\b', # Always true conditions
- r'\bor\s+\'.*?\'\s*=\s*\'.*?\'', # Suspicious OR with string comparisons
- r'\'.*?\'\s*or\s*\'.*?\'', # Injection with OR between quotes
- r'\\x[0-9a-f]{2}', # Hex encoding
+ r"--", # SQL comments
+ r"/\*.*?\*/", # Multi-line comments
+ r";\s*\w+", # Multiple statements
+ r"\bunion\s+select\b", # Union injection
+ r"\bor\s+1\s*=\s*1\b", # Always true conditions
+ r"\band\s+1\s*=\s*1\b", # Always true conditions
+ r"\bor\s+\'.*?\'\s*=\s*\'.*?\'", # Suspicious OR with string comparisons
+ r"\'.*?\'\s*or\s*\'.*?\'", # Injection with OR between quotes
+ r"\\x[0-9a-f]{2}", # Hex encoding
]
-
+
for pattern in injection_patterns:
if re.search(pattern, normalized_statement, re.IGNORECASE):
- raise SecurityError(f"Potential SQL injection pattern detected: {pattern} in query: {normalized_statement[:100]}")
+ raise SecurityError(
+ f"Potential SQL injection pattern detected: {pattern} in query: {normalized_statement[:100]}"
+ )
# Limit query complexity
if len(statement) > 5000:
raise SecurityError("Query too long - potential DoS attack")
-
+
# Count nested subqueries to prevent complex injection attacks
- subquery_count = normalized_statement.count('select')
+ subquery_count = normalized_statement.count("select")
if subquery_count > 5:
- raise SecurityError("Too many nested subqueries - potential injection attack")
-
+ raise SecurityError(
+ "Too many nested subqueries - potential injection attack"
+ )
+
# Basic syntax validation using actual database connection
try:
# Parse SQL to check syntax (without executing)
@@ -370,59 +422,65 @@ def validate_sql_query(self, statement: str) -> None:
conn.execute(f"EXPLAIN QUERY PLAN {statement}")
except sqlite3.Error as e:
raise InvalidSQLError(f"SQL syntax error: {e}")
-
+
# Cache successful validation
if len(self._query_plan_cache) >= self._cache_max_size:
# Simple eviction: remove oldest entries
oldest_keys = list(self._query_plan_cache.keys())[:100]
for key in oldest_keys:
del self._query_plan_cache[key]
-
+
self._query_plan_cache[query_hash] = time.time()
- logger.debug("SQL query validation passed and cached", statement=statement[:100])
-
+ logger.debug(
+ "SQL query validation passed and cached", statement=statement[:100]
+ )
+
def _is_valid_table_name(self, table_name: str) -> bool:
"""
Validate table name to prevent SQL injection.
-
+
Args:
table_name: Table name to validate
-
+
Returns:
True if table name is valid
"""
import re
-
+
# Table name should only contain alphanumeric characters and underscores
- if not re.match(r'^[a-zA-Z_][a-zA-Z0-9_]*$', table_name):
+ if not re.match(r"^[a-zA-Z_][a-zA-Z0-9_]*$", table_name):
return False
-
+
# Reasonable length limit
if len(table_name) > 64:
return False
-
+
# Check against known system tables that should not be accessed
forbidden_tables = {
- 'sqlite_master', 'sqlite_temp_master', 'sqlite_sequence',
- 'sqlite_stat1', 'sqlite_stat2', 'sqlite_stat3', 'sqlite_stat4'
+ "sqlite_master",
+ "sqlite_temp_master",
+ "sqlite_sequence",
+ "sqlite_stat1",
+ "sqlite_stat2",
+ "sqlite_stat3",
+ "sqlite_stat4",
}
-
+
if table_name.lower() in forbidden_tables:
return False
-
+
return True
- logger.debug("SQL query validation passed", statement=statement[:100])
-
+
async def execute_query(self, statement: str) -> QueryResult:
"""
Execute a validated SQL query using connection pool.
-
+
Args:
statement: SQL SELECT statement to execute
-
+
Returns:
QueryResult containing rows, metadata, and timing
-
+
Raises:
DatabaseError: If query execution fails
SecurityError: If statement violates security rules
@@ -430,71 +488,75 @@ async def execute_query(self, statement: str) -> QueryResult:
"""
# Validate query first (with caching)
self.validate_sql_query(statement)
-
+
start_time = time.time()
-
+
try:
# Get connection from pool
db = await self.connection_pool.get_connection()
-
+
try:
# Enable row factory for dictionary-like access
db.row_factory = aiosqlite.Row
-
+
cursor = await db.execute(statement)
rows = await cursor.fetchall()
-
+
# Convert rows to dictionaries
structured_results = []
columns = []
-
+
if rows:
# Get column names from first row
columns = list(rows[0].keys())
-
+
# Convert each row to dictionary
for row in rows:
row_dict = {col: row[col] for col in columns}
structured_results.append(row_dict)
-
+
await cursor.close()
-
+
end_time = time.time()
execution_time_ms = (end_time - start_time) * 1000
-
+
result = QueryResult(
rows=structured_results,
row_count=len(structured_results),
columns=columns,
- execution_time_ms=execution_time_ms
+ execution_time_ms=execution_time_ms,
)
-
- logger.info("Query executed successfully",
- statement=statement[:100],
- row_count=result.row_count,
- execution_time_ms=execution_time_ms)
-
+
+ logger.info(
+ "Query executed successfully",
+ statement=statement[:100],
+ row_count=result.row_count,
+ execution_time_ms=execution_time_ms,
+ )
+
return result
-
+
finally:
# Always return connection to pool
await self.connection_pool.return_connection(db)
-
+
except Exception as e:
end_time = time.time()
execution_time_ms = (end_time - start_time) * 1000
-
- logger.error("Query execution failed",
- statement=statement[:100],
- error=str(e),
- execution_time_ms=execution_time_ms)
-
+
+ logger.error(
+ "Query execution failed",
+ statement=statement[:100],
+ error=str(e),
+ execution_time_ms=execution_time_ms,
+ )
+
raise DatabaseError(f"Query execution failed: {e}")
-
+
async def get_database_info(self) -> Dict[str, Any]:
"""
Get database metadata and statistics.
-
+
Returns:
Dictionary containing database information
"""
@@ -507,35 +569,41 @@ async def get_database_info(self) -> Dict[str, Any]:
""")
tables = await cursor.fetchall()
await cursor.close()
-
+
table_info = {}
for (table_name,) in tables:
# Validate table name to prevent injection
if not self._is_valid_table_name(table_name):
- logger.warning("Invalid table name detected", table_name=table_name)
+ logger.warning(
+ "Invalid table name detected", table_name=table_name
+ )
continue
-
+
# Get row count for each table using parameterized query
# Note: Table names cannot be parameterized, so we validate them first
- cursor = await db.execute(f"SELECT COUNT(*) FROM \"{table_name}\"")
+ cursor = await db.execute(f'SELECT COUNT(*) FROM "{table_name}"')
count = (await cursor.fetchone())[0]
await cursor.close()
table_info[table_name] = {"row_count": count}
-
+
# Get database file size
- file_size = os.path.getsize(self.database_path) if os.path.exists(self.database_path) else 0
-
+ file_size = (
+ os.path.getsize(self.database_path)
+ if os.path.exists(self.database_path)
+ else 0
+ )
+
return {
"database_path": self.database_path,
"file_size_bytes": file_size,
"tables": table_info,
- "total_tables": len(tables)
+ "total_tables": len(tables),
}
-
+
except Exception as e:
logger.error("Failed to get database info", error=str(e))
raise DatabaseError(f"Failed to get database info: {e}")
-
+
async def cleanup(self):
"""
Cleanup database resources including connection pool.
@@ -545,12 +613,12 @@ async def cleanup(self):
logger.info("Database manager cleanup completed")
except Exception as e:
logger.error("Database cleanup failed", error=str(e))
-
+
@asynccontextmanager
async def get_connection(self):
"""
Get an async database connection context manager.
-
+
Yields:
aiosqlite.Connection: Database connection
"""
@@ -568,11 +636,11 @@ async def get_connection(self):
def create_database_manager(database_path: str) -> DatabaseManager:
"""
Create and initialize a database manager instance.
-
+
Args:
database_path: Path to SQLite database file
-
+
Returns:
Configured DatabaseManager instance
"""
- return DatabaseManager(database_path)
\ No newline at end of file
+ return DatabaseManager(database_path)
diff --git a/src/db/models.py b/src/db/models.py
index 859b9a2..41e4a03 100644
--- a/src/db/models.py
+++ b/src/db/models.py
@@ -10,13 +10,13 @@
from datetime import datetime
import structlog
-
logger = structlog.get_logger(__name__)
@dataclass
class FinancialRecord:
"""Represents a financial record in the database."""
+
id: int
company: str
quarter: str
@@ -31,6 +31,7 @@ class FinancialRecord:
@dataclass
class Company:
"""Represents a company in the database."""
+
id: int
name: str
symbol: str
@@ -45,6 +46,7 @@ class Company:
@dataclass
class QueryResult:
"""Represents the result of a database query."""
+
rows: List[Dict[str, Any]]
row_count: int
columns: List[str]
@@ -131,7 +133,7 @@ class QueryResult:
"sector": "Technology",
"founded_year": 1995,
"employees": 50000,
- "market_cap": 250000000000.0
+ "market_cap": 250000000000.0,
},
{
"name": "FinanceFirst LLC",
@@ -139,7 +141,7 @@ class QueryResult:
"sector": "Financial Services",
"founded_year": 1988,
"employees": 25000,
- "market_cap": 125000000000.0
+ "market_cap": 125000000000.0,
},
{
"name": "HealthTech Solutions",
@@ -147,7 +149,7 @@ class QueryResult:
"sector": "Healthcare",
"founded_year": 2005,
"employees": 15000,
- "market_cap": 75000000000.0
+ "market_cap": 75000000000.0,
},
{
"name": "Green Energy Corp",
@@ -155,7 +157,7 @@ class QueryResult:
"sector": "Energy",
"founded_year": 2010,
"employees": 8000,
- "market_cap": 45000000000.0
+ "market_cap": 45000000000.0,
},
{
"name": "RetailMax Group",
@@ -163,52 +165,223 @@ class QueryResult:
"sector": "Retail",
"founded_year": 1975,
"employees": 120000,
- "market_cap": 95000000000.0
- }
+ "market_cap": 95000000000.0,
+ },
]
SAMPLE_FINANCIAL_RECORDS = [
# TechCorp Inc. (company_id: 1)
- {"company_id": 1, "quarter": "Q1", "year": 2025, "revenue": 25000000000.0, "profit": 5000000000.0, "expenses": 20000000000.0},
- {"company_id": 1, "quarter": "Q4", "year": 2024, "revenue": 28000000000.0, "profit": 6000000000.0, "expenses": 22000000000.0},
- {"company_id": 1, "quarter": "Q3", "year": 2024, "revenue": 26000000000.0, "profit": 5500000000.0, "expenses": 20500000000.0},
- {"company_id": 1, "quarter": "Q2", "year": 2024, "revenue": 24000000000.0, "profit": 4800000000.0, "expenses": 19200000000.0},
- {"company_id": 1, "quarter": "Q1", "year": 2024, "revenue": 23000000000.0, "profit": 4600000000.0, "expenses": 18400000000.0},
-
+ {
+ "company_id": 1,
+ "quarter": "Q1",
+ "year": 2025,
+ "revenue": 25000000000.0,
+ "profit": 5000000000.0,
+ "expenses": 20000000000.0,
+ },
+ {
+ "company_id": 1,
+ "quarter": "Q4",
+ "year": 2024,
+ "revenue": 28000000000.0,
+ "profit": 6000000000.0,
+ "expenses": 22000000000.0,
+ },
+ {
+ "company_id": 1,
+ "quarter": "Q3",
+ "year": 2024,
+ "revenue": 26000000000.0,
+ "profit": 5500000000.0,
+ "expenses": 20500000000.0,
+ },
+ {
+ "company_id": 1,
+ "quarter": "Q2",
+ "year": 2024,
+ "revenue": 24000000000.0,
+ "profit": 4800000000.0,
+ "expenses": 19200000000.0,
+ },
+ {
+ "company_id": 1,
+ "quarter": "Q1",
+ "year": 2024,
+ "revenue": 23000000000.0,
+ "profit": 4600000000.0,
+ "expenses": 18400000000.0,
+ },
# FinanceFirst LLC (company_id: 2)
- {"company_id": 2, "quarter": "Q1", "year": 2025, "revenue": 12000000000.0, "profit": 3000000000.0, "expenses": 9000000000.0},
- {"company_id": 2, "quarter": "Q4", "year": 2024, "revenue": 13500000000.0, "profit": 3200000000.0, "expenses": 10300000000.0},
- {"company_id": 2, "quarter": "Q3", "year": 2024, "revenue": 13000000000.0, "profit": 3100000000.0, "expenses": 9900000000.0},
- {"company_id": 2, "quarter": "Q2", "year": 2024, "revenue": 12500000000.0, "profit": 2900000000.0, "expenses": 9600000000.0},
- {"company_id": 2, "quarter": "Q1", "year": 2024, "revenue": 11800000000.0, "profit": 2800000000.0, "expenses": 9000000000.0},
-
+ {
+ "company_id": 2,
+ "quarter": "Q1",
+ "year": 2025,
+ "revenue": 12000000000.0,
+ "profit": 3000000000.0,
+ "expenses": 9000000000.0,
+ },
+ {
+ "company_id": 2,
+ "quarter": "Q4",
+ "year": 2024,
+ "revenue": 13500000000.0,
+ "profit": 3200000000.0,
+ "expenses": 10300000000.0,
+ },
+ {
+ "company_id": 2,
+ "quarter": "Q3",
+ "year": 2024,
+ "revenue": 13000000000.0,
+ "profit": 3100000000.0,
+ "expenses": 9900000000.0,
+ },
+ {
+ "company_id": 2,
+ "quarter": "Q2",
+ "year": 2024,
+ "revenue": 12500000000.0,
+ "profit": 2900000000.0,
+ "expenses": 9600000000.0,
+ },
+ {
+ "company_id": 2,
+ "quarter": "Q1",
+ "year": 2024,
+ "revenue": 11800000000.0,
+ "profit": 2800000000.0,
+ "expenses": 9000000000.0,
+ },
# HealthTech Solutions (company_id: 3)
- {"company_id": 3, "quarter": "Q1", "year": 2025, "revenue": 8000000000.0, "profit": 1200000000.0, "expenses": 6800000000.0},
- {"company_id": 3, "quarter": "Q4", "year": 2024, "revenue": 8500000000.0, "profit": 1300000000.0, "expenses": 7200000000.0},
- {"company_id": 3, "quarter": "Q3", "year": 2024, "revenue": 8200000000.0, "profit": 1250000000.0, "expenses": 6950000000.0},
- {"company_id": 3, "quarter": "Q2", "year": 2024, "revenue": 7800000000.0, "profit": 1150000000.0, "expenses": 6650000000.0},
- {"company_id": 3, "quarter": "Q1", "year": 2024, "revenue": 7500000000.0, "profit": 1100000000.0, "expenses": 6400000000.0},
-
+ {
+ "company_id": 3,
+ "quarter": "Q1",
+ "year": 2025,
+ "revenue": 8000000000.0,
+ "profit": 1200000000.0,
+ "expenses": 6800000000.0,
+ },
+ {
+ "company_id": 3,
+ "quarter": "Q4",
+ "year": 2024,
+ "revenue": 8500000000.0,
+ "profit": 1300000000.0,
+ "expenses": 7200000000.0,
+ },
+ {
+ "company_id": 3,
+ "quarter": "Q3",
+ "year": 2024,
+ "revenue": 8200000000.0,
+ "profit": 1250000000.0,
+ "expenses": 6950000000.0,
+ },
+ {
+ "company_id": 3,
+ "quarter": "Q2",
+ "year": 2024,
+ "revenue": 7800000000.0,
+ "profit": 1150000000.0,
+ "expenses": 6650000000.0,
+ },
+ {
+ "company_id": 3,
+ "quarter": "Q1",
+ "year": 2024,
+ "revenue": 7500000000.0,
+ "profit": 1100000000.0,
+ "expenses": 6400000000.0,
+ },
# Green Energy Corp (company_id: 4)
- {"company_id": 4, "quarter": "Q1", "year": 2025, "revenue": 5500000000.0, "profit": 800000000.0, "expenses": 4700000000.0},
- {"company_id": 4, "quarter": "Q4", "year": 2024, "revenue": 6000000000.0, "profit": 900000000.0, "expenses": 5100000000.0},
- {"company_id": 4, "quarter": "Q3", "year": 2024, "revenue": 5800000000.0, "profit": 850000000.0, "expenses": 4950000000.0},
- {"company_id": 4, "quarter": "Q2", "year": 2024, "revenue": 5200000000.0, "profit": 750000000.0, "expenses": 4450000000.0},
- {"company_id": 4, "quarter": "Q1", "year": 2024, "revenue": 4800000000.0, "profit": 680000000.0, "expenses": 4120000000.0},
-
+ {
+ "company_id": 4,
+ "quarter": "Q1",
+ "year": 2025,
+ "revenue": 5500000000.0,
+ "profit": 800000000.0,
+ "expenses": 4700000000.0,
+ },
+ {
+ "company_id": 4,
+ "quarter": "Q4",
+ "year": 2024,
+ "revenue": 6000000000.0,
+ "profit": 900000000.0,
+ "expenses": 5100000000.0,
+ },
+ {
+ "company_id": 4,
+ "quarter": "Q3",
+ "year": 2024,
+ "revenue": 5800000000.0,
+ "profit": 850000000.0,
+ "expenses": 4950000000.0,
+ },
+ {
+ "company_id": 4,
+ "quarter": "Q2",
+ "year": 2024,
+ "revenue": 5200000000.0,
+ "profit": 750000000.0,
+ "expenses": 4450000000.0,
+ },
+ {
+ "company_id": 4,
+ "quarter": "Q1",
+ "year": 2024,
+ "revenue": 4800000000.0,
+ "profit": 680000000.0,
+ "expenses": 4120000000.0,
+ },
# RetailMax Group (company_id: 5)
- {"company_id": 5, "quarter": "Q1", "year": 2025, "revenue": 18000000000.0, "profit": 1800000000.0, "expenses": 16200000000.0},
- {"company_id": 5, "quarter": "Q4", "year": 2024, "revenue": 22000000000.0, "profit": 2400000000.0, "expenses": 19600000000.0}, # Holiday season
- {"company_id": 5, "quarter": "Q3", "year": 2024, "revenue": 19000000000.0, "profit": 2000000000.0, "expenses": 17000000000.0},
- {"company_id": 5, "quarter": "Q2", "year": 2024, "revenue": 17500000000.0, "profit": 1750000000.0, "expenses": 15750000000.0},
- {"company_id": 5, "quarter": "Q1", "year": 2024, "revenue": 16800000000.0, "profit": 1650000000.0, "expenses": 15150000000.0},
+ {
+ "company_id": 5,
+ "quarter": "Q1",
+ "year": 2025,
+ "revenue": 18000000000.0,
+ "profit": 1800000000.0,
+ "expenses": 16200000000.0,
+ },
+ {
+ "company_id": 5,
+ "quarter": "Q4",
+ "year": 2024,
+ "revenue": 22000000000.0,
+ "profit": 2400000000.0,
+ "expenses": 19600000000.0,
+ }, # Holiday season
+ {
+ "company_id": 5,
+ "quarter": "Q3",
+ "year": 2024,
+ "revenue": 19000000000.0,
+ "profit": 2000000000.0,
+ "expenses": 17000000000.0,
+ },
+ {
+ "company_id": 5,
+ "quarter": "Q2",
+ "year": 2024,
+ "revenue": 17500000000.0,
+ "profit": 1750000000.0,
+ "expenses": 15750000000.0,
+ },
+ {
+ "company_id": 5,
+ "quarter": "Q1",
+ "year": 2024,
+ "revenue": 16800000000.0,
+ "profit": 1650000000.0,
+ "expenses": 15150000000.0,
+ },
]
def get_sample_queries() -> List[str]:
"""
Get list of sample queries for testing and demonstration.
-
+
Returns:
List of sample SQL queries
"""
@@ -219,27 +392,27 @@ def get_sample_queries() -> List[str]:
"SELECT sector, COUNT(*) as company_count FROM companies GROUP BY sector",
"SELECT c.name, f.quarter, f.year, f.revenue FROM companies c JOIN financial_records f ON c.id = f.company_id WHERE c.symbol = 'TECH' ORDER BY f.year DESC, f.quarter DESC",
"SELECT AVG(revenue) as avg_revenue, AVG(profit) as avg_profit FROM financial_records WHERE year = 2024",
- "SELECT c.name, c.market_cap, f.revenue FROM companies c JOIN financial_records f ON c.id = f.company_id WHERE f.year = 2025 AND f.quarter = 'Q1' ORDER BY c.market_cap DESC"
+ "SELECT c.name, c.market_cap, f.revenue FROM companies c JOIN financial_records f ON c.id = f.company_id WHERE f.year = 2025 AND f.quarter = 'Q1' ORDER BY c.market_cap DESC",
]
def validate_schema_integrity(tables_info: List[Dict[str, Any]]) -> bool:
"""
Validate that the database schema matches expected structure.
-
+
Args:
tables_info: Information about database tables
-
+
Returns:
True if schema is valid, False otherwise
"""
expected_tables = {"companies", "financial_records", "financial_data", "benchmarks"}
actual_tables = {table["name"] for table in tables_info}
-
+
if not expected_tables.issubset(actual_tables):
missing_tables = expected_tables - actual_tables
logger.error("Missing database tables", missing=list(missing_tables))
return False
-
+
logger.info("Database schema validation passed")
- return True
\ No newline at end of file
+ return True
diff --git a/src/monitoring/__init__.py b/src/monitoring/__init__.py
index 6c4fae1..759f49f 100644
--- a/src/monitoring/__init__.py
+++ b/src/monitoring/__init__.py
@@ -7,14 +7,24 @@
# Use try/except to handle both relative and absolute imports
try:
- from .metrics import MetricsCollector, get_metrics_collector, SystemMetrics, ToolExecutionMetric
+ from .metrics import (
+ MetricsCollector,
+ get_metrics_collector,
+ SystemMetrics,
+ ToolExecutionMetric,
+ )
except ImportError:
# Fallback to absolute imports when called from scripts
- from monitoring.metrics import MetricsCollector, get_metrics_collector, SystemMetrics, ToolExecutionMetric
+ from monitoring.metrics import (
+ MetricsCollector,
+ get_metrics_collector,
+ SystemMetrics,
+ ToolExecutionMetric,
+ )
__all__ = [
- 'MetricsCollector',
- 'get_metrics_collector',
- 'SystemMetrics',
- 'ToolExecutionMetric'
-]
\ No newline at end of file
+ "MetricsCollector",
+ "get_metrics_collector",
+ "SystemMetrics",
+ "ToolExecutionMetric",
+]
diff --git a/src/monitoring/metrics.py b/src/monitoring/metrics.py
index 83c01d8..1f3fe52 100644
--- a/src/monitoring/metrics.py
+++ b/src/monitoring/metrics.py
@@ -18,6 +18,7 @@
@dataclass
class ToolExecutionMetric:
"""Represents a single tool execution metric."""
+
tool_name: str
success: bool
execution_time_ms: float
@@ -30,11 +31,12 @@ class ToolExecutionMetric:
@dataclass
class SystemMetrics:
"""Aggregated system metrics."""
+
total_executions: int = 0
successful_executions: int = 0
failed_executions: int = 0
average_execution_time: float = 0.0
- min_execution_time: float = float('inf')
+ min_execution_time: float = float("inf")
max_execution_time: float = 0.0
error_rate: float = 0.0
executions_per_minute: float = 0.0
@@ -45,47 +47,51 @@ class SystemMetrics:
class MetricsCollector:
"""
Collects and aggregates performance metrics for the FACT system.
-
+
Provides real-time metrics collection, aggregation, and reporting
for tool execution performance and system health monitoring.
"""
-
+
def __init__(self, max_history: int = 10000):
"""
Initialize metrics collector.
-
+
Args:
max_history: Maximum number of metrics to keep in memory
"""
self.max_history = max_history
self.metrics_history: deque = deque(maxlen=max_history)
- self.tool_stats: Dict[str, Dict[str, Any]] = defaultdict(lambda: {
- 'count': 0,
- 'success_count': 0,
- 'total_time': 0.0,
- 'min_time': float('inf'),
- 'max_time': 0.0,
- 'recent_executions': deque(maxlen=100)
- })
-
+ self.tool_stats: Dict[str, Dict[str, Any]] = defaultdict(
+ lambda: {
+ "count": 0,
+ "success_count": 0,
+ "total_time": 0.0,
+ "min_time": float("inf"),
+ "max_time": 0.0,
+ "recent_executions": deque(maxlen=100),
+ }
+ )
+
# Thread safety
self._lock = threading.RLock()
-
+
# Start time for rate calculations
self._start_time = time.time()
-
+
logger.info("MetricsCollector initialized", max_history=max_history)
-
- def record_tool_execution(self,
- tool_name: str,
- success: bool,
- execution_time: float,
- user_id: Optional[str] = None,
- error_type: Optional[str] = None,
- metadata: Optional[Dict[str, Any]] = None) -> None:
+
+ def record_tool_execution(
+ self,
+ tool_name: str,
+ success: bool,
+ execution_time: float,
+ user_id: Optional[str] = None,
+ error_type: Optional[str] = None,
+ metadata: Optional[Dict[str, Any]] = None,
+ ) -> None:
"""
Record a tool execution metric.
-
+
Args:
tool_name: Name of the executed tool
success: Whether execution was successful
@@ -96,7 +102,7 @@ def record_tool_execution(self,
"""
with self._lock:
timestamp = time.time()
-
+
# Create metric record
metric = ToolExecutionMetric(
tool_name=tool_name,
@@ -105,81 +111,92 @@ def record_tool_execution(self,
timestamp=timestamp,
user_id=user_id,
error_type=error_type,
- metadata=metadata or {}
+ metadata=metadata or {},
)
-
+
# Add to history
self.metrics_history.append(metric)
-
+
# Update tool-specific stats
tool_stat = self.tool_stats[tool_name]
- tool_stat['count'] += 1
- tool_stat['total_time'] += execution_time
- tool_stat['recent_executions'].append({
- 'timestamp': timestamp,
- 'success': success,
- 'execution_time': execution_time
- })
-
+ tool_stat["count"] += 1
+ tool_stat["total_time"] += execution_time
+ tool_stat["recent_executions"].append(
+ {
+ "timestamp": timestamp,
+ "success": success,
+ "execution_time": execution_time,
+ }
+ )
+
if success:
- tool_stat['success_count'] += 1
-
+ tool_stat["success_count"] += 1
+
# Update min/max times
- tool_stat['min_time'] = min(tool_stat['min_time'], execution_time)
- tool_stat['max_time'] = max(tool_stat['max_time'], execution_time)
-
- logger.debug("Tool execution metric recorded",
- tool_name=tool_name,
- success=success,
- execution_time_ms=execution_time,
- user_id=user_id)
-
+ tool_stat["min_time"] = min(tool_stat["min_time"], execution_time)
+ tool_stat["max_time"] = max(tool_stat["max_time"], execution_time)
+
+ logger.debug(
+ "Tool execution metric recorded",
+ tool_name=tool_name,
+ success=success,
+ execution_time_ms=execution_time,
+ user_id=user_id,
+ )
+
def get_system_metrics(self, time_window_minutes: int = 60) -> SystemMetrics:
"""
Get aggregated system metrics.
-
+
Args:
time_window_minutes: Time window for calculations in minutes
-
+
Returns:
SystemMetrics object with aggregated data
"""
with self._lock:
current_time = time.time()
cutoff_time = current_time - (time_window_minutes * 60)
-
+
# Filter metrics within time window
recent_metrics = [
- metric for metric in self.metrics_history
+ metric
+ for metric in self.metrics_history
if metric.timestamp >= cutoff_time
]
-
+
if not recent_metrics:
return SystemMetrics()
-
+
# Calculate basic stats
total_executions = len(recent_metrics)
successful_executions = sum(1 for m in recent_metrics if m.success)
failed_executions = total_executions - successful_executions
-
+
execution_times = [m.execution_time_ms for m in recent_metrics]
average_execution_time = sum(execution_times) / len(execution_times)
min_execution_time = min(execution_times)
max_execution_time = max(execution_times)
-
- error_rate = (failed_executions / total_executions) * 100 if total_executions > 0 else 0
+
+ error_rate = (
+ (failed_executions / total_executions) * 100
+ if total_executions > 0
+ else 0
+ )
executions_per_minute = total_executions / time_window_minutes
-
+
# Get top tools by usage
tool_counts = defaultdict(int)
for metric in recent_metrics:
tool_counts[metric.tool_name] += 1
-
+
top_tools = [
{"tool_name": tool, "count": count}
- for tool, count in sorted(tool_counts.items(), key=lambda x: x[1], reverse=True)[:10]
+ for tool, count in sorted(
+ tool_counts.items(), key=lambda x: x[1], reverse=True
+ )[:10]
]
-
+
# Get recent errors
recent_errors = [
{
@@ -187,11 +204,14 @@ def get_system_metrics(self, time_window_minutes: int = 60) -> SystemMetrics:
"error_type": m.error_type,
"timestamp": m.timestamp,
"execution_time_ms": m.execution_time_ms,
- "user_id": m.user_id
+ "user_id": m.user_id,
}
- for m in recent_metrics if not m.success
- ][-10:] # Last 10 errors
-
+ for m in recent_metrics
+ if not m.success
+ ][
+ -10:
+ ] # Last 10 errors
+
return SystemMetrics(
total_executions=total_executions,
successful_executions=successful_executions,
@@ -202,16 +222,16 @@ def get_system_metrics(self, time_window_minutes: int = 60) -> SystemMetrics:
error_rate=error_rate,
executions_per_minute=executions_per_minute,
top_tools=top_tools,
- recent_errors=recent_errors
+ recent_errors=recent_errors,
)
-
+
def get_tool_metrics(self, tool_name: str) -> Dict[str, Any]:
"""
Get detailed metrics for a specific tool.
-
+
Args:
tool_name: Name of the tool
-
+
Returns:
Tool-specific metrics
"""
@@ -223,96 +243,114 @@ def get_tool_metrics(self, tool_name: str) -> Dict[str, Any]:
"success_rate": 0.0,
"average_execution_time": 0.0,
"min_execution_time": 0.0,
- "max_execution_time": 0.0
+ "max_execution_time": 0.0,
}
-
+
stats = self.tool_stats[tool_name]
-
+
# Calculate derived metrics
- success_rate = (stats['success_count'] / stats['count']) * 100 if stats['count'] > 0 else 0
- average_time = stats['total_time'] / stats['count'] if stats['count'] > 0 else 0
-
+ success_rate = (
+ (stats["success_count"] / stats["count"]) * 100
+ if stats["count"] > 0
+ else 0
+ )
+ average_time = (
+ stats["total_time"] / stats["count"] if stats["count"] > 0 else 0
+ )
+
# Recent performance (last 24 hours)
current_time = time.time()
cutoff_time = current_time - (24 * 60 * 60) # 24 hours ago
-
+
recent_executions = [
- exec for exec in stats['recent_executions']
- if exec['timestamp'] >= cutoff_time
+ exec
+ for exec in stats["recent_executions"]
+ if exec["timestamp"] >= cutoff_time
]
-
+
recent_success_rate = 0.0
recent_avg_time = 0.0
-
+
if recent_executions:
- recent_successes = sum(1 for exec in recent_executions if exec['success'])
+ recent_successes = sum(
+ 1 for exec in recent_executions if exec["success"]
+ )
recent_success_rate = (recent_successes / len(recent_executions)) * 100
- recent_avg_time = sum(exec['execution_time'] for exec in recent_executions) / len(recent_executions)
-
+ recent_avg_time = sum(
+ exec["execution_time"] for exec in recent_executions
+ ) / len(recent_executions)
+
return {
"tool_name": tool_name,
- "total_executions": stats['count'],
- "successful_executions": stats['success_count'],
- "failed_executions": stats['count'] - stats['success_count'],
+ "total_executions": stats["count"],
+ "successful_executions": stats["success_count"],
+ "failed_executions": stats["count"] - stats["success_count"],
"success_rate": success_rate,
"average_execution_time": average_time,
- "min_execution_time": stats['min_time'] if stats['min_time'] != float('inf') else 0,
- "max_execution_time": stats['max_time'],
+ "min_execution_time": (
+ stats["min_time"] if stats["min_time"] != float("inf") else 0
+ ),
+ "max_execution_time": stats["max_time"],
"recent_24h": {
"executions": len(recent_executions),
"success_rate": recent_success_rate,
- "average_execution_time": recent_avg_time
- }
+ "average_execution_time": recent_avg_time,
+ },
}
-
- def get_user_metrics(self, user_id: str, time_window_minutes: int = 60) -> Dict[str, Any]:
+
+ def get_user_metrics(
+ self, user_id: str, time_window_minutes: int = 60
+ ) -> Dict[str, Any]:
"""
Get metrics for a specific user.
-
+
Args:
user_id: User identifier
time_window_minutes: Time window in minutes
-
+
Returns:
User-specific metrics
"""
with self._lock:
current_time = time.time()
cutoff_time = current_time - (time_window_minutes * 60)
-
+
# Filter metrics for this user within time window
user_metrics = [
- metric for metric in self.metrics_history
+ metric
+ for metric in self.metrics_history
if metric.user_id == user_id and metric.timestamp >= cutoff_time
]
-
+
if not user_metrics:
return {
"user_id": user_id,
"total_executions": 0,
"success_rate": 0.0,
"tools_used": [],
- "average_execution_time": 0.0
+ "average_execution_time": 0.0,
}
-
+
# Calculate user stats
total_executions = len(user_metrics)
successful_executions = sum(1 for m in user_metrics if m.success)
success_rate = (successful_executions / total_executions) * 100
-
+
execution_times = [m.execution_time_ms for m in user_metrics]
average_execution_time = sum(execution_times) / len(execution_times)
-
+
# Tools used by this user
tool_usage = defaultdict(int)
for metric in user_metrics:
tool_usage[metric.tool_name] += 1
-
+
tools_used = [
{"tool_name": tool, "count": count}
- for tool, count in sorted(tool_usage.items(), key=lambda x: x[1], reverse=True)
+ for tool, count in sorted(
+ tool_usage.items(), key=lambda x: x[1], reverse=True
+ )
]
-
+
return {
"user_id": user_id,
"time_window_minutes": time_window_minutes,
@@ -321,17 +359,19 @@ def get_user_metrics(self, user_id: str, time_window_minutes: int = 60) -> Dict[
"failed_executions": total_executions - successful_executions,
"success_rate": success_rate,
"average_execution_time": average_execution_time,
- "tools_used": tools_used
+ "tools_used": tools_used,
}
-
- def get_performance_trends(self, time_window_hours: int = 24, bucket_minutes: int = 60) -> Dict[str, Any]:
+
+ def get_performance_trends(
+ self, time_window_hours: int = 24, bucket_minutes: int = 60
+ ) -> Dict[str, Any]:
"""
Get performance trends over time.
-
+
Args:
time_window_hours: Time window in hours
bucket_minutes: Time bucket size in minutes
-
+
Returns:
Performance trend data
"""
@@ -339,41 +379,42 @@ def get_performance_trends(self, time_window_hours: int = 24, bucket_minutes: in
current_time = time.time()
cutoff_time = current_time - (time_window_hours * 60 * 60)
bucket_size = bucket_minutes * 60
-
+
# Filter metrics within time window
recent_metrics = [
- metric for metric in self.metrics_history
+ metric
+ for metric in self.metrics_history
if metric.timestamp >= cutoff_time
]
-
+
if not recent_metrics:
return {"buckets": [], "summary": {}}
-
+
# Create time buckets
start_time = cutoff_time
buckets = []
-
+
while start_time < current_time:
bucket_end = start_time + bucket_size
bucket_metrics = [
- m for m in recent_metrics
- if start_time <= m.timestamp < bucket_end
+ m for m in recent_metrics if start_time <= m.timestamp < bucket_end
]
-
+
if bucket_metrics:
total = len(bucket_metrics)
successful = sum(1 for m in bucket_metrics if m.success)
execution_times = [m.execution_time_ms for m in bucket_metrics]
-
+
bucket_data = {
"timestamp": start_time,
"total_executions": total,
"successful_executions": successful,
"failed_executions": total - successful,
"success_rate": (successful / total) * 100,
- "average_execution_time": sum(execution_times) / len(execution_times),
+ "average_execution_time": sum(execution_times)
+ / len(execution_times),
"min_execution_time": min(execution_times),
- "max_execution_time": max(execution_times)
+ "max_execution_time": max(execution_times),
}
else:
bucket_data = {
@@ -384,49 +425,47 @@ def get_performance_trends(self, time_window_hours: int = 24, bucket_minutes: in
"success_rate": 0.0,
"average_execution_time": 0.0,
"min_execution_time": 0.0,
- "max_execution_time": 0.0
+ "max_execution_time": 0.0,
}
-
+
buckets.append(bucket_data)
start_time = bucket_end
-
+
# Calculate summary
if recent_metrics:
total_executions = len(recent_metrics)
successful_executions = sum(1 for m in recent_metrics if m.success)
all_execution_times = [m.execution_time_ms for m in recent_metrics]
-
+
summary = {
"time_window_hours": time_window_hours,
"bucket_minutes": bucket_minutes,
"total_executions": total_executions,
"success_rate": (successful_executions / total_executions) * 100,
- "average_execution_time": sum(all_execution_times) / len(all_execution_times),
+ "average_execution_time": sum(all_execution_times)
+ / len(all_execution_times),
"min_execution_time": min(all_execution_times),
- "max_execution_time": max(all_execution_times)
+ "max_execution_time": max(all_execution_times),
}
else:
summary = {}
-
- return {
- "buckets": buckets,
- "summary": summary
- }
-
+
+ return {"buckets": buckets, "summary": summary}
+
def export_metrics(self, format: str = "json") -> str:
"""
Export metrics in specified format.
-
+
Args:
format: Export format ("json" or "csv")
-
+
Returns:
Exported metrics as string
"""
with self._lock:
if format.lower() == "json":
import json
-
+
export_data = {
"system_metrics": self.get_system_metrics().__dict__,
"tool_metrics": {
@@ -435,47 +474,55 @@ def export_metrics(self, format: str = "json") -> str:
},
"performance_trends": self.get_performance_trends(),
"export_timestamp": time.time(),
- "total_metrics_count": len(self.metrics_history)
+ "total_metrics_count": len(self.metrics_history),
}
-
+
return json.dumps(export_data, indent=2)
-
+
elif format.lower() == "csv":
import csv
import io
-
+
output = io.StringIO()
writer = csv.writer(output)
-
+
# Write header
- writer.writerow([
- "timestamp", "tool_name", "success", "execution_time_ms",
- "user_id", "error_type"
- ])
-
+ writer.writerow(
+ [
+ "timestamp",
+ "tool_name",
+ "success",
+ "execution_time_ms",
+ "user_id",
+ "error_type",
+ ]
+ )
+
# Write metrics
for metric in self.metrics_history:
- writer.writerow([
- metric.timestamp,
- metric.tool_name,
- metric.success,
- metric.execution_time_ms,
- metric.user_id or "",
- metric.error_type or ""
- ])
-
+ writer.writerow(
+ [
+ metric.timestamp,
+ metric.tool_name,
+ metric.success,
+ metric.execution_time_ms,
+ metric.user_id or "",
+ metric.error_type or "",
+ ]
+ )
+
return output.getvalue()
-
+
else:
raise ValueError(f"Unsupported export format: {format}")
-
+
def clear_metrics(self, older_than_hours: Optional[int] = None) -> int:
"""
Clear metrics from memory.
-
+
Args:
older_than_hours: Only clear metrics older than this many hours
-
+
Returns:
Number of metrics cleared
"""
@@ -490,17 +537,23 @@ def clear_metrics(self, older_than_hours: Optional[int] = None) -> int:
else:
# Clear only old metrics
cutoff_time = time.time() - (older_than_hours * 60 * 60)
-
+
original_count = len(self.metrics_history)
self.metrics_history = deque(
- (metric for metric in self.metrics_history if metric.timestamp >= cutoff_time),
- maxlen=self.max_history
+ (
+ metric
+ for metric in self.metrics_history
+ if metric.timestamp >= cutoff_time
+ ),
+ maxlen=self.max_history,
)
-
+
cleared_count = original_count - len(self.metrics_history)
- logger.info("Old metrics cleared",
- cleared_count=cleared_count,
- older_than_hours=older_than_hours)
+ logger.info(
+ "Old metrics cleared",
+ cleared_count=cleared_count,
+ older_than_hours=older_than_hours,
+ )
return cleared_count
@@ -510,4 +563,4 @@ def clear_metrics(self, older_than_hours: Optional[int] = None) -> int:
def get_metrics_collector() -> MetricsCollector:
"""Get the global metrics collector instance."""
- return _metrics_collector
\ No newline at end of file
+ return _metrics_collector
diff --git a/src/monitoring/performance_optimizer.py b/src/monitoring/performance_optimizer.py
index a98da57..3ac7f96 100644
--- a/src/monitoring/performance_optimizer.py
+++ b/src/monitoring/performance_optimizer.py
@@ -22,10 +22,11 @@
except ImportError:
import sys
from pathlib import Path
+
src_path = str(Path(__file__).parent.parent)
if src_path not in sys.path:
sys.path.insert(0, src_path)
-
+
from cache.manager import CacheManager, get_cache_manager
from cache.warming import CacheWarmer, get_cache_warmer
from cache.metrics import MetricsCollector, get_metrics_collector
@@ -36,6 +37,7 @@
class OptimizationStrategy(Enum):
"""Performance optimization strategies."""
+
AGGRESSIVE_WARMING = "aggressive_warming"
CONSERVATIVE_MEMORY = "conservative_memory"
BALANCED = "balanced"
@@ -46,6 +48,7 @@ class OptimizationStrategy(Enum):
@dataclass
class PerformanceTarget:
"""Performance targets and thresholds."""
+
cache_hit_latency_ms: float = 48.0
cache_miss_latency_ms: float = 140.0
cache_hit_rate_percent: float = 60.0
@@ -57,6 +60,7 @@ class PerformanceTarget:
@dataclass
class OptimizationAction:
"""Represents an optimization action to be taken."""
+
action_type: str
parameters: Dict[str, Any]
priority: int
@@ -67,17 +71,19 @@ class OptimizationAction:
class PerformanceOptimizer:
"""
Real-time performance optimizer for the FACT system.
-
+
Monitors system performance and automatically applies optimizations
to meet performance targets.
"""
-
- def __init__(self,
- cache_manager: Optional[CacheManager] = None,
- targets: Optional[PerformanceTarget] = None):
+
+ def __init__(
+ self,
+ cache_manager: Optional[CacheManager] = None,
+ targets: Optional[PerformanceTarget] = None,
+ ):
"""
Initialize performance optimizer.
-
+
Args:
cache_manager: Cache manager instance
targets: Performance targets
@@ -86,48 +92,50 @@ def __init__(self,
self.cache_warmer = get_cache_warmer(self.cache_manager)
self.metrics_collector = get_metrics_collector()
self.benchmark_framework = BenchmarkFramework()
-
+
self.targets = targets or PerformanceTarget()
self.strategy = OptimizationStrategy.BALANCED
-
+
# Optimization state
self.optimization_enabled = True
self.optimization_interval = 300 # 5 minutes
self.last_optimization = 0.0
self.optimization_history: List[OptimizationAction] = []
-
+
# Performance tracking
self.performance_windows = {
- 'short': [], # Last 10 measurements
- 'medium': [], # Last 50 measurements
- 'long': [] # Last 200 measurements
+ "short": [], # Last 10 measurements
+ "medium": [], # Last 50 measurements
+ "long": [], # Last 200 measurements
}
-
+
# Optimization thresholds
self.thresholds = {
- 'hit_rate_critical': 40.0,
- 'latency_critical': 200.0,
- 'memory_critical': 95.0,
- 'cost_efficiency_critical': 50.0
+ "hit_rate_critical": 40.0,
+ "latency_critical": 200.0,
+ "memory_critical": 95.0,
+ "cost_efficiency_critical": 50.0,
}
-
+
# Thread safety
self._lock = threading.RLock()
self._optimization_task: Optional[asyncio.Task] = None
-
- logger.info("Performance optimizer initialized",
- strategy=self.strategy.value,
- targets=self.targets)
-
+
+ logger.info(
+ "Performance optimizer initialized",
+ strategy=self.strategy.value,
+ targets=self.targets,
+ )
+
async def start_optimization_loop(self):
"""Start the continuous optimization loop."""
if self._optimization_task and not self._optimization_task.done():
logger.warning("Optimization loop already running")
return
-
+
self._optimization_task = asyncio.create_task(self._optimization_loop())
logger.info("Performance optimization loop started")
-
+
async def stop_optimization_loop(self):
"""Stop the continuous optimization loop."""
if self._optimization_task:
@@ -137,360 +145,421 @@ async def stop_optimization_loop(self):
except asyncio.CancelledError:
pass
self._optimization_task = None
-
+
logger.info("Performance optimization loop stopped")
-
+
async def _optimization_loop(self):
"""Main optimization loop."""
while True:
try:
await asyncio.sleep(self.optimization_interval)
-
+
if self.optimization_enabled:
await self.optimize_performance()
-
+
except asyncio.CancelledError:
logger.info("Optimization loop cancelled")
break
except Exception as e:
logger.error("Optimization loop error", error=str(e))
await asyncio.sleep(30) # Brief pause before retry
-
+
async def optimize_performance(self) -> List[OptimizationAction]:
"""
Analyze current performance and apply optimizations.
-
+
Returns:
List of optimization actions taken
"""
try:
with self._lock:
logger.info("Starting performance optimization cycle")
-
+
# Collect current metrics
current_metrics = self._collect_performance_metrics()
-
+
# Update performance windows
self._update_performance_windows(current_metrics)
-
+
# Analyze performance issues
issues = self._analyze_performance_issues(current_metrics)
-
+
# Generate optimization actions
actions = self._generate_optimization_actions(issues, current_metrics)
-
+
# Execute high-priority actions
executed_actions = []
for action in sorted(actions, key=lambda x: x.priority, reverse=True):
if await self._execute_optimization_action(action):
executed_actions.append(action)
self.optimization_history.append(action)
-
+
# Log optimization results
self._log_optimization_results(executed_actions, current_metrics)
-
+
self.last_optimization = time.time()
-
+
return executed_actions
-
+
except Exception as e:
logger.error("Performance optimization failed", error=str(e))
return []
-
+
def _collect_performance_metrics(self) -> Dict[str, Any]:
"""Collect comprehensive performance metrics."""
try:
# Cache metrics
cache_metrics = self.cache_manager.get_metrics()
cache_perf_stats = self.cache_manager.get_performance_stats()
-
+
# Cache health
- health_metrics = self.metrics_collector.get_cache_health_score(self.cache_manager)
-
+ health_metrics = self.metrics_collector.get_cache_health_score(
+ self.cache_manager
+ )
+
# Optimization metrics
- opt_metrics = self.metrics_collector.track_optimization_metrics(self.cache_manager)
-
+ opt_metrics = self.metrics_collector.track_optimization_metrics(
+ self.cache_manager
+ )
+
# System metrics
current_time = time.time()
- memory_utilization = (cache_metrics.total_size / self.cache_manager.max_size_bytes * 100)
-
+ memory_utilization = (
+ cache_metrics.total_size / self.cache_manager.max_size_bytes * 100
+ )
+
return {
- 'timestamp': current_time,
- 'cache_metrics': cache_metrics,
- 'performance_stats': cache_perf_stats,
- 'health_metrics': health_metrics,
- 'optimization_metrics': opt_metrics,
- 'memory_utilization': memory_utilization,
- 'hit_rate_percent': cache_metrics.hit_rate,
- 'avg_hit_latency_ms': cache_perf_stats.get('avg_hit_latency_ms', 0),
- 'avg_miss_latency_ms': cache_perf_stats.get('avg_miss_latency_ms', 0)
+ "timestamp": current_time,
+ "cache_metrics": cache_metrics,
+ "performance_stats": cache_perf_stats,
+ "health_metrics": health_metrics,
+ "optimization_metrics": opt_metrics,
+ "memory_utilization": memory_utilization,
+ "hit_rate_percent": cache_metrics.hit_rate,
+ "avg_hit_latency_ms": cache_perf_stats.get("avg_hit_latency_ms", 0),
+ "avg_miss_latency_ms": cache_perf_stats.get("avg_miss_latency_ms", 0),
}
-
+
except Exception as e:
logger.error("Failed to collect performance metrics", error=str(e))
- return {'timestamp': time.time(), 'error': str(e)}
-
+ return {"timestamp": time.time(), "error": str(e)}
+
def _update_performance_windows(self, metrics: Dict[str, Any]):
"""Update rolling performance windows."""
try:
# Add to all windows
for window_name, window in self.performance_windows.items():
window.append(metrics)
-
+
# Trim windows to size
- if len(self.performance_windows['short']) > 10:
- self.performance_windows['short'].pop(0)
- if len(self.performance_windows['medium']) > 50:
- self.performance_windows['medium'].pop(0)
- if len(self.performance_windows['long']) > 200:
- self.performance_windows['long'].pop(0)
-
+ if len(self.performance_windows["short"]) > 10:
+ self.performance_windows["short"].pop(0)
+ if len(self.performance_windows["medium"]) > 50:
+ self.performance_windows["medium"].pop(0)
+ if len(self.performance_windows["long"]) > 200:
+ self.performance_windows["long"].pop(0)
+
except Exception as e:
logger.debug("Failed to update performance windows", error=str(e))
-
+
def _analyze_performance_issues(self, current_metrics: Dict[str, Any]) -> List[str]:
"""Analyze current performance and identify issues."""
issues = []
-
+
try:
# Cache hit rate analysis
- hit_rate = current_metrics.get('hit_rate_percent', 0)
+ hit_rate = current_metrics.get("hit_rate_percent", 0)
if hit_rate < self.targets.cache_hit_rate_percent:
- severity = "critical" if hit_rate < self.thresholds['hit_rate_critical'] else "moderate"
+ severity = (
+ "critical"
+ if hit_rate < self.thresholds["hit_rate_critical"]
+ else "moderate"
+ )
issues.append(f"low_hit_rate_{severity}")
-
+
# Latency analysis
- hit_latency = current_metrics.get('avg_hit_latency_ms', 0)
- miss_latency = current_metrics.get('avg_miss_latency_ms', 0)
-
+ hit_latency = current_metrics.get("avg_hit_latency_ms", 0)
+ miss_latency = current_metrics.get("avg_miss_latency_ms", 0)
+
if hit_latency > self.targets.cache_hit_latency_ms:
issues.append("high_hit_latency")
if miss_latency > self.targets.cache_miss_latency_ms:
issues.append("high_miss_latency")
-
+
# Memory utilization
- memory_util = current_metrics.get('memory_utilization', 0)
+ memory_util = current_metrics.get("memory_utilization", 0)
if memory_util > self.targets.memory_utilization_max:
- severity = "critical" if memory_util > self.thresholds['memory_critical'] else "moderate"
+ severity = (
+ "critical"
+ if memory_util > self.thresholds["memory_critical"]
+ else "moderate"
+ )
issues.append(f"high_memory_usage_{severity}")
-
+
# Cost efficiency
- opt_metrics = current_metrics.get('optimization_metrics', {})
- warming_efficiency = opt_metrics.get('cache_warming_efficiency', 0)
+ opt_metrics = current_metrics.get("optimization_metrics", {})
+ warming_efficiency = opt_metrics.get("cache_warming_efficiency", 0)
if warming_efficiency < 50:
issues.append("low_warming_efficiency")
-
+
# Trend analysis
- if len(self.performance_windows['medium']) >= 10:
- recent_hit_rates = [m.get('hit_rate_percent', 0) for m in self.performance_windows['medium'][-10:]]
+ if len(self.performance_windows["medium"]) >= 10:
+ recent_hit_rates = [
+ m.get("hit_rate_percent", 0)
+ for m in self.performance_windows["medium"][-10:]
+ ]
if len(recent_hit_rates) >= 5:
- trend = (recent_hit_rates[-1] - recent_hit_rates[0]) / len(recent_hit_rates)
+ trend = (recent_hit_rates[-1] - recent_hit_rates[0]) / len(
+ recent_hit_rates
+ )
if trend < -2: # Declining hit rate
issues.append("declining_hit_rate")
-
+
except Exception as e:
logger.error("Performance analysis failed", error=str(e))
issues.append("analysis_error")
-
+
return issues
-
- def _generate_optimization_actions(self, issues: List[str], metrics: Dict[str, Any]) -> List[OptimizationAction]:
+
+ def _generate_optimization_actions(
+ self, issues: List[str], metrics: Dict[str, Any]
+ ) -> List[OptimizationAction]:
"""Generate optimization actions based on identified issues."""
actions = []
-
+
try:
# Cache warming optimizations
if "low_hit_rate_critical" in issues or "declining_hit_rate" in issues:
- actions.append(OptimizationAction(
- action_type="aggressive_cache_warming",
- parameters={"max_queries": 50, "concurrent": True},
- priority=9,
- estimated_impact=15.0
- ))
+ actions.append(
+ OptimizationAction(
+ action_type="aggressive_cache_warming",
+ parameters={"max_queries": 50, "concurrent": True},
+ priority=9,
+ estimated_impact=15.0,
+ )
+ )
elif "low_hit_rate_moderate" in issues:
- actions.append(OptimizationAction(
- action_type="moderate_cache_warming",
- parameters={"max_queries": 30, "concurrent": True},
- priority=7,
- estimated_impact=10.0
- ))
-
+ actions.append(
+ OptimizationAction(
+ action_type="moderate_cache_warming",
+ parameters={"max_queries": 30, "concurrent": True},
+ priority=7,
+ estimated_impact=10.0,
+ )
+ )
+
# Memory optimization
if "high_memory_usage_critical" in issues:
- actions.append(OptimizationAction(
- action_type="emergency_memory_cleanup",
- parameters={"target_utilization": 70.0},
- priority=10,
- estimated_impact=20.0
- ))
+ actions.append(
+ OptimizationAction(
+ action_type="emergency_memory_cleanup",
+ parameters={"target_utilization": 70.0},
+ priority=10,
+ estimated_impact=20.0,
+ )
+ )
elif "high_memory_usage_moderate" in issues:
- actions.append(OptimizationAction(
- action_type="preemptive_cleanup",
- parameters={"target_utilization": 80.0},
- priority=6,
- estimated_impact=8.0
- ))
-
+ actions.append(
+ OptimizationAction(
+ action_type="preemptive_cleanup",
+ parameters={"target_utilization": 80.0},
+ priority=6,
+ estimated_impact=8.0,
+ )
+ )
+
# Latency optimizations
if "high_hit_latency" in issues:
- actions.append(OptimizationAction(
- action_type="optimize_cache_access",
- parameters={"enable_fast_lookup": True},
- priority=8,
- estimated_impact=12.0
- ))
-
+ actions.append(
+ OptimizationAction(
+ action_type="optimize_cache_access",
+ parameters={"enable_fast_lookup": True},
+ priority=8,
+ estimated_impact=12.0,
+ )
+ )
+
# Warming efficiency improvements
if "low_warming_efficiency" in issues:
- actions.append(OptimizationAction(
- action_type="optimize_warming_strategy",
- parameters={"strategy": "intelligent_prioritization"},
- priority=5,
- estimated_impact=7.0
- ))
-
+ actions.append(
+ OptimizationAction(
+ action_type="optimize_warming_strategy",
+ parameters={"strategy": "intelligent_prioritization"},
+ priority=5,
+ estimated_impact=7.0,
+ )
+ )
+
# Strategy adjustments
- current_hit_rate = metrics.get('hit_rate_percent', 0)
- if current_hit_rate < 40 and self.strategy != OptimizationStrategy.AGGRESSIVE_WARMING:
- actions.append(OptimizationAction(
- action_type="change_strategy",
- parameters={"new_strategy": OptimizationStrategy.AGGRESSIVE_WARMING.value},
- priority=4,
- estimated_impact=5.0
- ))
-
+ current_hit_rate = metrics.get("hit_rate_percent", 0)
+ if (
+ current_hit_rate < 40
+ and self.strategy != OptimizationStrategy.AGGRESSIVE_WARMING
+ ):
+ actions.append(
+ OptimizationAction(
+ action_type="change_strategy",
+ parameters={
+ "new_strategy": OptimizationStrategy.AGGRESSIVE_WARMING.value
+ },
+ priority=4,
+ estimated_impact=5.0,
+ )
+ )
+
except Exception as e:
logger.error("Failed to generate optimization actions", error=str(e))
-
+
return actions
-
+
async def _execute_optimization_action(self, action: OptimizationAction) -> bool:
"""Execute a specific optimization action."""
try:
- logger.info("Executing optimization action",
- action_type=action.action_type,
- priority=action.priority,
- parameters=action.parameters)
-
+ logger.info(
+ "Executing optimization action",
+ action_type=action.action_type,
+ priority=action.priority,
+ parameters=action.parameters,
+ )
+
if action.action_type == "aggressive_cache_warming":
await self.cache_warmer.warm_cache_intelligently(
max_queries=action.parameters.get("max_queries", 50)
)
return True
-
+
elif action.action_type == "moderate_cache_warming":
await self.cache_warmer.warm_cache_intelligently(
max_queries=action.parameters.get("max_queries", 30)
)
return True
-
+
elif action.action_type == "emergency_memory_cleanup":
target_util = action.parameters.get("target_utilization", 70.0)
current_size = self.cache_manager._calculate_current_size()
target_size = int(self.cache_manager.max_size_bytes * target_util / 100)
-
+
if current_size > target_size:
space_to_free = current_size - target_size
self.cache_manager._intelligent_eviction(space_to_free)
return True
-
+
elif action.action_type == "preemptive_cleanup":
self.cache_manager._maybe_preemptive_cleanup()
return True
-
+
elif action.action_type == "optimize_cache_access":
# Enable performance optimizations
self.cache_manager.optimization_enabled = True
- self.cache_manager.fast_lookup_enabled = action.parameters.get("enable_fast_lookup", True)
+ self.cache_manager.fast_lookup_enabled = action.parameters.get(
+ "enable_fast_lookup", True
+ )
return True
-
+
elif action.action_type == "optimize_warming_strategy":
# Update warming strategy
- strategy = action.parameters.get("strategy", "intelligent_prioritization")
- self.cache_warmer.optimization_config['adaptive_priorities'] = True
- self.cache_warmer.optimization_config['concurrent_warming'] = True
+ strategy = action.parameters.get(
+ "strategy", "intelligent_prioritization"
+ )
+ self.cache_warmer.optimization_config["adaptive_priorities"] = True
+ self.cache_warmer.optimization_config["concurrent_warming"] = True
return True
-
+
elif action.action_type == "change_strategy":
new_strategy = action.parameters.get("new_strategy")
if new_strategy:
self.strategy = OptimizationStrategy(new_strategy)
logger.info("Strategy changed", new_strategy=new_strategy)
return True
-
+
else:
- logger.warning("Unknown optimization action", action_type=action.action_type)
+ logger.warning(
+ "Unknown optimization action", action_type=action.action_type
+ )
return False
-
+
except Exception as e:
- logger.error("Failed to execute optimization action",
- action_type=action.action_type,
- error=str(e))
+ logger.error(
+ "Failed to execute optimization action",
+ action_type=action.action_type,
+ error=str(e),
+ )
return False
-
- def _log_optimization_results(self, actions: List[OptimizationAction], metrics: Dict[str, Any]):
+
+ def _log_optimization_results(
+ self, actions: List[OptimizationAction], metrics: Dict[str, Any]
+ ):
"""Log optimization results and impact."""
try:
if actions:
action_types = [a.action_type for a in actions]
total_impact = sum(a.estimated_impact for a in actions)
-
- logger.info("Optimization cycle completed",
- actions_taken=len(actions),
- action_types=action_types,
- estimated_total_impact=total_impact,
- current_hit_rate=metrics.get('hit_rate_percent', 0),
- memory_utilization=metrics.get('memory_utilization', 0))
+
+ logger.info(
+ "Optimization cycle completed",
+ actions_taken=len(actions),
+ action_types=action_types,
+ estimated_total_impact=total_impact,
+ current_hit_rate=metrics.get("hit_rate_percent", 0),
+ memory_utilization=metrics.get("memory_utilization", 0),
+ )
else:
- logger.info("Optimization cycle completed - no actions needed",
- current_hit_rate=metrics.get('hit_rate_percent', 0),
- memory_utilization=metrics.get('memory_utilization', 0))
-
+ logger.info(
+ "Optimization cycle completed - no actions needed",
+ current_hit_rate=metrics.get("hit_rate_percent", 0),
+ memory_utilization=metrics.get("memory_utilization", 0),
+ )
+
except Exception as e:
logger.debug("Failed to log optimization results", error=str(e))
-
+
def get_optimization_status(self) -> Dict[str, Any]:
"""Get current optimization status and statistics."""
try:
- recent_actions = self.optimization_history[-10:] if self.optimization_history else []
-
+ recent_actions = (
+ self.optimization_history[-10:] if self.optimization_history else []
+ )
+
return {
- 'optimization_enabled': self.optimization_enabled,
- 'current_strategy': self.strategy.value,
- 'last_optimization': self.last_optimization,
- 'optimization_interval': self.optimization_interval,
- 'total_optimizations': len(self.optimization_history),
- 'recent_actions': [
+ "optimization_enabled": self.optimization_enabled,
+ "current_strategy": self.strategy.value,
+ "last_optimization": self.last_optimization,
+ "optimization_interval": self.optimization_interval,
+ "total_optimizations": len(self.optimization_history),
+ "recent_actions": [
{
- 'action_type': a.action_type,
- 'execution_time': a.execution_time,
- 'estimated_impact': a.estimated_impact
- } for a in recent_actions
+ "action_type": a.action_type,
+ "execution_time": a.execution_time,
+ "estimated_impact": a.estimated_impact,
+ }
+ for a in recent_actions
],
- 'performance_targets': {
- 'cache_hit_latency_ms': self.targets.cache_hit_latency_ms,
- 'cache_miss_latency_ms': self.targets.cache_miss_latency_ms,
- 'cache_hit_rate_percent': self.targets.cache_hit_rate_percent,
- 'cost_reduction_percent': self.targets.cost_reduction_percent
- }
+ "performance_targets": {
+ "cache_hit_latency_ms": self.targets.cache_hit_latency_ms,
+ "cache_miss_latency_ms": self.targets.cache_miss_latency_ms,
+ "cache_hit_rate_percent": self.targets.cache_hit_rate_percent,
+ "cost_reduction_percent": self.targets.cost_reduction_percent,
+ },
}
-
+
except Exception as e:
logger.error("Failed to get optimization status", error=str(e))
- return {'error': str(e)}
+ return {"error": str(e)}
# Global optimizer instance
_performance_optimizer: Optional[PerformanceOptimizer] = None
-def get_performance_optimizer(cache_manager: Optional[CacheManager] = None) -> PerformanceOptimizer:
+def get_performance_optimizer(
+ cache_manager: Optional[CacheManager] = None,
+) -> PerformanceOptimizer:
"""Get or create global performance optimizer instance."""
global _performance_optimizer
-
+
if _performance_optimizer is None:
_performance_optimizer = PerformanceOptimizer(cache_manager)
-
+
return _performance_optimizer
@@ -505,4 +574,4 @@ async def stop_performance_optimization():
"""Stop performance optimization."""
optimizer = get_performance_optimizer()
await optimizer.stop_optimization_loop()
- logger.info("Performance optimization stopped")
\ No newline at end of file
+ logger.info("Performance optimization stopped")
diff --git a/src/security/__init__.py b/src/security/__init__.py
index 1427c7b..b67e033 100644
--- a/src/security/__init__.py
+++ b/src/security/__init__.py
@@ -12,8 +12,4 @@
# Fallback to absolute imports when called from scripts
from security.auth import AuthorizationManager, Authorization, AuthFlow
-__all__ = [
- 'AuthorizationManager',
- 'Authorization',
- 'AuthFlow'
-]
\ No newline at end of file
+__all__ = ["AuthorizationManager", "Authorization", "AuthFlow"]
diff --git a/src/security/auth.py b/src/security/auth.py
index 08cf06c..0b53356 100644
--- a/src/security/auth.py
+++ b/src/security/auth.py
@@ -13,25 +13,18 @@
try:
# Try relative imports first (when used as package)
- from ..core.errors import (
- AuthenticationError,
- UnauthorizedError,
- ValidationError
- )
+ from ..core.errors import AuthenticationError, UnauthorizedError, ValidationError
except ImportError:
# Fall back to absolute imports (when run as script)
import sys
from pathlib import Path
+
# Add src to path if not already there
src_path = str(Path(__file__).parent.parent)
if src_path not in sys.path:
sys.path.insert(0, src_path)
-
- from core.errors import (
- AuthenticationError,
- UnauthorizedError,
- ValidationError
- )
+
+ from core.errors import AuthenticationError, UnauthorizedError, ValidationError
logger = structlog.get_logger(__name__)
@@ -40,6 +33,7 @@
@dataclass
class Authorization:
"""Represents an authorization grant for a user and tool."""
+
user_id: str
tool_name: str
scopes: List[str]
@@ -47,21 +41,21 @@ class Authorization:
refresh_token: Optional[str] = None
expires_at: Optional[float] = None
created_at: float = None
-
+
def __post_init__(self):
if self.created_at is None:
self.created_at = time.time()
-
+
def is_expired(self) -> bool:
"""Check if authorization is expired."""
if self.expires_at is None:
return False
return time.time() > self.expires_at
-
+
def has_scope(self, scope: str) -> bool:
"""Check if authorization includes a specific scope."""
return scope in self.scopes
-
+
def has_all_scopes(self, required_scopes: List[str]) -> bool:
"""Check if authorization includes all required scopes."""
return all(scope in self.scopes for scope in required_scopes)
@@ -70,6 +64,7 @@ def has_all_scopes(self, required_scopes: List[str]) -> bool:
@dataclass
class AuthFlow:
"""Represents an ongoing OAuth authorization flow."""
+
flow_id: str
user_id: str
tool_name: str
@@ -78,13 +73,13 @@ class AuthFlow:
status: str = "pending"
created_at: float = None
expires_at: float = None
-
+
def __post_init__(self):
if self.created_at is None:
self.created_at = time.time()
if self.expires_at is None:
self.expires_at = self.created_at + 600 # 10 minutes
-
+
def is_expired(self) -> bool:
"""Check if authorization flow is expired."""
return time.time() > self.expires_at
@@ -93,18 +88,20 @@ def is_expired(self) -> bool:
class AuthorizationManager:
"""
Manages user authorization for tool access.
-
+
Provides OAuth integration, scope validation, and session management
for secure tool execution.
"""
-
- def __init__(self,
- oauth_client_id: Optional[str] = None,
- oauth_client_secret: Optional[str] = None,
- oauth_redirect_uri: Optional[str] = None):
+
+ def __init__(
+ self,
+ oauth_client_id: Optional[str] = None,
+ oauth_client_secret: Optional[str] = None,
+ oauth_redirect_uri: Optional[str] = None,
+ ):
"""
Initialize authorization manager.
-
+
Args:
oauth_client_id: OAuth client ID
oauth_client_secret: OAuth client secret
@@ -113,11 +110,11 @@ def __init__(self,
self.oauth_client_id = oauth_client_id
self.oauth_client_secret = oauth_client_secret
self.oauth_redirect_uri = oauth_redirect_uri
-
+
# In-memory storage for demo (use persistent storage in production)
self.active_authorizations: Dict[str, Authorization] = {}
self.pending_flows: Dict[str, AuthFlow] = {}
-
+
# Scope definitions
self.available_scopes = {
"read": "Read access to data and tools",
@@ -125,109 +122,130 @@ def __init__(self,
"admin": "Administrative access to system functions",
"sql": "Execute SQL queries",
"file": "Access file system operations",
- "web": "Make external web requests"
+ "web": "Make external web requests",
}
-
- logger.info("AuthorizationManager initialized",
- oauth_enabled=bool(oauth_client_id))
-
- async def validate_authorization(self,
- user_id: str,
- tool_name: str,
- required_scopes: Optional[List[str]] = None) -> Authorization:
+
+ logger.info(
+ "AuthorizationManager initialized", oauth_enabled=bool(oauth_client_id)
+ )
+
+ async def validate_authorization(
+ self, user_id: str, tool_name: str, required_scopes: Optional[List[str]] = None
+ ) -> Authorization:
"""
Validate user authorization for tool access.
-
+
Args:
user_id: User identifier
tool_name: Tool name requiring authorization
required_scopes: Optional list of required scopes
-
+
Returns:
Valid Authorization object
-
+
Raises:
UnauthorizedError: If authorization is invalid or insufficient
"""
try:
auth_key = f"{user_id}:{tool_name}"
-
+
# Check if authorization exists
if auth_key not in self.active_authorizations:
- logger.warning("No authorization found",
- user_id=user_id,
- tool_name=tool_name)
- raise UnauthorizedError(f"No authorization found for user {user_id} and tool {tool_name}")
-
+ logger.warning(
+ "No authorization found", user_id=user_id, tool_name=tool_name
+ )
+ raise UnauthorizedError(
+ f"No authorization found for user {user_id} and tool {tool_name}"
+ )
+
authorization = self.active_authorizations[auth_key]
-
+
# Check if authorization is expired
if authorization.is_expired():
- logger.warning("Authorization expired",
- user_id=user_id,
- tool_name=tool_name,
- expires_at=authorization.expires_at)
-
+ logger.warning(
+ "Authorization expired",
+ user_id=user_id,
+ tool_name=tool_name,
+ expires_at=authorization.expires_at,
+ )
+
# Try to refresh token if available
if authorization.refresh_token:
try:
- refreshed_auth = await self._refresh_authorization(authorization)
+ refreshed_auth = await self._refresh_authorization(
+ authorization
+ )
self.active_authorizations[auth_key] = refreshed_auth
authorization = refreshed_auth
except Exception as e:
- logger.error("Failed to refresh authorization",
- user_id=user_id,
- tool_name=tool_name,
- error=str(e))
+ logger.error(
+ "Failed to refresh authorization",
+ user_id=user_id,
+ tool_name=tool_name,
+ error=str(e),
+ )
del self.active_authorizations[auth_key]
- raise UnauthorizedError("Authorization expired and refresh failed")
+ raise UnauthorizedError(
+ "Authorization expired and refresh failed"
+ )
else:
del self.active_authorizations[auth_key]
raise UnauthorizedError("Authorization expired")
-
+
# Check required scopes
if required_scopes:
if not authorization.has_all_scopes(required_scopes):
- missing_scopes = [scope for scope in required_scopes if not authorization.has_scope(scope)]
- logger.warning("Insufficient scopes",
- user_id=user_id,
- tool_name=tool_name,
- required_scopes=required_scopes,
- user_scopes=authorization.scopes,
- missing_scopes=missing_scopes)
- raise UnauthorizedError(f"Insufficient permissions. Missing scopes: {missing_scopes}")
-
- logger.debug("Authorization validated successfully",
+ missing_scopes = [
+ scope
+ for scope in required_scopes
+ if not authorization.has_scope(scope)
+ ]
+ logger.warning(
+ "Insufficient scopes",
user_id=user_id,
tool_name=tool_name,
- scopes=authorization.scopes)
-
+ required_scopes=required_scopes,
+ user_scopes=authorization.scopes,
+ missing_scopes=missing_scopes,
+ )
+ raise UnauthorizedError(
+ f"Insufficient permissions. Missing scopes: {missing_scopes}"
+ )
+
+ logger.debug(
+ "Authorization validated successfully",
+ user_id=user_id,
+ tool_name=tool_name,
+ scopes=authorization.scopes,
+ )
+
return authorization
-
+
except UnauthorizedError:
raise
except Exception as e:
- logger.error("Authorization validation failed",
- user_id=user_id,
- tool_name=tool_name,
- error=str(e))
+ logger.error(
+ "Authorization validation failed",
+ user_id=user_id,
+ tool_name=tool_name,
+ error=str(e),
+ )
raise UnauthorizedError(f"Authorization validation failed: {str(e)}")
-
- def initiate_authorization(self,
- user_id: str,
- tool_name: str,
- scopes: List[str]) -> AuthFlow:
+
+ def initiate_authorization(
+ self, user_id: str, tool_name: str, scopes: List[str]
+ ) -> AuthFlow:
"""
Initiate OAuth authorization flow.
-
+
Args:
user_id: User identifier
tool_name: Tool name requiring authorization
scopes: List of requested scopes
-
+
Returns:
AuthFlow instance with authorization URL
-
+
Raises:
ValidationError: If parameters are invalid
"""
@@ -235,64 +253,71 @@ def initiate_authorization(self,
# Validate inputs
if not user_id or not tool_name:
raise ValidationError("User ID and tool name are required")
-
+
if not scopes:
raise ValidationError("At least one scope is required")
-
+
# Validate scopes
- invalid_scopes = [scope for scope in scopes if scope not in self.available_scopes]
+ invalid_scopes = [
+ scope for scope in scopes if scope not in self.available_scopes
+ ]
if invalid_scopes:
raise ValidationError(f"Invalid scopes: {invalid_scopes}")
-
+
# Generate flow ID
import uuid
+
flow_id = str(uuid.uuid4())
-
+
# Create authorization URL
auth_url = self._create_auth_url(flow_id, scopes)
-
+
# Create auth flow
auth_flow = AuthFlow(
flow_id=flow_id,
user_id=user_id,
tool_name=tool_name,
scopes=scopes,
- auth_url=auth_url
+ auth_url=auth_url,
)
-
+
# Store pending flow
self.pending_flows[flow_id] = auth_flow
-
- logger.info("Authorization flow initiated",
- flow_id=flow_id,
- user_id=user_id,
- tool_name=tool_name,
- scopes=scopes)
-
+
+ logger.info(
+ "Authorization flow initiated",
+ flow_id=flow_id,
+ user_id=user_id,
+ tool_name=tool_name,
+ scopes=scopes,
+ )
+
return auth_flow
-
+
except ValidationError:
raise
except Exception as e:
- logger.error("Failed to initiate authorization",
- user_id=user_id,
- tool_name=tool_name,
- error=str(e))
+ logger.error(
+ "Failed to initiate authorization",
+ user_id=user_id,
+ tool_name=tool_name,
+ error=str(e),
+ )
raise ValidationError(f"Failed to initiate authorization: {str(e)}")
-
- async def complete_authorization(self,
- flow_id: str,
- authorization_code: Optional[str] = None) -> Authorization:
+
+ async def complete_authorization(
+ self, flow_id: str, authorization_code: Optional[str] = None
+ ) -> Authorization:
"""
Complete OAuth authorization flow.
-
+
Args:
flow_id: Authorization flow identifier
authorization_code: OAuth authorization code
-
+
Returns:
Authorization object
-
+
Raises:
UnauthorizedError: If flow is invalid or expired
"""
@@ -300,14 +325,14 @@ async def complete_authorization(self,
# Get pending flow
if flow_id not in self.pending_flows:
raise UnauthorizedError("Invalid or expired authorization flow")
-
+
auth_flow = self.pending_flows[flow_id]
-
+
# Check if flow is expired
if auth_flow.is_expired():
del self.pending_flows[flow_id]
raise UnauthorizedError("Authorization flow expired")
-
+
# For demo purposes, create authorization without actual OAuth
# In production, exchange authorization_code for access token
authorization = Authorization(
@@ -315,123 +340,127 @@ async def complete_authorization(self,
tool_name=auth_flow.tool_name,
scopes=auth_flow.scopes,
access_token=f"demo_token_{flow_id}",
- expires_at=time.time() + 3600 # 1 hour
+ expires_at=time.time() + 3600, # 1 hour
)
-
+
# Store authorization
auth_key = f"{auth_flow.user_id}:{auth_flow.tool_name}"
self.active_authorizations[auth_key] = authorization
-
+
# Clean up pending flow
del self.pending_flows[flow_id]
-
- logger.info("Authorization completed successfully",
- flow_id=flow_id,
- user_id=auth_flow.user_id,
- tool_name=auth_flow.tool_name)
-
+
+ logger.info(
+ "Authorization completed successfully",
+ flow_id=flow_id,
+ user_id=auth_flow.user_id,
+ tool_name=auth_flow.tool_name,
+ )
+
return authorization
-
+
except UnauthorizedError:
raise
except Exception as e:
- logger.error("Failed to complete authorization",
- flow_id=flow_id,
- error=str(e))
+ logger.error(
+ "Failed to complete authorization", flow_id=flow_id, error=str(e)
+ )
raise UnauthorizedError(f"Failed to complete authorization: {str(e)}")
-
+
def revoke_authorization(self, user_id: str, tool_name: str) -> bool:
"""
Revoke user authorization for a tool.
-
+
Args:
user_id: User identifier
tool_name: Tool name
-
+
Returns:
True if authorization was revoked
"""
auth_key = f"{user_id}:{tool_name}"
-
+
if auth_key in self.active_authorizations:
del self.active_authorizations[auth_key]
- logger.info("Authorization revoked",
- user_id=user_id,
- tool_name=tool_name)
+ logger.info("Authorization revoked", user_id=user_id, tool_name=tool_name)
return True
-
+
return False
-
+
def get_user_authorizations(self, user_id: str) -> List[Authorization]:
"""
Get all authorizations for a user.
-
+
Args:
user_id: User identifier
-
+
Returns:
List of user authorizations
"""
user_auths = []
-
+
for auth_key, authorization in self.active_authorizations.items():
if authorization.user_id == user_id:
user_auths.append(authorization)
-
+
return user_auths
-
+
def cleanup_expired_authorizations(self) -> int:
"""
Clean up expired authorizations and flows.
-
+
Returns:
Number of items cleaned up
"""
cleaned_count = 0
current_time = time.time()
-
+
# Clean up expired authorizations
expired_auths = []
for auth_key, authorization in self.active_authorizations.items():
if authorization.is_expired():
expired_auths.append(auth_key)
-
+
for auth_key in expired_auths:
del self.active_authorizations[auth_key]
cleaned_count += 1
-
+
# Clean up expired flows
expired_flows = []
for flow_id, auth_flow in self.pending_flows.items():
if auth_flow.is_expired():
expired_flows.append(flow_id)
-
+
for flow_id in expired_flows:
del self.pending_flows[flow_id]
cleaned_count += 1
-
+
if cleaned_count > 0:
- logger.info("Cleaned up expired items",
- count=cleaned_count,
- authorizations=len(expired_auths),
- flows=len(expired_flows))
-
+ logger.info(
+ "Cleaned up expired items",
+ count=cleaned_count,
+ authorizations=len(expired_auths),
+ flows=len(expired_flows),
+ )
+
return cleaned_count
-
+
def get_available_scopes(self) -> Dict[str, str]:
"""Get available authorization scopes."""
return self.available_scopes.copy()
-
- async def _refresh_authorization(self, authorization: Authorization) -> Authorization:
+
+ async def _refresh_authorization(
+ self, authorization: Authorization
+ ) -> Authorization:
"""
Refresh an expired authorization using refresh token.
-
+
Args:
authorization: Authorization to refresh
-
+
Returns:
Refreshed authorization
-
+
Raises:
UnauthorizedError: If refresh fails
"""
@@ -445,28 +474,32 @@ async def _refresh_authorization(self, authorization: Authorization) -> Authoriz
access_token=f"refreshed_{authorization.access_token}",
refresh_token=authorization.refresh_token,
expires_at=time.time() + 3600, # 1 hour
- created_at=authorization.created_at
+ created_at=authorization.created_at,
+ )
+
+ logger.info(
+ "Authorization refreshed",
+ user_id=authorization.user_id,
+ tool_name=authorization.tool_name,
)
-
- logger.info("Authorization refreshed",
- user_id=authorization.user_id,
- tool_name=authorization.tool_name)
-
+
return refreshed_auth
-
+
except Exception as e:
- logger.error("Failed to refresh authorization",
- user_id=authorization.user_id,
- tool_name=authorization.tool_name,
- error=str(e))
+ logger.error(
+ "Failed to refresh authorization",
+ user_id=authorization.user_id,
+ tool_name=authorization.tool_name,
+ error=str(e),
+ )
raise UnauthorizedError(f"Failed to refresh authorization: {str(e)}")
-
+
def _create_auth_url(self, flow_id: str, scopes: List[str]) -> str:
"""Create OAuth authorization URL."""
if not self.oauth_client_id or not self.oauth_redirect_uri:
# Return demo URL if OAuth not configured
return f"https://demo-oauth.example.com/authorize?flow_id={flow_id}&scopes={','.join(scopes)}"
-
+
# In production, create actual OAuth URL
scope_string = " ".join(scopes)
auth_url = (
@@ -477,41 +510,44 @@ def _create_auth_url(self, flow_id: str, scopes: List[str]) -> str:
f"&state={flow_id}"
f"&response_type=code"
)
-
+
return auth_url
# Utility functions for authorization
-def create_authorization_manager(oauth_client_id: Optional[str] = None,
- oauth_client_secret: Optional[str] = None,
- oauth_redirect_uri: Optional[str] = None) -> AuthorizationManager:
+
+def create_authorization_manager(
+ oauth_client_id: Optional[str] = None,
+ oauth_client_secret: Optional[str] = None,
+ oauth_redirect_uri: Optional[str] = None,
+) -> AuthorizationManager:
"""
Create and configure an authorization manager.
-
+
Args:
oauth_client_id: OAuth client ID
oauth_client_secret: OAuth client secret
oauth_redirect_uri: OAuth redirect URI
-
+
Returns:
Configured AuthorizationManager instance
"""
return AuthorizationManager(
oauth_client_id=oauth_client_id,
oauth_client_secret=oauth_client_secret,
- oauth_redirect_uri=oauth_redirect_uri
+ oauth_redirect_uri=oauth_redirect_uri,
)
def extract_required_scopes(tool_definition: Dict[str, Any]) -> List[str]:
"""
Extract required scopes from tool definition.
-
+
Args:
tool_definition: Tool definition dictionary
-
+
Returns:
List of required scopes
"""
- return tool_definition.get("required_scopes", [])
\ No newline at end of file
+ return tool_definition.get("required_scopes", [])
diff --git a/src/security/cache_encryption.py b/src/security/cache_encryption.py
index c7079db..810491e 100644
--- a/src/security/cache_encryption.py
+++ b/src/security/cache_encryption.py
@@ -20,52 +20,53 @@
from core.errors import CacheError, SecurityError
-
logger = structlog.get_logger(__name__)
class CacheEncryption:
"""
Secure encryption for cache data with integrity protection.
-
+
Provides encryption, decryption, and integrity validation for cached
content to prevent data tampering and unauthorized access.
"""
-
+
def __init__(self, encryption_key: Optional[bytes] = None):
"""
Initialize cache encryption.
-
+
Args:
encryption_key: Optional encryption key (generated if not provided)
"""
self.encryption_key = encryption_key or self._generate_encryption_key()
self.cipher_suite = Fernet(self.encryption_key)
-
+
# HMAC key for integrity verification (derived from encryption key)
self.hmac_key = self._derive_hmac_key(self.encryption_key)
-
+
# Metadata for cache entries
self.version = "1.0"
self.algorithm = "Fernet"
-
- logger.info("Cache encryption initialized",
- algorithm=self.algorithm,
- version=self.version)
-
- def encrypt_cache_entry(self,
- content: str,
- metadata: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
+
+ logger.info(
+ "Cache encryption initialized",
+ algorithm=self.algorithm,
+ version=self.version,
+ )
+
+ def encrypt_cache_entry(
+ self, content: str, metadata: Optional[Dict[str, Any]] = None
+ ) -> Dict[str, Any]:
"""
Encrypt cache content with metadata and integrity protection.
-
+
Args:
content: Content to encrypt
metadata: Optional metadata to include
-
+
Returns:
Encrypted cache entry with metadata
-
+
Raises:
CacheError: If encryption fails
"""
@@ -73,181 +74,200 @@ def encrypt_cache_entry(self,
# Validate input
if not isinstance(content, str):
raise CacheError("Content must be a string")
-
+
if len(content) > 10 * 1024 * 1024: # 10MB limit
raise CacheError("Content too large for encryption")
-
+
# Prepare cache entry data
cache_data = {
"content": content,
"metadata": metadata or {},
"encrypted_at": time.time(),
- "version": self.version
+ "version": self.version,
}
-
+
# Serialize data
- data_bytes = json.dumps(cache_data).encode('utf-8')
-
+ data_bytes = json.dumps(cache_data).encode("utf-8")
+
# Encrypt data
encrypted_data = self.cipher_suite.encrypt(data_bytes)
-
+
# Create integrity hash
integrity_hash = self._create_integrity_hash(encrypted_data)
-
+
# Prepare final encrypted entry
encrypted_entry = {
- "encrypted_content": base64.urlsafe_b64encode(encrypted_data).decode('utf-8'),
+ "encrypted_content": base64.urlsafe_b64encode(encrypted_data).decode(
+ "utf-8"
+ ),
"integrity_hash": integrity_hash,
"algorithm": self.algorithm,
"version": self.version,
- "created_at": time.time()
+ "created_at": time.time(),
}
-
- logger.debug("Cache entry encrypted successfully",
- content_length=len(content),
- metadata_keys=list(metadata.keys()) if metadata else [])
-
+
+ logger.debug(
+ "Cache entry encrypted successfully",
+ content_length=len(content),
+ metadata_keys=list(metadata.keys()) if metadata else [],
+ )
+
return encrypted_entry
-
+
except Exception as e:
logger.error("Cache encryption failed", error=str(e))
raise CacheError(f"Failed to encrypt cache entry: {str(e)}")
-
+
def decrypt_cache_entry(self, encrypted_entry: Dict[str, Any]) -> Dict[str, Any]:
"""
Decrypt and validate cache entry.
-
+
Args:
encrypted_entry: Encrypted cache entry
-
+
Returns:
Decrypted cache data with content and metadata
-
+
Raises:
CacheError: If decryption or validation fails
"""
try:
# Validate encrypted entry structure
- required_fields = ["encrypted_content", "integrity_hash", "algorithm", "version"]
+ required_fields = [
+ "encrypted_content",
+ "integrity_hash",
+ "algorithm",
+ "version",
+ ]
for field in required_fields:
if field not in encrypted_entry:
raise CacheError(f"Missing required field: {field}")
-
+
# Check version compatibility
if encrypted_entry["version"] != self.version:
- raise CacheError(f"Incompatible cache entry version: {encrypted_entry['version']}")
-
+ raise CacheError(
+ f"Incompatible cache entry version: {encrypted_entry['version']}"
+ )
+
# Check algorithm
if encrypted_entry["algorithm"] != self.algorithm:
- raise CacheError(f"Unsupported encryption algorithm: {encrypted_entry['algorithm']}")
-
+ raise CacheError(
+ f"Unsupported encryption algorithm: {encrypted_entry['algorithm']}"
+ )
+
# Decode encrypted content
try:
encrypted_data = base64.urlsafe_b64decode(
- encrypted_entry["encrypted_content"].encode('utf-8')
+ encrypted_entry["encrypted_content"].encode("utf-8")
)
except Exception:
raise CacheError("Invalid encrypted content encoding")
-
+
# Verify integrity
expected_hash = encrypted_entry["integrity_hash"]
actual_hash = self._create_integrity_hash(encrypted_data)
-
+
if not hmac.compare_digest(expected_hash, actual_hash):
- raise CacheError("Cache entry integrity check failed - possible tampering")
-
+ raise CacheError(
+ "Cache entry integrity check failed - possible tampering"
+ )
+
# Decrypt data
try:
decrypted_data = self.cipher_suite.decrypt(encrypted_data)
except Exception:
- raise CacheError("Failed to decrypt cache entry - invalid key or corrupted data")
-
+ raise CacheError(
+ "Failed to decrypt cache entry - invalid key or corrupted data"
+ )
+
# Parse decrypted content
try:
- cache_data = json.loads(decrypted_data.decode('utf-8'))
+ cache_data = json.loads(decrypted_data.decode("utf-8"))
except Exception:
raise CacheError("Failed to parse decrypted cache data")
-
+
# Validate decrypted data structure
if not isinstance(cache_data, dict) or "content" not in cache_data:
raise CacheError("Invalid decrypted cache data structure")
-
- logger.debug("Cache entry decrypted successfully",
- content_length=len(cache_data.get("content", "")),
- has_metadata=bool(cache_data.get("metadata")))
-
+
+ logger.debug(
+ "Cache entry decrypted successfully",
+ content_length=len(cache_data.get("content", "")),
+ has_metadata=bool(cache_data.get("metadata")),
+ )
+
return cache_data
-
+
except CacheError:
raise
except Exception as e:
logger.error("Cache decryption failed", error=str(e))
raise CacheError(f"Failed to decrypt cache entry: {str(e)}")
-
- def encrypt_sensitive_fields(self,
- data: Dict[str, Any],
- sensitive_fields: set) -> Dict[str, Any]:
+
+ def encrypt_sensitive_fields(
+ self, data: Dict[str, Any], sensitive_fields: set
+ ) -> Dict[str, Any]:
"""
Encrypt specific sensitive fields in a dictionary.
-
+
Args:
data: Dictionary containing data
sensitive_fields: Set of field names to encrypt
-
+
Returns:
Dictionary with sensitive fields encrypted
"""
if not isinstance(data, dict):
raise CacheError("Data must be a dictionary")
-
+
encrypted_data = data.copy()
-
+
for field_name in sensitive_fields:
if field_name in data:
field_value = data[field_name]
-
+
if isinstance(field_value, str):
# Encrypt string field
encrypted_field = self._encrypt_field(field_value)
encrypted_data[f"_encrypted_{field_name}"] = encrypted_field
del encrypted_data[field_name]
-
+
elif isinstance(field_value, dict):
# Recursively encrypt nested dictionaries
encrypted_field = self._encrypt_field(json.dumps(field_value))
encrypted_data[f"_encrypted_{field_name}"] = encrypted_field
del encrypted_data[field_name]
-
+
return encrypted_data
-
- def decrypt_sensitive_fields(self,
- data: Dict[str, Any],
- sensitive_fields: set) -> Dict[str, Any]:
+
+ def decrypt_sensitive_fields(
+ self, data: Dict[str, Any], sensitive_fields: set
+ ) -> Dict[str, Any]:
"""
Decrypt sensitive fields in a dictionary.
-
+
Args:
data: Dictionary with encrypted fields
sensitive_fields: Set of original field names
-
+
Returns:
Dictionary with sensitive fields decrypted
"""
if not isinstance(data, dict):
raise CacheError("Data must be a dictionary")
-
+
decrypted_data = data.copy()
-
+
for field_name in sensitive_fields:
encrypted_field_name = f"_encrypted_{field_name}"
-
+
if encrypted_field_name in data:
encrypted_field = data[encrypted_field_name]
-
+
try:
# Decrypt field
decrypted_value = self._decrypt_field(encrypted_field)
-
+
# Try to parse as JSON for complex types
try:
parsed_value = json.loads(decrypted_value)
@@ -255,68 +275,77 @@ def decrypt_sensitive_fields(self,
except json.JSONDecodeError:
# Keep as string if not valid JSON
decrypted_data[field_name] = decrypted_value
-
+
# Remove encrypted field
del decrypted_data[encrypted_field_name]
-
+
except Exception as e:
- logger.warning("Failed to decrypt sensitive field",
- field_name=field_name,
- error=str(e))
+ logger.warning(
+ "Failed to decrypt sensitive field",
+ field_name=field_name,
+ error=str(e),
+ )
# Keep encrypted field if decryption fails
-
+
return decrypted_data
-
+
def is_cache_entry_valid(self, encrypted_entry: Dict[str, Any]) -> bool:
"""
Check if a cache entry is valid without decrypting it.
-
+
Args:
encrypted_entry: Encrypted cache entry to validate
-
+
Returns:
True if entry appears valid
"""
try:
# Check required fields
- required_fields = ["encrypted_content", "integrity_hash", "algorithm", "version"]
+ required_fields = [
+ "encrypted_content",
+ "integrity_hash",
+ "algorithm",
+ "version",
+ ]
for field in required_fields:
if field not in encrypted_entry:
return False
-
+
# Check version and algorithm
- if (encrypted_entry["version"] != self.version or
- encrypted_entry["algorithm"] != self.algorithm):
+ if (
+ encrypted_entry["version"] != self.version
+ or encrypted_entry["algorithm"] != self.algorithm
+ ):
return False
-
+
# Validate base64 encoding
try:
base64.urlsafe_b64decode(encrypted_entry["encrypted_content"])
except Exception:
return False
-
+
return True
-
+
except Exception:
return False
-
+
def _generate_encryption_key(self) -> bytes:
"""Generate a secure encryption key."""
# In production, this should be loaded from secure key management
password = os.urandom(32)
salt = os.urandom(16)
-
+
kdf = PBKDF2HMAC(
algorithm=hashes.SHA256(),
length=32,
salt=salt,
iterations=100000,
- backend=default_backend()
+ backend=default_backend(),
)
-
+
key = base64.urlsafe_b64encode(kdf.derive(password))
return key
-
+
def _derive_hmac_key(self, encryption_key: bytes) -> bytes:
"""Derive HMAC key from encryption key."""
# Use PBKDF2 to derive a separate HMAC key
@@ -325,40 +354,36 @@ def _derive_hmac_key(self, encryption_key: bytes) -> bytes:
length=32,
salt=b"hmac_derivation_salt",
iterations=10000,
- backend=default_backend()
+ backend=default_backend(),
)
-
+
return kdf.derive(encryption_key)
-
+
def _create_integrity_hash(self, data: bytes) -> str:
"""Create HMAC hash for integrity verification."""
- signature = hmac.new(
- self.hmac_key,
- data,
- hashlib.sha256
- ).hexdigest()
-
+ signature = hmac.new(self.hmac_key, data, hashlib.sha256).hexdigest()
+
return signature
-
+
def _encrypt_field(self, value: str) -> str:
"""Encrypt a single field value."""
- encrypted_data = self.cipher_suite.encrypt(value.encode('utf-8'))
- return base64.urlsafe_b64encode(encrypted_data).decode('utf-8')
-
+ encrypted_data = self.cipher_suite.encrypt(value.encode("utf-8"))
+ return base64.urlsafe_b64encode(encrypted_data).decode("utf-8")
+
def _decrypt_field(self, encrypted_value: str) -> str:
"""Decrypt a single field value."""
- encrypted_data = base64.urlsafe_b64decode(encrypted_value.encode('utf-8'))
+ encrypted_data = base64.urlsafe_b64decode(encrypted_value.encode("utf-8"))
decrypted_data = self.cipher_suite.decrypt(encrypted_data)
- return decrypted_data.decode('utf-8')
+ return decrypted_data.decode("utf-8")
def create_cache_encryption(encryption_key: Optional[bytes] = None) -> CacheEncryption:
"""
Create a configured cache encryption instance.
-
+
Args:
encryption_key: Optional encryption key
-
+
Returns:
CacheEncryption instance
"""
@@ -372,22 +397,23 @@ def create_cache_encryption(encryption_key: Optional[bytes] = None) -> CacheEncr
def get_cache_encryption() -> CacheEncryption:
"""Get global cache encryption instance."""
global _cache_encryption_instance
-
+
if _cache_encryption_instance is None:
_cache_encryption_instance = create_cache_encryption()
-
+
return _cache_encryption_instance
-def encrypt_cache_data(content: str,
- metadata: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
+def encrypt_cache_data(
+ content: str, metadata: Optional[Dict[str, Any]] = None
+) -> Dict[str, Any]:
"""
Convenience function to encrypt cache data.
-
+
Args:
content: Content to encrypt
metadata: Optional metadata
-
+
Returns:
Encrypted cache entry
"""
@@ -398,12 +424,12 @@ def encrypt_cache_data(content: str,
def decrypt_cache_data(encrypted_entry: Dict[str, Any]) -> Dict[str, Any]:
"""
Convenience function to decrypt cache data.
-
+
Args:
encrypted_entry: Encrypted cache entry
-
+
Returns:
Decrypted cache data
"""
cache_encryption = get_cache_encryption()
- return cache_encryption.decrypt_cache_entry(encrypted_entry)
\ No newline at end of file
+ return cache_encryption.decrypt_cache_entry(encrypted_entry)
diff --git a/src/security/config.py b/src/security/config.py
index b31613b..0bba6f4 100644
--- a/src/security/config.py
+++ b/src/security/config.py
@@ -14,29 +14,28 @@
from core.errors import ConfigurationError, SecurityError
-
logger = structlog.get_logger(__name__)
@dataclass
class AuthConfig:
"""Authentication and authorization configuration."""
-
+
# Token settings
token_lifetime_seconds: int = 3600 # 1 hour
refresh_token_lifetime_seconds: int = 86400 # 24 hours
token_entropy_bytes: int = 32 # 256 bits
max_token_per_user: int = 10
-
+
# OAuth settings
oauth_client_id: Optional[str] = None
oauth_client_secret: Optional[str] = None
oauth_redirect_uri: Optional[str] = None
-
+
# Session settings
session_timeout_seconds: int = 1800 # 30 minutes
max_sessions_per_user: int = 5
-
+
# Security settings
require_https: bool = True
enforce_csrf_protection: bool = True
@@ -46,24 +45,24 @@ class AuthConfig:
@dataclass
class ValidationConfig:
"""Input validation and sanitization configuration."""
-
+
# String limits
max_string_length: int = 10000
max_url_length: int = 2000
max_email_length: int = 254
max_filename_length: int = 255
max_path_length: int = 4096
-
+
# List and object limits
max_list_items: int = 1000
max_object_properties: int = 100
max_nesting_depth: int = 10
-
+
# Content validation
allow_html: bool = False
allow_javascript: bool = False
allow_file_uploads: bool = False
-
+
# Encoding settings
default_encoding: str = "utf-8"
sanitize_unicode: bool = True
@@ -72,47 +71,54 @@ class ValidationConfig:
@dataclass
class EncryptionConfig:
"""Encryption configuration for data protection."""
-
+
# Algorithm settings
encryption_algorithm: str = "Fernet"
key_derivation_iterations: int = 100000
hmac_algorithm: str = "sha256"
-
+
# Key management
encryption_key: Optional[bytes] = None
key_rotation_interval_days: int = 90
-
+
# Cache encryption
encrypt_cache_content: bool = True
cache_encryption_key: Optional[bytes] = None
-
+
# Database encryption
encrypt_sensitive_fields: bool = True
- sensitive_field_names: List[str] = field(default_factory=lambda: [
- 'password', 'token', 'secret', 'key', 'auth', 'credential'
- ])
+ sensitive_field_names: List[str] = field(
+ default_factory=lambda: [
+ "password",
+ "token",
+ "secret",
+ "key",
+ "auth",
+ "credential",
+ ]
+ )
@dataclass
class RateLimitConfig:
"""Rate limiting configuration."""
-
+
# Global rate limits
global_requests_per_minute: int = 1000
global_requests_per_hour: int = 10000
-
+
# Per-user rate limits
user_requests_per_minute: int = 60
user_requests_per_hour: int = 1000
-
+
# Tool-specific rate limits
tool_requests_per_minute: int = 30
tool_requests_per_hour: int = 500
-
+
# Authentication rate limits
auth_attempts_per_minute: int = 5
auth_attempts_per_hour: int = 20
-
+
# Rate limit enforcement
enable_rate_limiting: bool = True
rate_limit_storage: str = "memory" # "memory" or "redis"
@@ -122,23 +128,23 @@ class RateLimitConfig:
@dataclass
class SecurityMonitoringConfig:
"""Security monitoring and alerting configuration."""
-
+
# Logging settings
log_security_events: bool = True
log_level: str = "INFO"
log_sensitive_data: bool = False
-
+
# Alert thresholds
failed_auth_threshold: int = 10
rate_limit_violation_threshold: int = 5
suspicious_activity_threshold: int = 3
-
+
# Monitoring features
enable_intrusion_detection: bool = True
enable_anomaly_detection: bool = False
monitor_file_access: bool = True
monitor_network_requests: bool = True
-
+
# Data retention
security_log_retention_days: int = 90
audit_log_retention_days: int = 365
@@ -147,22 +153,26 @@ class SecurityMonitoringConfig:
@dataclass
class NetworkSecurityConfig:
"""Network security configuration."""
-
+
# HTTPS settings
enforce_https: bool = True
hsts_max_age: int = 31536000 # 1 year
-
+
# CORS settings
cors_enabled: bool = True
- cors_allowed_origins: List[str] = field(default_factory=lambda: ["https://*.example.com"])
- cors_allowed_methods: List[str] = field(default_factory=lambda: ["GET", "POST", "PUT", "DELETE"])
+ cors_allowed_origins: List[str] = field(
+ default_factory=lambda: ["https://*.example.com"]
+ )
+ cors_allowed_methods: List[str] = field(
+ default_factory=lambda: ["GET", "POST", "PUT", "DELETE"]
+ )
cors_max_age: int = 86400 # 24 hours
-
+
# Request filtering
block_private_ips: bool = True
block_localhost: bool = True
allowed_schemes: List[str] = field(default_factory=lambda: ["https"])
-
+
# Timeout settings
request_timeout: int = 30
connection_timeout: int = 10
@@ -172,57 +182,62 @@ class NetworkSecurityConfig:
@dataclass
class SecurityConfig:
"""Main security configuration container."""
-
+
auth: AuthConfig = field(default_factory=AuthConfig)
validation: ValidationConfig = field(default_factory=ValidationConfig)
encryption: EncryptionConfig = field(default_factory=EncryptionConfig)
rate_limiting: RateLimitConfig = field(default_factory=RateLimitConfig)
- monitoring: SecurityMonitoringConfig = field(default_factory=SecurityMonitoringConfig)
+ monitoring: SecurityMonitoringConfig = field(
+ default_factory=SecurityMonitoringConfig
+ )
network: NetworkSecurityConfig = field(default_factory=NetworkSecurityConfig)
-
+
# Global security settings
debug_mode: bool = False
strict_mode: bool = True
security_headers_enabled: bool = True
-
+
def __post_init__(self):
"""Validate configuration after initialization."""
self._validate_config()
self._setup_encryption_keys()
-
+
def _validate_config(self):
"""Validate security configuration."""
# Validate auth config
if self.auth.token_lifetime_seconds < 300: # 5 minutes minimum
raise ConfigurationError("Token lifetime must be at least 5 minutes")
-
+
if self.auth.token_lifetime_seconds > 86400: # 24 hours maximum
raise ConfigurationError("Token lifetime cannot exceed 24 hours")
-
+
# Validate validation config
if self.validation.max_string_length > 100 * 1024 * 1024: # 100MB
raise ConfigurationError("Maximum string length too large")
-
+
# Validate rate limiting
- if self.rate_limiting.user_requests_per_minute > self.rate_limiting.global_requests_per_minute:
+ if (
+ self.rate_limiting.user_requests_per_minute
+ > self.rate_limiting.global_requests_per_minute
+ ):
raise ConfigurationError("User rate limit cannot exceed global rate limit")
-
+
# Validate encryption
if self.encryption.key_derivation_iterations < 10000:
raise ConfigurationError("Key derivation iterations too low")
-
+
logger.info("Security configuration validated successfully")
-
+
def _setup_encryption_keys(self):
"""Setup encryption keys if not provided."""
if self.encryption.encryption_key is None:
self.encryption.encryption_key = Fernet.generate_key()
logger.info("Generated new encryption key")
-
+
if self.encryption.cache_encryption_key is None:
self.encryption.cache_encryption_key = Fernet.generate_key()
logger.info("Generated new cache encryption key")
-
+
def to_dict(self) -> Dict[str, Any]:
"""Convert configuration to dictionary (excluding sensitive data)."""
config_dict = {
@@ -232,51 +247,51 @@ def to_dict(self) -> Dict[str, Any]:
"require_https": self.auth.require_https,
"enforce_csrf_protection": self.auth.enforce_csrf_protection,
"min_password_length": self.auth.min_password_length,
- "oauth_configured": bool(self.auth.oauth_client_id)
+ "oauth_configured": bool(self.auth.oauth_client_id),
},
"validation": {
"max_string_length": self.validation.max_string_length,
"max_list_items": self.validation.max_list_items,
"max_nesting_depth": self.validation.max_nesting_depth,
"allow_html": self.validation.allow_html,
- "allow_javascript": self.validation.allow_javascript
+ "allow_javascript": self.validation.allow_javascript,
},
"encryption": {
"encryption_algorithm": self.encryption.encryption_algorithm,
"encrypt_cache_content": self.encryption.encrypt_cache_content,
"encrypt_sensitive_fields": self.encryption.encrypt_sensitive_fields,
- "key_rotation_interval_days": self.encryption.key_rotation_interval_days
+ "key_rotation_interval_days": self.encryption.key_rotation_interval_days,
},
"rate_limiting": {
"enable_rate_limiting": self.rate_limiting.enable_rate_limiting,
"user_requests_per_minute": self.rate_limiting.user_requests_per_minute,
- "tool_requests_per_minute": self.rate_limiting.tool_requests_per_minute
+ "tool_requests_per_minute": self.rate_limiting.tool_requests_per_minute,
},
"monitoring": {
"log_security_events": self.monitoring.log_security_events,
"enable_intrusion_detection": self.monitoring.enable_intrusion_detection,
- "security_log_retention_days": self.monitoring.security_log_retention_days
+ "security_log_retention_days": self.monitoring.security_log_retention_days,
},
"network": {
"enforce_https": self.network.enforce_https,
"cors_enabled": self.network.cors_enabled,
"block_private_ips": self.network.block_private_ips,
- "request_timeout": self.network.request_timeout
+ "request_timeout": self.network.request_timeout,
},
"global": {
"debug_mode": self.debug_mode,
"strict_mode": self.strict_mode,
- "security_headers_enabled": self.security_headers_enabled
- }
+ "security_headers_enabled": self.security_headers_enabled,
+ },
}
-
+
return config_dict
def load_security_config_from_env() -> SecurityConfig:
"""Load security configuration from environment variables."""
config = SecurityConfig()
-
+
# Auth configuration
config.auth.token_lifetime_seconds = int(
os.getenv("AUTH_TOKEN_LIFETIME", config.auth.token_lifetime_seconds)
@@ -284,45 +299,54 @@ def load_security_config_from_env() -> SecurityConfig:
config.auth.oauth_client_id = os.getenv("OAUTH_CLIENT_ID")
config.auth.oauth_client_secret = os.getenv("OAUTH_CLIENT_SECRET")
config.auth.oauth_redirect_uri = os.getenv("OAUTH_REDIRECT_URI")
-
+
# Validation configuration
config.validation.max_string_length = int(
os.getenv("VALIDATION_MAX_STRING_LENGTH", config.validation.max_string_length)
)
- config.validation.allow_html = os.getenv("VALIDATION_ALLOW_HTML", "false").lower() == "true"
-
+ config.validation.allow_html = (
+ os.getenv("VALIDATION_ALLOW_HTML", "false").lower() == "true"
+ )
+
# Encryption configuration
encryption_key_b64 = os.getenv("ENCRYPTION_KEY")
if encryption_key_b64:
try:
import base64
- config.encryption.encryption_key = base64.urlsafe_b64decode(encryption_key_b64)
+
+ config.encryption.encryption_key = base64.urlsafe_b64decode(
+ encryption_key_b64
+ )
except Exception as e:
- logger.warning("Failed to decode encryption key from environment", error=str(e))
-
+ logger.warning(
+ "Failed to decode encryption key from environment", error=str(e)
+ )
+
# Rate limiting configuration
config.rate_limiting.enable_rate_limiting = (
os.getenv("RATE_LIMITING_ENABLED", "true").lower() == "true"
)
config.rate_limiting.user_requests_per_minute = int(
- os.getenv("RATE_LIMIT_USER_PER_MINUTE", config.rate_limiting.user_requests_per_minute)
+ os.getenv(
+ "RATE_LIMIT_USER_PER_MINUTE", config.rate_limiting.user_requests_per_minute
+ )
)
-
+
# Monitoring configuration
config.monitoring.log_security_events = (
os.getenv("LOG_SECURITY_EVENTS", "true").lower() == "true"
)
- config.monitoring.log_level = os.getenv("SECURITY_LOG_LEVEL", config.monitoring.log_level)
-
- # Network configuration
- config.network.enforce_https = (
- os.getenv("ENFORCE_HTTPS", "true").lower() == "true"
+ config.monitoring.log_level = os.getenv(
+ "SECURITY_LOG_LEVEL", config.monitoring.log_level
)
-
+
+ # Network configuration
+ config.network.enforce_https = os.getenv("ENFORCE_HTTPS", "true").lower() == "true"
+
# Global settings
config.debug_mode = os.getenv("DEBUG_MODE", "false").lower() == "true"
config.strict_mode = os.getenv("STRICT_MODE", "true").lower() == "true"
-
+
logger.info("Security configuration loaded from environment")
return config
@@ -330,23 +354,23 @@ def load_security_config_from_env() -> SecurityConfig:
def create_security_config(**overrides) -> SecurityConfig:
"""
Create security configuration with optional overrides.
-
+
Args:
**overrides: Configuration overrides
-
+
Returns:
SecurityConfig instance
"""
# Start with environment-based config
config = load_security_config_from_env()
-
+
# Apply overrides
for key, value in overrides.items():
if hasattr(config, key):
setattr(config, key, value)
else:
logger.warning("Unknown security config override", key=key)
-
+
return config
@@ -357,10 +381,10 @@ def create_security_config(**overrides) -> SecurityConfig:
def get_security_config() -> SecurityConfig:
"""Get global security configuration instance."""
global _security_config_instance
-
+
if _security_config_instance is None:
_security_config_instance = load_security_config_from_env()
-
+
return _security_config_instance
@@ -374,7 +398,7 @@ def update_security_config(config: SecurityConfig) -> None:
def get_security_headers() -> Dict[str, str]:
"""Get recommended security headers."""
config = get_security_config()
-
+
headers = {
"X-Content-Type-Options": "nosniff",
"X-Frame-Options": "DENY",
@@ -390,54 +414,56 @@ def get_security_headers() -> Dict[str, str]:
"object-src 'none'; "
"media-src 'self'; "
"frame-src 'none';"
- )
+ ),
}
-
+
if config.network.enforce_https:
- headers["Strict-Transport-Security"] = f"max-age={config.network.hsts_max_age}; includeSubDomains"
-
+ headers["Strict-Transport-Security"] = (
+ f"max-age={config.network.hsts_max_age}; includeSubDomains"
+ )
+
return headers
def validate_security_config(config: SecurityConfig) -> List[str]:
"""
Validate security configuration and return warnings.
-
+
Args:
config: Security configuration to validate
-
+
Returns:
List of validation warnings
"""
warnings = []
-
+
# Check for insecure settings
if config.debug_mode:
warnings.append("Debug mode is enabled - disable in production")
-
+
if not config.auth.require_https:
warnings.append("HTTPS is not required - enable for production")
-
+
if not config.rate_limiting.enable_rate_limiting:
warnings.append("Rate limiting is disabled - enable for production")
-
+
if not config.encryption.encrypt_cache_content:
warnings.append("Cache encryption is disabled")
-
+
if config.validation.allow_html:
warnings.append("HTML input is allowed - potential XSS risk")
-
+
if config.validation.allow_javascript:
warnings.append("JavaScript input is allowed - high XSS risk")
-
+
# Check for weak settings
if config.auth.token_lifetime_seconds > 7200: # 2 hours
warnings.append("Token lifetime is longer than recommended (>2 hours)")
-
+
if config.rate_limiting.user_requests_per_minute > 120:
warnings.append("User rate limit is higher than recommended (>120/min)")
-
+
if config.validation.max_string_length > 1024 * 1024: # 1MB
warnings.append("Maximum string length is very large (>1MB)")
-
- return warnings
\ No newline at end of file
+
+ return warnings
diff --git a/src/security/error_handler.py b/src/security/error_handler.py
index ab7b701..dc047f1 100644
--- a/src/security/error_handler.py
+++ b/src/security/error_handler.py
@@ -12,22 +12,21 @@
import structlog
from core.errors import (
- SecurityError,
- AuthenticationError,
+ SecurityError,
+ AuthenticationError,
ValidationError,
DatabaseError,
ToolExecutionError,
- CacheError
+ CacheError,
)
-
logger = structlog.get_logger(__name__)
@dataclass
class SecureErrorResponse:
"""Represents a sanitized error response."""
-
+
error_code: str
message: str
status_code: int
@@ -40,16 +39,16 @@ class SecureErrorHandler:
Secure error handler that sanitizes error responses to prevent
information disclosure while maintaining debugging capabilities.
"""
-
+
def __init__(self, debug_mode: bool = False):
"""
Initialize secure error handler.
-
+
Args:
debug_mode: Whether to include debug information in responses
"""
self.debug_mode = debug_mode
-
+
# Error code mappings for different exception types
self.error_mappings = {
SecurityError: ("SECURITY_ERROR", 403),
@@ -66,7 +65,7 @@ def __init__(self, debug_mode: bool = False):
ConnectionError: ("CONNECTION_ERROR", 503),
TimeoutError: ("TIMEOUT_ERROR", 504),
}
-
+
# Safe error messages that don't reveal system internals
self.safe_messages = {
"SECURITY_ERROR": "Access denied due to security policy",
@@ -82,105 +81,107 @@ def __init__(self, debug_mode: bool = False):
"MISSING_KEY": "Required parameter missing",
"CONNECTION_ERROR": "Service temporarily unavailable",
"TIMEOUT_ERROR": "Request timed out",
- "INTERNAL_ERROR": "An internal error occurred"
+ "INTERNAL_ERROR": "An internal error occurred",
}
-
+
# Patterns to sanitize from error messages
self.sensitive_patterns = [
- r'password[=:]\s*\S+',
- r'api[_-]?key[=:]\s*\S+',
- r'token[=:]\s*\S+',
- r'secret[=:]\s*\S+',
- r'auth[=:]\s*\S+',
- r'/[a-zA-Z]:/.*', # Windows paths
- r'/home/[^/\s]+', # Unix home paths
- r'/etc/[^/\s]+', # System config paths
- r'localhost:\d+', # Local endpoints
- r'127\.0\.0\.1:\d+', # Local IP endpoints
- r'(\d{1,3}\.){3}\d{1,3}:\d+', # IP:port combinations
+ r"password[=:]\s*\S+",
+ r"api[_-]?key[=:]\s*\S+",
+ r"token[=:]\s*\S+",
+ r"secret[=:]\s*\S+",
+ r"auth[=:]\s*\S+",
+ r"/[a-zA-Z]:/.*", # Windows paths
+ r"/home/[^/\s]+", # Unix home paths
+ r"/etc/[^/\s]+", # System config paths
+ r"localhost:\d+", # Local endpoints
+ r"127\.0\.0\.1:\d+", # Local IP endpoints
+ r"(\d{1,3}\.){3}\d{1,3}:\d+", # IP:port combinations
]
-
- def handle_exception(self,
- exception: Exception,
- context: Optional[Dict[str, Any]] = None) -> SecureErrorResponse:
+
+ def handle_exception(
+ self, exception: Exception, context: Optional[Dict[str, Any]] = None
+ ) -> SecureErrorResponse:
"""
Handle an exception and return a secure error response.
-
+
Args:
exception: The exception to handle
context: Optional context information
-
+
Returns:
SecureErrorResponse with sanitized error information
"""
# Generate unique error ID for tracking
import uuid
+
error_id = str(uuid.uuid4())
-
+
# Get error code and status code
error_code, status_code = self._get_error_info(exception)
-
+
# Get safe message
safe_message = self._get_safe_message(exception, error_code)
-
+
# Log full error details for debugging
self._log_error_details(exception, error_id, context)
-
+
# Prepare response details
details = None
if self.debug_mode:
details = self._get_debug_details(exception, context)
-
+
return SecureErrorResponse(
error_code=error_code,
message=safe_message,
status_code=status_code,
details=details,
- error_id=error_id
+ error_id=error_id,
)
-
+
def sanitize_error_message(self, message: str) -> str:
"""
Sanitize an error message to remove sensitive information.
-
+
Args:
message: Original error message
-
+
Returns:
Sanitized error message
"""
import re
-
+
sanitized = message
-
+
# Remove sensitive patterns
for pattern in self.sensitive_patterns:
- sanitized = re.sub(pattern, '[REDACTED]', sanitized, flags=re.IGNORECASE)
-
+ sanitized = re.sub(pattern, "[REDACTED]", sanitized, flags=re.IGNORECASE)
+
# Remove file paths
- sanitized = re.sub(r'/[\w/.]+\.py', '[FILE_PATH]', sanitized)
- sanitized = re.sub(r'[A-Z]:\\[\w\\/.]+\.py', '[FILE_PATH]', sanitized)
-
+ sanitized = re.sub(r"/[\w/.]+\.py", "[FILE_PATH]", sanitized)
+ sanitized = re.sub(r"[A-Z]:\\[\w\\/.]+\.py", "[FILE_PATH]", sanitized)
+
# Remove line numbers from stack traces
- sanitized = re.sub(r'line \d+', 'line [NUM]', sanitized)
-
+ sanitized = re.sub(r"line \d+", "line [NUM]", sanitized)
+
# Remove memory addresses
- sanitized = re.sub(r'0x[0-9a-fA-F]+', '[MEMORY_ADDR]', sanitized)
-
+ sanitized = re.sub(r"0x[0-9a-fA-F]+", "[MEMORY_ADDR]", sanitized)
+
# Truncate very long messages
if len(sanitized) > 500:
- sanitized = sanitized[:497] + '...'
-
+ sanitized = sanitized[:497] + "..."
+
return sanitized
-
- def format_error_response(self,
- error_response: SecureErrorResponse) -> Dict[str, Any]:
+
+ def format_error_response(
+ self, error_response: SecureErrorResponse
+ ) -> Dict[str, Any]:
"""
Format error response for API consumption.
-
+
Args:
error_response: SecureErrorResponse to format
-
+
Returns:
Formatted error response dictionary
"""
@@ -188,66 +189,66 @@ def format_error_response(self,
"error": True,
"error_code": error_response.error_code,
"message": error_response.message,
- "error_id": error_response.error_id
+ "error_id": error_response.error_id,
}
-
+
if error_response.details and self.debug_mode:
response["details"] = error_response.details
-
+
return response
-
+
def _get_error_info(self, exception: Exception) -> Tuple[str, int]:
"""Get error code and HTTP status code for exception."""
exception_type = type(exception)
-
+
# Check for exact type match
if exception_type in self.error_mappings:
error_code, status_code = self.error_mappings[exception_type]
return error_code, status_code
-
+
# Check for inheritance
for exc_type, (error_code, status_code) in self.error_mappings.items():
if isinstance(exception, exc_type):
return error_code, status_code
-
+
# Default for unknown exceptions
return "INTERNAL_ERROR", 500
-
+
def _get_safe_message(self, exception: Exception, error_code: str) -> str:
"""Get a safe error message that doesn't reveal system internals."""
# First try to get the safe message from our mappings
safe_message = self.safe_messages.get(error_code, "An error occurred")
-
+
# For known security-related exceptions, we can include some details
if isinstance(exception, (ValidationError, SecurityError)):
# These exceptions are designed to be user-facing
original_message = str(exception)
sanitized_message = self.sanitize_error_message(original_message)
-
+
# Only use the original message if it's reasonably safe
- if (len(sanitized_message) < 200 and
- not any(keyword in sanitized_message.lower()
- for keyword in ['password', 'token', 'key', 'secret', 'internal'])):
+ if len(sanitized_message) < 200 and not any(
+ keyword in sanitized_message.lower()
+ for keyword in ["password", "token", "key", "secret", "internal"]
+ ):
return sanitized_message
-
+
return safe_message
-
- def _log_error_details(self,
- exception: Exception,
- error_id: str,
- context: Optional[Dict[str, Any]]) -> None:
+
+ def _log_error_details(
+ self, exception: Exception, error_id: str, context: Optional[Dict[str, Any]]
+ ) -> None:
"""Log full error details for debugging purposes."""
error_details = {
"error_id": error_id,
"exception_type": type(exception).__name__,
"exception_message": str(exception),
- "context": context or {}
+ "context": context or {},
}
-
+
# Add stack trace in debug mode
if self.debug_mode:
error_details["stack_trace"] = traceback.format_exc()
-
+
# Log based on severity
if isinstance(exception, (SecurityError, AuthenticationError)):
logger.warning("Security-related error occurred", **error_details)
@@ -255,19 +256,19 @@ def _log_error_details(self,
logger.info("Input validation error occurred", **error_details)
else:
logger.error("System error occurred", **error_details)
-
- def _get_debug_details(self,
- exception: Exception,
- context: Optional[Dict[str, Any]]) -> Dict[str, Any]:
+
+ def _get_debug_details(
+ self, exception: Exception, context: Optional[Dict[str, Any]]
+ ) -> Dict[str, Any]:
"""Get debug details for development environment."""
if not self.debug_mode:
return {}
-
+
details = {
"exception_type": type(exception).__name__,
- "exception_message": self.sanitize_error_message(str(exception))
+ "exception_message": self.sanitize_error_message(str(exception)),
}
-
+
if context:
# Sanitize context information
sanitized_context = {}
@@ -278,9 +279,9 @@ def _get_debug_details(self,
sanitized_context[key] = value
else:
sanitized_context[key] = str(type(value))
-
+
details["context"] = sanitized_context
-
+
return details
@@ -291,32 +292,34 @@ def _get_debug_details(self,
def get_error_handler(debug_mode: bool = False) -> SecureErrorHandler:
"""
Get or create global error handler instance.
-
+
Args:
debug_mode: Whether to enable debug mode
-
+
Returns:
SecureErrorHandler instance
"""
global _error_handler_instance
-
+
if _error_handler_instance is None:
_error_handler_instance = SecureErrorHandler(debug_mode)
-
+
return _error_handler_instance
-def handle_error(exception: Exception,
- context: Optional[Dict[str, Any]] = None,
- debug_mode: bool = False) -> Dict[str, Any]:
+def handle_error(
+ exception: Exception,
+ context: Optional[Dict[str, Any]] = None,
+ debug_mode: bool = False,
+) -> Dict[str, Any]:
"""
Convenience function to handle an error and return a formatted response.
-
+
Args:
exception: Exception to handle
context: Optional context information
debug_mode: Whether to include debug information
-
+
Returns:
Formatted error response dictionary
"""
@@ -328,12 +331,12 @@ def handle_error(exception: Exception,
def sanitize_message(message: str) -> str:
"""
Convenience function to sanitize an error message.
-
+
Args:
message: Message to sanitize
-
+
Returns:
Sanitized message
"""
error_handler = get_error_handler()
- return error_handler.sanitize_error_message(message)
\ No newline at end of file
+ return error_handler.sanitize_error_message(message)
diff --git a/src/security/input_sanitizer.py b/src/security/input_sanitizer.py
index 934b3a5..78c6fa6 100644
--- a/src/security/input_sanitizer.py
+++ b/src/security/input_sanitizer.py
@@ -18,11 +18,12 @@
# Fall back to absolute imports (when run as script)
import sys
from pathlib import Path
+
# Add src to path if not already there
src_path = str(Path(__file__).parent.parent)
if src_path not in sys.path:
sys.path.insert(0, src_path)
-
+
from core.errors import SecurityError, ValidationError
@@ -32,264 +33,274 @@
class InputSanitizer:
"""
Comprehensive input sanitization for security protection.
-
+
Provides sanitization against XSS, SQL injection, command injection,
and other common attack vectors.
"""
-
+
def __init__(self):
"""Initialize input sanitizer with security patterns."""
-
+
# XSS prevention patterns
self.xss_patterns = [
- r'"
with pytest.raises(Exception): # Should raise security exception
security_manager.sanitize_input(malicious_query)
-
- @pytest.mark.asyncio
+
+ @pytest.mark.asyncio
async def test_cache_system_integration(self, valid_test_env):
"""Test cache system integration with configuration."""
with patch.dict(os.environ, valid_test_env, clear=True):
config = Config()
-
+
# Cache configuration should be loaded from environment
cache_config = config.cache_config
assert cache_config["max_size"] == "1000" # From test env
assert cache_config["prefix"] == "fact_v1"
-
+
# Cache validator should initialize
validator = CacheValidator()
assert validator is not None
-
+
@pytest.mark.asyncio
- async def test_driver_initialization_integration(self, test_database, valid_test_env):
+ async def test_driver_initialization_integration(
+ self, test_database, valid_test_env
+ ):
"""Test complete driver initialization with all components."""
with patch.dict(os.environ, valid_test_env, clear=True):
- with patch('src.core.driver.FACTDriver._test_connections', return_value=None):
+ with patch(
+ "src.core.driver.FACTDriver._test_connections", return_value=None
+ ):
# Driver should initialize successfully
driver = await get_driver()
-
+
assert driver is not None
assert driver.config is not None
assert driver.tool_registry is not None
-
+
# Should have tools registered
tools = driver.tool_registry.list_tools()
assert len(tools) > 0
-
+
# Should have metrics
metrics = driver.get_metrics()
assert metrics["initialized"] is True
-
+
await driver.shutdown()
-
+
@pytest.mark.asyncio
async def test_benchmark_framework_integration(self, valid_test_env):
"""Test benchmark framework integration with system components."""
with patch.dict(os.environ, valid_test_env, clear=True):
framework = BenchmarkFramework()
-
+
# Should initialize with configuration
assert framework.config is not None
assert framework.metrics_collector is not None
-
+
# Should handle empty query list
summary = await framework.run_benchmark_suite([])
assert summary.total_queries == 0
assert summary.successful_queries == 0
-
+
@pytest.mark.asyncio
async def test_end_to_end_query_processing(self, test_database, valid_test_env):
"""Test complete end-to-end query processing through all system layers."""
with patch.dict(os.environ, valid_test_env, clear=True):
- with patch('src.core.driver.FACTDriver._test_connections', return_value=None):
- with patch('anthropic.AsyncAnthropic') as mock_anthropic:
+ with patch(
+ "src.core.driver.FACTDriver._test_connections", return_value=None
+ ):
+ with patch("anthropic.AsyncAnthropic") as mock_anthropic:
# Mock successful API response
mock_response = MagicMock()
- mock_response.content = [MagicMock(text="TechCorp's Q1 2025 revenue was $50M")]
- mock_anthropic.return_value.messages.create.return_value = mock_response
-
+ mock_response.content = [
+ MagicMock(text="TechCorp's Q1 2025 revenue was $50M")
+ ]
+ mock_anthropic.return_value.messages.create.return_value = (
+ mock_response
+ )
+
# Initialize driver
driver = await get_driver()
-
+
# Process a query end-to-end
- response = await driver.process_query("What was TechCorp's Q1 2025 revenue?")
-
+ response = await driver.process_query(
+ "What was TechCorp's Q1 2025 revenue?"
+ )
+
assert response is not None
assert "TechCorp" in response
assert "revenue" in response
-
+
await driver.shutdown()
-
+
@pytest.mark.asyncio
async def test_performance_monitoring_integration(self, valid_test_env):
"""Test performance monitoring integration."""
with patch.dict(os.environ, valid_test_env, clear=True):
metrics_collector = MetricsCollector()
-
+
# Should record metrics
await metrics_collector.record_query_metrics(
- query="test query",
- response_time_ms=50.0,
- success=True,
- cache_hit=False
+ query="test query", response_time_ms=50.0, success=True, cache_hit=False
)
-
+
# Should provide summary
summary = await metrics_collector.get_performance_summary()
assert summary is not None
assert "total_queries" in summary
-
+
def test_configuration_validation_integration(self, valid_test_env):
"""Test that configuration validation works across all components."""
# Test with valid configuration
with patch.dict(os.environ, valid_test_env, clear=True):
config = Config()
assert config is not None
-
+
# Test with missing keys
with patch.dict(os.environ, {}, clear=True):
with pytest.raises(ConfigurationError) as exc_info:
@@ -220,39 +232,42 @@ def test_configuration_validation_integration(self, valid_test_env):
class TestSystemErrorHandling:
"""Test error handling across integrated components."""
-
+
def test_graceful_degradation_with_missing_api_keys(self):
"""Test system behavior when API keys are missing."""
with patch.dict(os.environ, {}, clear=True):
# Configuration should fail fast
with pytest.raises(ConfigurationError):
Config()
-
+
def test_database_error_handling(self):
"""Test database error handling in integrated system."""
test_env = {
"ANTHROPIC_API_KEY": "sk-test-key",
"ARCADE_API_KEY": "ak-test-key",
- "DATABASE_PATH": "/invalid/path/database.db"
+ "DATABASE_PATH": "/invalid/path/database.db",
}
-
+
with patch.dict(os.environ, test_env, clear=True):
config = Config()
-
+
# Should handle invalid database path gracefully
db_manager = DatabaseManager(config.database_path)
# Database operations should fail with clear error messages
-
+
@pytest.mark.asyncio
async def test_api_connectivity_error_handling(self):
"""Test API connectivity error handling."""
test_env = {
"ANTHROPIC_API_KEY": "sk-invalid-key",
- "ARCADE_API_KEY": "ak-invalid-key"
+ "ARCADE_API_KEY": "ak-invalid-key",
}
-
+
with patch.dict(os.environ, test_env, clear=True):
- with patch('src.core.driver.FACTDriver._test_connections', side_effect=Exception("API Error")):
+ with patch(
+ "src.core.driver.FACTDriver._test_connections",
+ side_effect=Exception("API Error"),
+ ):
# Should handle API errors gracefully
with pytest.raises(Exception):
await get_driver()
@@ -260,86 +275,90 @@ async def test_api_connectivity_error_handling(self):
class TestSystemPerformance:
"""Test system performance characteristics."""
-
+
@pytest.mark.asyncio
async def test_concurrent_query_handling(self, valid_test_env=None):
"""Test system handling of concurrent queries."""
if valid_test_env is None:
valid_test_env = {
"ANTHROPIC_API_KEY": "sk-test-key",
- "ARCADE_API_KEY": "ak-test-key"
+ "ARCADE_API_KEY": "ak-test-key",
}
-
+
with patch.dict(os.environ, valid_test_env, clear=True):
- with patch('src.core.driver.FACTDriver._test_connections', return_value=None):
+ with patch(
+ "src.core.driver.FACTDriver._test_connections", return_value=None
+ ):
framework = BenchmarkFramework()
-
+
# Test concurrent query processing
queries = ["Query 1", "Query 2", "Query 3"]
-
- with patch('src.benchmarking.framework.process_user_query', return_value="Mock response"):
+
+ with patch(
+ "src.benchmarking.framework.process_user_query",
+ return_value="Mock response",
+ ):
summary = await framework.run_benchmark_suite(queries)
-
+
# Should handle all queries
expected_total = len(queries) * framework.config.iterations
assert summary.total_queries == expected_total
-
+
@pytest.mark.asyncio
async def test_cache_performance_validation(self):
"""Test cache performance meets targets."""
test_env = {
- "ANTHROPIC_API_KEY": "sk-test-key",
+ "ANTHROPIC_API_KEY": "sk-test-key",
"ARCADE_API_KEY": "ak-test-key",
"CACHE_HIT_TARGET_MS": "30",
- "CACHE_MISS_TARGET_MS": "120"
+ "CACHE_MISS_TARGET_MS": "120",
}
-
+
with patch.dict(os.environ, test_env, clear=True):
validator = CacheValidator()
-
+
# Validate cache performance targets
# This test validates the cache system meets performance requirements
class TestSystemRecovery:
"""Test system recovery and resilience."""
-
+
@pytest.mark.asyncio
async def test_system_recovery_after_failure(self):
"""Test system recovery after component failure."""
- test_env = {
- "ANTHROPIC_API_KEY": "sk-test-key",
- "ARCADE_API_KEY": "ak-test-key"
- }
-
+ test_env = {"ANTHROPIC_API_KEY": "sk-test-key", "ARCADE_API_KEY": "ak-test-key"}
+
with patch.dict(os.environ, test_env, clear=True):
# Simulate system failure and recovery
- with patch('src.core.driver.FACTDriver._test_connections', return_value=None):
+ with patch(
+ "src.core.driver.FACTDriver._test_connections", return_value=None
+ ):
driver = await get_driver()
-
+
# System should be operational
assert driver is not None
-
+
# Simulate graceful shutdown
await driver.shutdown()
-
+
def test_configuration_reload_capability(self):
"""Test system ability to reload configuration."""
# Test configuration reloading without restart
test_env_1 = {
- "ANTHROPIC_API_KEY": "sk-test-key-1",
- "ARCADE_API_KEY": "ak-test-key-1"
+ "ANTHROPIC_API_KEY": "sk-test-key-1",
+ "ARCADE_API_KEY": "ak-test-key-1",
}
-
+
test_env_2 = {
"ANTHROPIC_API_KEY": "sk-test-key-2",
- "ARCADE_API_KEY": "ak-test-key-2"
+ "ARCADE_API_KEY": "ak-test-key-2",
}
-
+
with patch.dict(os.environ, test_env_1, clear=True):
config1 = Config()
assert config1.anthropic_api_key == "sk-test-key-1"
-
+
with patch.dict(os.environ, test_env_2, clear=True):
config2 = Config()
- assert config2.anthropic_api_key == "sk-test-key-2"
\ No newline at end of file
+ assert config2.anthropic_api_key == "sk-test-key-2"
diff --git a/tests/integration/test_system_integration.py b/tests/integration/test_system_integration.py
index a7bdaf3..0ac64da 100644
--- a/tests/integration/test_system_integration.py
+++ b/tests/integration/test_system_integration.py
@@ -17,94 +17,111 @@
class TestEndToEndWorkflow:
"""Integration tests for complete FACT workflow."""
-
+
@pytest.mark.integration
- async def test_complete_query_processing_workflow(self, test_environment, benchmark_queries):
+ async def test_complete_query_processing_workflow(
+ self, test_environment, benchmark_queries
+ ):
"""TEST: Complete query from user input to final response"""
# Arrange
query = "What was Q1-2025 revenue?"
expected_revenue = 1234567.89
-
+
# Act
with test_environment:
response = await process_user_query(query)
-
+
# Assert
assert response is not None
assert str(expected_revenue) in response or "1,234,567.89" in response
assert "Q1-2025" in response
-
+
# Verify response contains structured financial data
- assert any(keyword in response.lower() for keyword in ["revenue", "quarter", "financial"])
-
+ assert any(
+ keyword in response.lower()
+ for keyword in ["revenue", "quarter", "financial"]
+ )
+
@pytest.mark.integration
- async def test_cache_miss_to_tool_execution_workflow(self, test_environment, test_database):
+ async def test_cache_miss_to_tool_execution_workflow(
+ self, test_environment, test_database
+ ):
"""TEST: Cache miss triggers tool execution and response generation"""
# Arrange
unique_query = f"What is the total revenue for 2024? (timestamp: {time.time()})"
-
+
# Configure mock to simulate tool call
mock_tool_call = Mock()
mock_tool_call.name = "SQL.QueryReadonly"
mock_tool_call.id = "test-call-123"
- mock_tool_call.arguments = json.dumps({
- "statement": "SELECT SUM(value) as total FROM revenue WHERE quarter LIKE '%2024%'"
- })
-
+ mock_tool_call.arguments = json.dumps(
+ {
+ "statement": "SELECT SUM(value) as total FROM revenue WHERE quarter LIKE '%2024%'"
+ }
+ )
+
mock_response = Mock()
- mock_response.content = [Mock(text="The total revenue for 2024 was $2,997,419.08")]
+ mock_response.content = [
+ Mock(text="The total revenue for 2024 was $2,997,419.08")
+ ]
mock_response.tool_calls = [mock_tool_call]
-
+
# Act
with test_environment as env:
env["anthropic"].messages.create.return_value = mock_response
env["arcade"].tools.execute.return_value = {
"status": "success",
- "data": {"rows": [{"total": 2997419.08}], "row_count": 1}
+ "data": {"rows": [{"total": 2997419.08}], "row_count": 1},
}
-
+
response = await process_user_query(unique_query)
-
+
# Assert
assert response is not None
assert "2,997,419.08" in response or "2997419.08" in response
assert "2024" in response
-
+
# Verify tool was executed
env["arcade"].tools.execute.assert_called_once()
call_args = env["arcade"].tools.execute.call_args[1]
assert "SQL.QueryReadonly" in str(call_args)
-
+
@pytest.mark.integration
- async def test_cache_hit_workflow_bypasses_tool_execution(self, test_environment, cache_config):
+ async def test_cache_hit_workflow_bypasses_tool_execution(
+ self, test_environment, cache_config
+ ):
"""TEST: Cache hit bypasses tool execution for faster response"""
# Arrange
cache_manager = CacheManager(config=cache_config)
query = "What was Q1-2025 revenue?"
- cached_response = "Q1-2025 revenue was $1,234,567.89 based on our financial records."
-
+ cached_response = (
+ "Q1-2025 revenue was $1,234,567.89 based on our financial records."
+ )
+
# Pre-populate cache
query_hash = cache_manager.generate_hash(query)
cache_manager.store(query_hash, cached_response)
-
+
# Act
start_time = time.perf_counter()
-
+
with test_environment as env:
- with patch('src.cache.manager.cache_manager', cache_manager):
+ with patch("src.cache.manager.cache_manager", cache_manager):
response = await process_user_query(query)
-
+
end_time = time.perf_counter()
response_time_ms = (end_time - start_time) * 1000
-
+
# Assert
assert response == cached_response
- assert response_time_ms < 50, f"Cache hit took {response_time_ms:.2f}ms, exceeds 50ms target"
-
+ assert (
+ response_time_ms < 50
+ ), f"Cache hit took {response_time_ms:.2f}ms, exceeds 50ms target"
+
# Verify tool execution was bypassed
env["arcade"].tools.execute.assert_not_called()
env["anthropic"].messages.create.assert_not_called()
-
+
@pytest.mark.integration
async def test_error_handling_and_recovery_workflow(self, test_environment):
"""TEST: System handles errors gracefully and provides useful feedback"""
@@ -114,22 +131,22 @@ async def test_error_handling_and_recovery_workflow(self, test_environment):
"name": "database_connection_error",
"query": "SELECT * FROM revenue",
"error": Exception("Database connection failed"),
- "expected_in_response": ["error", "database", "connection"]
+ "expected_in_response": ["error", "database", "connection"],
},
{
"name": "tool_execution_timeout",
"query": "Complex analytical query",
"error": TimeoutError("Tool execution timed out"),
- "expected_in_response": ["timeout", "error"]
+ "expected_in_response": ["timeout", "error"],
},
{
"name": "anthropic_api_error",
"query": "What is the revenue?",
"error": Exception("API rate limit exceeded"),
- "expected_in_response": ["error", "api", "temporarily"]
- }
+ "expected_in_response": ["error", "api", "temporarily"],
+ },
]
-
+
# Test each error scenario
for scenario in error_scenarios:
# Act
@@ -138,63 +155,78 @@ async def test_error_handling_and_recovery_workflow(self, test_environment):
env["arcade"].tools.execute.side_effect = scenario["error"]
elif "anthropic" in scenario["name"]:
env["anthropic"].messages.create.side_effect = scenario["error"]
-
+
response = await process_user_query(scenario["query"])
-
+
# Assert
assert response is not None
response_lower = response.lower()
-
+
# Should contain error context
- assert any(keyword in response_lower for keyword in scenario["expected_in_response"])
-
+ assert any(
+ keyword in response_lower
+ for keyword in scenario["expected_in_response"]
+ )
+
# Should not contain sensitive information
- assert not any(sensitive in response_lower for sensitive in ["api_key", "token", "secret"])
-
+ assert not any(
+ sensitive in response_lower
+ for sensitive in ["api_key", "token", "secret"]
+ )
+
@pytest.mark.integration
def test_system_initialization_and_configuration(self, test_environment):
"""TEST: System initializes properly with all components"""
# Act
with test_environment:
- anthropic_client, arcade_client, cache_prefix, system_prompt = initialize_system()
-
+ anthropic_client, arcade_client, cache_prefix, system_prompt = (
+ initialize_system()
+ )
+
# Assert
assert anthropic_client is not None
assert arcade_client is not None
assert cache_prefix == "fact_test_v1"
assert "deterministic" in system_prompt.lower()
- assert "finance" in system_prompt.lower() or "financial" in system_prompt.lower()
-
+ assert (
+ "finance" in system_prompt.lower() or "financial" in system_prompt.lower()
+ )
+
@pytest.mark.integration
async def test_concurrent_user_sessions(self, test_environment, benchmark_queries):
"""TEST: System handles concurrent user sessions correctly"""
+
# Arrange
async def user_session(user_id: int, queries: list):
session_results = []
for query in queries:
user_query = f"{query} (user {user_id})"
result = await process_user_query(user_query)
- session_results.append({
- "user_id": user_id,
- "query": user_query,
- "response": result,
- "timestamp": time.time()
- })
+ session_results.append(
+ {
+ "user_id": user_id,
+ "query": user_query,
+ "response": result,
+ "timestamp": time.time(),
+ }
+ )
return session_results
-
+
# Act
user_count = 5
queries_per_user = benchmark_queries[:3] # First 3 queries
-
+
with test_environment:
- concurrent_sessions = await asyncio.gather(*[
- user_session(user_id, queries_per_user)
- for user_id in range(user_count)
- ])
-
+ concurrent_sessions = await asyncio.gather(
+ *[
+ user_session(user_id, queries_per_user)
+ for user_id in range(user_count)
+ ]
+ )
+
# Assert
assert len(concurrent_sessions) == user_count
-
+
# Verify all sessions completed successfully
for session in concurrent_sessions:
assert len(session) == len(queries_per_user)
@@ -205,134 +237,135 @@ async def user_session(user_id: int, queries: list):
class TestToolIntegration:
"""Integration tests for tool registration and execution."""
-
+
@pytest.mark.integration
- def test_tool_registration_and_discovery_flow(self, test_environment, mock_arcade_client):
+ def test_tool_registration_and_discovery_flow(
+ self, test_environment, mock_arcade_client
+ ):
"""TEST: Tool registration and subsequent discovery workflow"""
# Arrange
from src.tools.decorators import tool, register_tool, discover_tools
-
+
@tool(
name="TestTool.Integration",
description="Integration test tool",
- parameters={
- "input": {"type": "string", "description": "Test input"}
- }
+ parameters={"input": {"type": "string", "description": "Test input"}},
)
def integration_test_tool(input: str) -> dict:
return {"processed": input.upper(), "timestamp": time.time()}
-
+
# Act
register_tool(integration_test_tool)
discovered_tools = discover_tools()
-
+
# Assert
- tool_names = [t._tool_metadata['name'] for t in discovered_tools]
+ tool_names = [t._tool_metadata["name"] for t in discovered_tools]
assert "TestTool.Integration" in tool_names
-
+
# Verify tool can be executed
from src.tools.decorators import execute_tool
+
result = execute_tool("TestTool.Integration", {"input": "test_value"})
-
+
assert result["success"] == True
assert result["data"]["processed"] == "TEST_VALUE"
-
+
@pytest.mark.integration
def test_sql_tool_end_to_end_execution(self, test_environment, test_database):
"""TEST: SQL tool end-to-end execution workflow"""
# Arrange
sql_tool = SQLQueryTool(database_path=test_database)
test_query = "SELECT quarter, value FROM revenue WHERE value > 1000000 ORDER BY value DESC"
-
+
# Act
result = sql_tool.execute({"statement": test_query})
-
+
# Assert
assert result["success"] == True
assert "data" in result
assert "rows" in result["data"]
assert len(result["data"]["rows"]) > 0
-
+
# Verify results are properly formatted
first_row = result["data"]["rows"][0]
assert "quarter" in first_row
assert "value" in first_row
assert isinstance(first_row["value"], (int, float))
assert first_row["value"] > 1000000
-
+
@pytest.mark.integration
- async def test_tool_call_via_arcade_integration(self, test_environment, mock_arcade_client):
+ async def test_tool_call_via_arcade_integration(
+ self, test_environment, mock_arcade_client
+ ):
"""TEST: Tool execution via Arcade.dev integration"""
# Arrange
from src.tools.decorators import execute_tool_via_arcade
-
+
# Configure successful tool execution
mock_arcade_client.tools.execute.return_value = {
"status": "success",
"data": {
"rows": [{"quarter": "Q1-2025", "value": 1234567.89}],
"row_count": 1,
- "execution_time_ms": 8
- }
+ "execution_time_ms": 8,
+ },
}
-
+
tool_call_data = {
"name": "SQL.QueryReadonly",
"arguments": {
"statement": "SELECT quarter, value FROM revenue WHERE quarter = 'Q1-2025'"
- }
+ },
}
-
+
# Act
result = await execute_tool_via_arcade(
- tool_call_data["name"],
- tool_call_data["arguments"],
- mock_arcade_client
+ tool_call_data["name"], tool_call_data["arguments"], mock_arcade_client
)
-
+
# Assert
assert result["success"] == True
assert result["data"]["row_count"] == 1
assert result["data"]["rows"][0]["quarter"] == "Q1-2025"
assert result["data"]["execution_time_ms"] < 10
-
+
# Verify Arcade client was called correctly
mock_arcade_client.tools.execute.assert_called_once()
-
+
@pytest.mark.integration
def test_tool_authorization_integration(self, test_environment):
"""TEST: Tool authorization integration with security framework"""
# Arrange
from src.tools.decorators import tool, register_tool, execute_tool_with_auth
from src.core.errors import AuthorizationError
-
+
@tool(
name="TestTool.Protected",
description="Protected tool requiring admin access",
- required_scopes=["admin", "write"]
+ required_scopes=["admin", "write"],
)
def protected_tool() -> dict:
return {"message": "Admin operation completed"}
-
+
register_tool(protected_tool)
-
+
# Test insufficient permissions
with pytest.raises(AuthorizationError):
execute_tool_with_auth(
"TestTool.Protected",
{},
user_scopes=["read"],
- user_id="test@example.com"
+ user_id="test@example.com",
)
-
+
# Test sufficient permissions
result = execute_tool_with_auth(
"TestTool.Protected",
{},
user_scopes=["admin", "write", "read"],
- user_id="admin@example.com"
+ user_id="admin@example.com",
)
-
+
# Assert
assert result["success"] == True
assert result["data"]["message"] == "Admin operation completed"
@@ -340,7 +373,7 @@ def protected_tool() -> dict:
class TestDatabaseIntegration:
"""Integration tests for database operations and connection management."""
-
+
@pytest.mark.database
def test_database_connection_pool_under_load(self, test_database):
"""TEST: Database connection pool handles concurrent load"""
@@ -348,164 +381,177 @@ def test_database_connection_pool_under_load(self, test_database):
pool = DatabaseConnectionPool(test_database, max_connections=3)
import threading
import queue
-
+
results_queue = queue.Queue()
-
+
def database_worker(worker_id: int):
try:
connection = pool.acquire_connection()
cursor = connection.execute("SELECT COUNT(*) as count FROM revenue")
result = cursor.fetchone()
pool.release_connection(connection)
- results_queue.put({"worker_id": worker_id, "count": result[0], "success": True})
+ results_queue.put(
+ {"worker_id": worker_id, "count": result[0], "success": True}
+ )
except Exception as e:
- results_queue.put({"worker_id": worker_id, "error": str(e), "success": False})
-
+ results_queue.put(
+ {"worker_id": worker_id, "error": str(e), "success": False}
+ )
+
# Act
workers = []
for i in range(5): # More workers than pool size
worker = threading.Thread(target=database_worker, args=(i,))
workers.append(worker)
worker.start()
-
+
# Wait for all workers to complete
for worker in workers:
worker.join(timeout=5.0)
-
+
# Collect results
results = []
while not results_queue.empty():
results.append(results_queue.get())
-
+
# Assert
assert len(results) == 5
successful_results = [r for r in results if r["success"]]
assert len(successful_results) == 5 # All should succeed
-
+
# Verify correct data
for result in successful_results:
assert result["count"] == 4 # Test database has 4 revenue records
-
+
@pytest.mark.database
def test_database_transaction_integrity(self, test_database):
"""TEST: Database maintains transaction integrity"""
# Arrange
from src.tools.connectors.sql import execute_sql_query
-
+
# Verify initial state
- initial_result = execute_sql_query("SELECT COUNT(*) as count FROM revenue", test_database)
+ initial_result = execute_sql_query(
+ "SELECT COUNT(*) as count FROM revenue", test_database
+ )
initial_count = initial_result.rows[0]["count"]
-
+
# Act - Try to execute multiple queries
queries = [
"SELECT * FROM revenue WHERE quarter = 'Q1-2025'",
"SELECT AVG(value) as avg_value FROM revenue",
- "SELECT quarter, value FROM revenue ORDER BY value DESC LIMIT 2"
+ "SELECT quarter, value FROM revenue ORDER BY value DESC LIMIT 2",
]
-
+
results = []
for query in queries:
result = execute_sql_query(query, test_database)
results.append(result)
-
+
# Verify final state
- final_result = execute_sql_query("SELECT COUNT(*) as count FROM revenue", test_database)
+ final_result = execute_sql_query(
+ "SELECT COUNT(*) as count FROM revenue", test_database
+ )
final_count = final_result.rows[0]["count"]
-
+
# Assert
assert all(r.success for r in results)
assert initial_count == final_count # No data should be modified
assert initial_count == 4 # Original test data intact
-
+
@pytest.mark.database
def test_database_error_recovery(self, test_database):
"""TEST: Database error recovery and connection resilience"""
# Arrange
from src.tools.connectors.sql import execute_sql_query
from src.core.errors import DatabaseError
-
+
# Test invalid query handling
invalid_queries = [
"SELECT * FROM non_existent_table",
"SELECT invalid_column FROM revenue",
- "SELECT * FROM revenue WHERE invalid_syntax ="
+ "SELECT * FROM revenue WHERE invalid_syntax =",
]
-
+
# Act & Assert
for invalid_query in invalid_queries:
with pytest.raises(DatabaseError):
execute_sql_query(invalid_query, test_database)
-
+
# Verify database is still functional after errors
- recovery_result = execute_sql_query("SELECT COUNT(*) as count FROM revenue", test_database)
+ recovery_result = execute_sql_query(
+ "SELECT COUNT(*) as count FROM revenue", test_database
+ )
assert recovery_result.success == True
assert recovery_result.rows[0]["count"] == 4
class TestCacheIntegration:
"""Integration tests for cache system integration."""
-
+
@pytest.mark.cache
- async def test_cache_integration_with_query_processing(self, test_environment, cache_config):
+ async def test_cache_integration_with_query_processing(
+ self, test_environment, cache_config
+ ):
"""TEST: Cache integration with complete query processing workflow"""
# Arrange
cache_manager = CacheManager(config=cache_config)
query = "What was Q1-2025 revenue and how does it compare to previous quarters?"
-
+
# First request (cache miss)
with test_environment as env:
- with patch('src.cache.manager.cache_manager', cache_manager):
+ with patch("src.cache.manager.cache_manager", cache_manager):
# Act - First request
first_response = await process_user_query(query)
-
+
# Verify cache miss behavior
assert first_response is not None
env["anthropic"].messages.create.assert_called()
-
+
# Act - Second request (should be cache hit)
env["anthropic"].messages.create.reset_mock()
second_response = await process_user_query(query)
-
+
# Assert
assert second_response is not None
# Second request should hit cache and not call Anthropic
env["anthropic"].messages.create.assert_not_called()
-
+
# Verify cache metrics
metrics = cache_manager.get_metrics()
assert metrics.total_requests >= 2
assert metrics.cache_hits >= 1
-
+
@pytest.mark.cache
def test_cache_invalidation_integration(self, test_environment, cache_config):
"""TEST: Cache invalidation integration with system updates"""
# Arrange
cache_manager = CacheManager(config=cache_config)
-
+
# Populate cache with multiple entries
test_queries = [
"What is Q1-2025 revenue?",
"Show quarterly summary",
- "Calculate total 2024 revenue"
+ "Calculate total 2024 revenue",
]
-
+
for query in test_queries:
content = f"Cached response for: {query}"
query_hash = cache_manager.generate_hash(query)
cache_manager.store(query_hash, content)
-
+
# Verify cache is populated
assert cache_manager.get_metrics().total_entries == 3
-
+
# Act - Trigger cache invalidation (e.g., due to schema change)
from src.cache.manager import invalidate_on_schema_change
- with patch('src.cache.manager.cache_manager', cache_manager):
+
+ with patch("src.cache.manager.cache_manager", cache_manager):
invalidated_count = invalidate_on_schema_change("Database schema updated")
-
+
# Assert
assert invalidated_count == 3
assert cache_manager.get_metrics().total_entries == 0
-
+
# Verify all queries now result in cache misses
for query in test_queries:
query_hash = cache_manager.generate_hash(query)
@@ -515,54 +561,62 @@ def test_cache_invalidation_integration(self, test_environment, cache_config):
@pytest.mark.integration
class TestSystemStability:
"""Integration tests for system stability and reliability."""
-
- async def test_memory_usage_under_sustained_load(self, test_environment, benchmark_queries):
+
+ async def test_memory_usage_under_sustained_load(
+ self, test_environment, benchmark_queries
+ ):
"""TEST: Memory usage remains stable under sustained load"""
# Arrange
import psutil
import os
-
+
process = psutil.Process(os.getpid())
initial_memory = process.memory_info().rss / 1024 / 1024 # MB
-
+
# Act - Sustained load test
query_count = 0
max_queries = 100
-
+
with test_environment:
while query_count < max_queries:
query = benchmark_queries[query_count % len(benchmark_queries)]
response = await process_user_query(query)
-
+
assert response is not None
query_count += 1
-
+
# Check memory every 10 queries
if query_count % 10 == 0:
current_memory = process.memory_info().rss / 1024 / 1024
memory_growth = current_memory - initial_memory
-
+
# Memory growth should be reasonable (< 50MB for 100 queries)
- assert memory_growth < 50, f"Memory growth {memory_growth:.1f}MB too high"
-
+ assert (
+ memory_growth < 50
+ ), f"Memory growth {memory_growth:.1f}MB too high"
+
# Final memory check
final_memory = process.memory_info().rss / 1024 / 1024
total_growth = final_memory - initial_memory
-
+
# Assert
- assert total_growth < 100, f"Total memory growth {total_growth:.1f}MB exceeds limit"
-
- async def test_error_rate_under_normal_operation(self, test_environment, benchmark_queries):
+ assert (
+ total_growth < 100
+ ), f"Total memory growth {total_growth:.1f}MB exceeds limit"
+
+ async def test_error_rate_under_normal_operation(
+ self, test_environment, benchmark_queries
+ ):
"""TEST: Error rate remains low under normal operation"""
# Arrange
total_requests = 50
error_count = 0
-
+
# Act
with test_environment:
for i in range(total_requests):
query = benchmark_queries[i % len(benchmark_queries)]
-
+
try:
response = await process_user_query(query)
assert response is not None
@@ -570,36 +624,40 @@ async def test_error_rate_under_normal_operation(self, test_environment, benchma
error_count += 1
# Log error for debugging
print(f"Query {i} failed: {str(e)}")
-
+
# Assert
error_rate = error_count / total_requests
assert error_rate < 0.01, f"Error rate {error_rate:.1%} exceeds 1% threshold"
-
+
def test_configuration_validation_and_recovery(self, test_environment):
"""TEST: System validates configuration and recovers from invalid configs"""
# Arrange
from src.core.config import validate_configuration, ConfigurationError
-
+
invalid_configs = [
{}, # Empty config
{"ANTHROPIC_API_KEY": ""}, # Empty API key
{"ANTHROPIC_API_KEY": "valid", "ARCADE_URL": "invalid-url"}, # Invalid URL
- {"ANTHROPIC_API_KEY": "valid", "ARCADE_URL": "http://localhost:9099", "ARCADE_API_KEY": ""} # Empty Arcade key
+ {
+ "ANTHROPIC_API_KEY": "valid",
+ "ARCADE_URL": "http://localhost:9099",
+ "ARCADE_API_KEY": "",
+ }, # Empty Arcade key
]
-
+
# Test invalid configurations
for config in invalid_configs:
- with patch.dict('os.environ', config, clear=True):
+ with patch.dict("os.environ", config, clear=True):
with pytest.raises(ConfigurationError):
validate_configuration()
-
+
# Test valid configuration recovery
valid_config = {
"ANTHROPIC_API_KEY": "test-anthropic-key",
"ARCADE_API_KEY": "test-arcade-key",
- "ARCADE_URL": "http://localhost:9099"
+ "ARCADE_URL": "http://localhost:9099",
}
-
- with patch.dict('os.environ', valid_config):
+
+ with patch.dict("os.environ", valid_config):
# Should not raise exception
- validate_configuration()
\ No newline at end of file
+ validate_configuration()
diff --git a/tests/performance/test_benchmarks.py b/tests/performance/test_benchmarks.py
index d1bd0fb..7220788 100644
--- a/tests/performance/test_benchmarks.py
+++ b/tests/performance/test_benchmarks.py
@@ -17,6 +17,7 @@
@dataclass
class PerformanceMetrics:
"""Performance metrics for benchmark testing."""
+
response_time_ms: float
token_cost: float
cache_hit: bool
@@ -27,6 +28,7 @@ class PerformanceMetrics:
@dataclass
class BenchmarkResults:
"""Benchmark results comparison."""
+
fact_metrics: List[PerformanceMetrics]
rag_baseline_metrics: List[PerformanceMetrics]
improvement_factor: float
@@ -35,238 +37,271 @@ class BenchmarkResults:
class TestResponseTimeTargets:
"""Test suite for response time performance targets."""
-
+
@pytest.mark.performance
- async def test_cache_hit_response_under_50ms(self, test_environment, cache_config, benchmark_queries):
+ async def test_cache_hit_response_under_50ms(
+ self, test_environment, cache_config, benchmark_queries
+ ):
"""TEST: Cache hit responses achieve target latency under 50ms"""
# Arrange
from src.core.driver import process_user_query
from src.cache.manager import CacheManager
-
+
manager = CacheManager(config=cache_config)
query = benchmark_queries[0]
-
+
# Pre-warm cache
cached_response = "Cached response: Q1-2025 revenue was $1,234,567.89"
query_hash = manager.generate_hash(query)
manager.store(query_hash, cached_response)
-
+
# Act
measurements = []
for _ in range(10): # Multiple measurements for accuracy
start_time = time.perf_counter()
-
- with patch('src.cache.manager.cache_manager', manager):
+
+ with patch("src.cache.manager.cache_manager", manager):
response = await process_user_query(query)
-
+
end_time = time.perf_counter()
latency_ms = (end_time - start_time) * 1000
measurements.append(latency_ms)
-
+
# Assert
avg_latency = statistics.mean(measurements)
p95_latency = statistics.quantiles(measurements, n=20)[18] # 95th percentile
-
- assert avg_latency < 50, f"Average cache hit latency {avg_latency:.2f}ms exceeds 50ms target"
- assert p95_latency < 50, f"P95 cache hit latency {p95_latency:.2f}ms exceeds 50ms target"
+
+ assert (
+ avg_latency < 50
+ ), f"Average cache hit latency {avg_latency:.2f}ms exceeds 50ms target"
+ assert (
+ p95_latency < 50
+ ), f"P95 cache hit latency {p95_latency:.2f}ms exceeds 50ms target"
assert response is not None
-
+
@pytest.mark.performance
- async def test_cache_miss_response_under_140ms(self, test_environment, mock_anthropic_client, benchmark_queries):
+ async def test_cache_miss_response_under_140ms(
+ self, test_environment, mock_anthropic_client, benchmark_queries
+ ):
"""TEST: Cache miss responses achieve target latency under 140ms"""
# Arrange
from src.core.driver import process_user_query
-
+
# Configure mock for realistic response time
mock_anthropic_client.messages.create = AsyncMock()
mock_response = Mock()
mock_response.content = [Mock(text="Mock response from Claude")]
mock_response.tool_calls = None
mock_anthropic_client.messages.create.return_value = mock_response
-
+
# Simulate network latency
async def delayed_response(*args, **kwargs):
await asyncio.sleep(0.08) # 80ms simulated API call
return mock_response
-
+
mock_anthropic_client.messages.create.side_effect = delayed_response
-
+
query = f"Unique query {time.time()}" # Ensure cache miss
-
+
# Act
measurements = []
for _ in range(5): # Fewer iterations due to higher latency
start_time = time.perf_counter()
-
- with patch('src.core.driver.anthropic_client', mock_anthropic_client):
+
+ with patch("src.core.driver.anthropic_client", mock_anthropic_client):
response = await process_user_query(query)
-
+
end_time = time.perf_counter()
latency_ms = (end_time - start_time) * 1000
measurements.append(latency_ms)
-
+
# Assert
avg_latency = statistics.mean(measurements)
- assert avg_latency < 140, f"Average cache miss latency {avg_latency:.2f}ms exceeds 140ms target"
+ assert (
+ avg_latency < 140
+ ), f"Average cache miss latency {avg_latency:.2f}ms exceeds 140ms target"
assert response is not None
-
+
@pytest.mark.performance
def test_tool_execution_under_10ms_lan(self, test_environment, test_database):
"""TEST: Tool execution meets LAN latency targets under 10ms"""
# Arrange
from src.tools.connectors.sql import execute_sql_query
+
queries = [
"SELECT COUNT(*) as total FROM revenue",
"SELECT quarter, value FROM revenue WHERE quarter = 'Q1-2025'",
"SELECT AVG(value) as avg_revenue FROM revenue",
- "SELECT MAX(value) as max_revenue FROM revenue"
+ "SELECT MAX(value) as max_revenue FROM revenue",
]
-
+
# Act
measurements = []
for query in queries:
start_time = time.perf_counter()
result = execute_sql_query(query, test_database)
end_time = time.perf_counter()
-
+
latency_ms = (end_time - start_time) * 1000
measurements.append(latency_ms)
-
+
assert result.success == True
-
+
# Assert
avg_latency = statistics.mean(measurements)
max_latency = max(measurements)
-
- assert avg_latency < 10, f"Average tool execution {avg_latency:.2f}ms exceeds 10ms target"
- assert max_latency < 10, f"Max tool execution {max_latency:.2f}ms exceeds 10ms target"
-
+
+ assert (
+ avg_latency < 10
+ ), f"Average tool execution {avg_latency:.2f}ms exceeds 10ms target"
+ assert (
+ max_latency < 10
+ ), f"Max tool execution {max_latency:.2f}ms exceeds 10ms target"
+
@pytest.mark.performance
- async def test_overall_system_response_under_100ms(self, test_environment, benchmark_queries):
+ async def test_overall_system_response_under_100ms(
+ self, test_environment, benchmark_queries
+ ):
"""TEST: Overall system response time meets 100ms average target"""
# Arrange
from src.core.driver import process_user_query
-
+
# Mix of cache hits and misses
mixed_queries = benchmark_queries[:5] # Use first 5 benchmark queries
-
+
# Act
measurements = []
for query in mixed_queries:
start_time = time.perf_counter()
response = await process_user_query(query)
end_time = time.perf_counter()
-
+
latency_ms = (end_time - start_time) * 1000
measurements.append(latency_ms)
assert response is not None
-
+
# Assert
avg_latency = statistics.mean(measurements)
- assert avg_latency < 100, f"Average system response {avg_latency:.2f}ms exceeds 100ms target"
+ assert (
+ avg_latency < 100
+ ), f"Average system response {avg_latency:.2f}ms exceeds 100ms target"
class TestTokenCostOptimization:
"""Test suite for token cost optimization requirements."""
-
+
@pytest.mark.cost_analysis
- async def test_cache_hit_cost_reduction_90_percent(self, test_environment, performance_targets):
+ async def test_cache_hit_cost_reduction_90_percent(
+ self, test_environment, performance_targets
+ ):
"""TEST: Cache hits achieve 90% cost reduction vs traditional RAG"""
# Arrange
from src.core.driver import process_user_query
from src.cache.manager import CacheManager, calculate_token_cost
-
+
query = "What was Q1-2025 revenue?"
-
+
# Simulate traditional RAG cost (baseline)
baseline_input_tokens = 1500 # Typical RAG prompt with context
baseline_output_tokens = 200
- baseline_cost = calculate_token_cost(baseline_input_tokens, baseline_output_tokens)
-
+ baseline_cost = calculate_token_cost(
+ baseline_input_tokens, baseline_output_tokens
+ )
+
# Test FACT cache hit cost
cached_response = "Q1-2025 revenue was $1,234,567.89 based on financial data."
-
+
# Act
- with patch('src.cache.manager.get_cached_response') as mock_cache:
+ with patch("src.cache.manager.get_cached_response") as mock_cache:
mock_cache.return_value = cached_response
-
+
# Measure cache hit cost
start_tokens = 100 # Minimal prompt for cache lookup
- end_tokens = 50 # Cached response tokens
+ end_tokens = 50 # Cached response tokens
fact_cost = calculate_token_cost(start_tokens, end_tokens)
-
+
# Assert
cost_reduction = (baseline_cost - fact_cost) / baseline_cost
- assert cost_reduction >= performance_targets["cost_reduction_cache_hit"], \
- f"Cache hit cost reduction {cost_reduction:.1%} below 90% target"
-
+ assert (
+ cost_reduction >= performance_targets["cost_reduction_cache_hit"]
+ ), f"Cache hit cost reduction {cost_reduction:.1%} below 90% target"
+
# Verify actual 90% reduction
- assert cost_reduction >= 0.90, f"Cost reduction {cost_reduction:.1%} below 90% requirement"
-
+ assert (
+ cost_reduction >= 0.90
+ ), f"Cost reduction {cost_reduction:.1%} below 90% requirement"
+
@pytest.mark.cost_analysis
- async def test_cache_miss_cost_reduction_65_percent(self, test_environment, performance_targets):
+ async def test_cache_miss_cost_reduction_65_percent(
+ self, test_environment, performance_targets
+ ):
"""TEST: Cache misses achieve 65% cost reduction vs traditional RAG"""
# Arrange
from src.core.driver import process_user_query
from src.cache.manager import calculate_token_cost
-
+
query = "What was Q2-2024 revenue breakdown by category?"
-
+
# Simulate traditional RAG cost
baseline_input_tokens = 1500 # Large context with vector search results
baseline_output_tokens = 300 # Detailed response
- baseline_cost = calculate_token_cost(baseline_input_tokens, baseline_output_tokens)
-
+ baseline_cost = calculate_token_cost(
+ baseline_input_tokens, baseline_output_tokens
+ )
+
# Test FACT cache miss cost (uses tool calls)
- fact_input_tokens = 200 # Streamlined prompt + tool call
+ fact_input_tokens = 200 # Streamlined prompt + tool call
fact_output_tokens = 150 # Response with tool results
fact_cost = calculate_token_cost(fact_input_tokens, fact_output_tokens)
-
+
# Assert
cost_reduction = (baseline_cost - fact_cost) / baseline_cost
- assert cost_reduction >= performance_targets["cost_reduction_cache_miss"], \
- f"Cache miss cost reduction {cost_reduction:.1%} below 65% target"
-
+ assert (
+ cost_reduction >= performance_targets["cost_reduction_cache_miss"]
+ ), f"Cache miss cost reduction {cost_reduction:.1%} below 65% target"
+
# Verify actual 65% reduction
- assert cost_reduction >= 0.65, f"Cost reduction {cost_reduction:.1%} below 65% requirement"
-
+ assert (
+ cost_reduction >= 0.65
+ ), f"Cost reduction {cost_reduction:.1%} below 65% requirement"
+
@pytest.mark.cost_analysis
def test_token_efficiency_metrics_calculation(self, test_environment):
"""TEST: Token efficiency metrics are calculated accurately"""
# Arrange
from src.cache.manager import TokenEfficiencyCalculator
-
+
test_scenarios = [
{
"scenario": "cache_hit",
"input_tokens": 100,
"output_tokens": 50,
"baseline_input": 1500,
- "baseline_output": 200
+ "baseline_output": 200,
},
{
"scenario": "cache_miss",
"input_tokens": 200,
"output_tokens": 150,
"baseline_input": 1500,
- "baseline_output": 300
- }
+ "baseline_output": 300,
+ },
]
-
+
calculator = TokenEfficiencyCalculator()
-
+
# Act & Assert
for scenario in test_scenarios:
efficiency = calculator.calculate_efficiency(
fact_input=scenario["input_tokens"],
fact_output=scenario["output_tokens"],
rag_input=scenario["baseline_input"],
- rag_output=scenario["baseline_output"]
+ rag_output=scenario["baseline_output"],
)
-
+
assert efficiency.fact_total_tokens < efficiency.rag_total_tokens
assert efficiency.reduction_percentage > 0
-
+
if scenario["scenario"] == "cache_hit":
assert efficiency.reduction_percentage >= 0.90
else: # cache_miss
@@ -275,168 +310,189 @@ def test_token_efficiency_metrics_calculation(self, test_environment):
class TestPerformanceComparison:
"""Test suite for performance comparison with traditional RAG."""
-
+
@pytest.mark.benchmark
- async def test_fact_vs_rag_latency_comparison(self, test_environment, benchmark_queries):
+ async def test_fact_vs_rag_latency_comparison(
+ self, test_environment, benchmark_queries
+ ):
"""TEST: FACT system latency comparison vs traditional RAG"""
# Arrange
from src.core.driver import process_user_query
-
+
# Simulate traditional RAG latencies
rag_latencies = [
- 450, 520, 380, 490, 410, # Typical RAG response times (400-500ms)
- 510, 470, 430, 460, 480
+ 450,
+ 520,
+ 380,
+ 490,
+ 410, # Typical RAG response times (400-500ms)
+ 510,
+ 470,
+ 430,
+ 460,
+ 480,
]
-
+
# Test FACT system latencies
fact_latencies = []
-
+
# Act
for i, query in enumerate(benchmark_queries[:10]):
start_time = time.perf_counter()
response = await process_user_query(query)
end_time = time.perf_counter()
-
+
latency_ms = (end_time - start_time) * 1000
fact_latencies.append(latency_ms)
-
+
assert response is not None
-
+
# Assert
fact_avg = statistics.mean(fact_latencies)
rag_avg = statistics.mean(rag_latencies)
improvement_factor = rag_avg / fact_avg
-
- assert improvement_factor >= 3.0, \
- f"FACT latency improvement {improvement_factor:.1f}x below 3x target vs RAG"
-
+
+ assert (
+ improvement_factor >= 3.0
+ ), f"FACT latency improvement {improvement_factor:.1f}x below 3x target vs RAG"
+
# Verify absolute performance
- assert fact_avg < 100, f"FACT average latency {fact_avg:.2f}ms exceeds 100ms target"
-
+ assert (
+ fact_avg < 100
+ ), f"FACT average latency {fact_avg:.2f}ms exceeds 100ms target"
+
@pytest.mark.benchmark
def test_fact_vs_rag_cost_comparison(self, performance_targets):
"""TEST: FACT system cost comparison vs traditional RAG"""
# Arrange
from src.cache.manager import calculate_token_cost
-
+
# Typical query scenarios
scenarios = [
{"query_type": "simple", "rag_input": 1200, "rag_output": 150},
{"query_type": "complex", "rag_input": 2000, "rag_output": 400},
- {"query_type": "analytical", "rag_input": 1800, "rag_output": 300}
+ {"query_type": "analytical", "rag_input": 1800, "rag_output": 300},
]
-
+
fact_costs = []
rag_costs = []
-
+
# Act
for scenario in scenarios:
# RAG costs
rag_cost = calculate_token_cost(
- scenario["rag_input"],
- scenario["rag_output"]
+ scenario["rag_input"], scenario["rag_output"]
)
rag_costs.append(rag_cost)
-
+
# FACT costs (assuming 70% cache hits)
cache_hit_cost = calculate_token_cost(80, 40) # Minimal tokens
cache_miss_cost = calculate_token_cost(180, 120) # Tool-enhanced
-
+
fact_avg_cost = (0.7 * cache_hit_cost) + (0.3 * cache_miss_cost)
fact_costs.append(fact_avg_cost)
-
+
# Assert
total_rag_cost = sum(rag_costs)
total_fact_cost = sum(fact_costs)
cost_reduction = (total_rag_cost - total_fact_cost) / total_rag_cost
-
- assert cost_reduction >= 0.75, \
- f"Overall cost reduction {cost_reduction:.1%} below 75% target"
-
+
+ assert (
+ cost_reduction >= 0.75
+ ), f"Overall cost reduction {cost_reduction:.1%} below 75% target"
+
@pytest.mark.benchmark
- async def test_throughput_comparison_under_load(self, test_environment, benchmark_queries):
+ async def test_throughput_comparison_under_load(
+ self, test_environment, benchmark_queries
+ ):
"""TEST: FACT system throughput vs traditional RAG under load"""
# Arrange
from src.core.driver import process_user_query
import asyncio
-
+
concurrent_users = 10
queries_per_user = 5
-
+
# Simulate concurrent load
async def user_session(user_id: int):
session_times = []
for i in range(queries_per_user):
query = benchmark_queries[i % len(benchmark_queries)]
-
+
start_time = time.perf_counter()
response = await process_user_query(f"{query} (user {user_id})")
end_time = time.perf_counter()
-
+
session_times.append(end_time - start_time)
assert response is not None
-
+
return session_times
-
+
# Act
start_time = time.perf_counter()
-
- user_sessions = await asyncio.gather(*[
- user_session(user_id) for user_id in range(concurrent_users)
- ])
-
+
+ user_sessions = await asyncio.gather(
+ *[user_session(user_id) for user_id in range(concurrent_users)]
+ )
+
end_time = time.perf_counter()
total_duration = end_time - start_time
-
+
# Assert
total_queries = concurrent_users * queries_per_user
throughput_qps = total_queries / total_duration
-
+
# FACT should handle significantly higher throughput than RAG
- assert throughput_qps >= 20, f"Throughput {throughput_qps:.1f} QPS below 20 QPS target"
-
+ assert (
+ throughput_qps >= 20
+ ), f"Throughput {throughput_qps:.1f} QPS below 20 QPS target"
+
# Verify response time consistency under load
all_response_times = [time for session in user_sessions for time in session]
avg_response_time = statistics.mean(all_response_times) * 1000
p95_response_time = statistics.quantiles(all_response_times, n=20)[18] * 1000
-
- assert avg_response_time < 150, f"Average response time under load {avg_response_time:.2f}ms too high"
- assert p95_response_time < 200, f"P95 response time under load {p95_response_time:.2f}ms too high"
+
+ assert (
+ avg_response_time < 150
+ ), f"Average response time under load {avg_response_time:.2f}ms too high"
+ assert (
+ p95_response_time < 200
+ ), f"P95 response time under load {p95_response_time:.2f}ms too high"
class TestContinuousBenchmarking:
"""Test suite for continuous benchmarking and monitoring."""
-
+
def test_benchmark_data_collection_and_storage(self, test_environment):
"""TEST: Benchmark data collection and storage for continuous monitoring"""
# Arrange
from src.monitoring.benchmarks import BenchmarkCollector
-
+
collector = BenchmarkCollector()
-
+
# Simulate benchmark run
metrics = PerformanceMetrics(
response_time_ms=45.2,
token_cost=0.002,
cache_hit=True,
query_type="financial_query",
- timestamp=time.time()
+ timestamp=time.time(),
)
-
+
# Act
collector.record_metric(metrics)
stored_metrics = collector.get_recent_metrics(limit=10)
-
+
# Assert
assert len(stored_metrics) >= 1
assert stored_metrics[0].response_time_ms == 45.2
assert stored_metrics[0].cache_hit == True
-
+
def test_benchmark_trend_analysis(self, test_environment):
"""TEST: Benchmark trend analysis for performance monitoring"""
# Arrange
from src.monitoring.benchmarks import BenchmarkAnalyzer
-
+
# Create sample benchmark data over time
historical_data = []
for i in range(30): # 30 data points
@@ -445,35 +501,35 @@ def test_benchmark_trend_analysis(self, test_environment):
token_cost=0.001 + (i * 0.0001),
cache_hit=i % 2 == 0,
query_type="test_query",
- timestamp=time.time() - (30 - i) * 86400 # Daily data
+ timestamp=time.time() - (30 - i) * 86400, # Daily data
)
historical_data.append(metrics)
-
+
analyzer = BenchmarkAnalyzer(historical_data)
-
+
# Act
trend_analysis = analyzer.analyze_trends()
-
+
# Assert
assert "response_time_trend" in trend_analysis
assert "cost_trend" in trend_analysis
assert "cache_hit_rate_trend" in trend_analysis
-
+
# Should detect increasing response time trend
assert trend_analysis["response_time_trend"]["direction"] == "increasing"
assert trend_analysis["response_time_trend"]["significance"] > 0.5
-
+
def test_performance_alert_thresholds(self, test_environment, performance_targets):
"""TEST: Performance alert thresholds for continuous monitoring"""
# Arrange
from src.monitoring.alerts import PerformanceAlertManager
-
+
alert_manager = PerformanceAlertManager(
response_time_threshold=performance_targets["overall_response_ms"],
cost_increase_threshold=0.20, # 20% cost increase
- cache_hit_rate_threshold=0.60 # 60% minimum hit rate
+ cache_hit_rate_threshold=0.60, # 60% minimum hit rate
)
-
+
# Test scenarios
scenarios = [
{
@@ -481,78 +537,82 @@ def test_performance_alert_thresholds(self, test_environment, performance_target
"response_time": 80,
"cost_increase": 0.05,
"cache_hit_rate": 0.75,
- "should_alert": False
+ "should_alert": False,
},
{
"name": "slow_response",
"response_time": 150,
"cost_increase": 0.10,
"cache_hit_rate": 0.70,
- "should_alert": True
+ "should_alert": True,
},
{
"name": "high_cost",
"response_time": 90,
"cost_increase": 0.25,
"cache_hit_rate": 0.65,
- "should_alert": True
+ "should_alert": True,
},
{
"name": "low_cache_hit_rate",
"response_time": 85,
"cost_increase": 0.08,
"cache_hit_rate": 0.45,
- "should_alert": True
- }
+ "should_alert": True,
+ },
]
-
+
# Act & Assert
for scenario in scenarios:
alerts = alert_manager.check_thresholds(
response_time_ms=scenario["response_time"],
cost_increase_ratio=scenario["cost_increase"],
- cache_hit_rate=scenario["cache_hit_rate"]
+ cache_hit_rate=scenario["cache_hit_rate"],
)
-
+
has_alerts = len(alerts) > 0
- assert has_alerts == scenario["should_alert"], \
- f"Alert expectation failed for scenario: {scenario['name']}"
-
+ assert (
+ has_alerts == scenario["should_alert"]
+ ), f"Alert expectation failed for scenario: {scenario['name']}"
+
def test_benchmark_report_generation(self, test_environment):
"""TEST: Benchmark report generation for performance tracking"""
# Arrange
from src.monitoring.reports import BenchmarkReportGenerator
-
+
# Sample benchmark results
fact_metrics = [
PerformanceMetrics(45.2, 0.002, True, "query1", time.time()),
PerformanceMetrics(65.1, 0.008, False, "query2", time.time()),
- PerformanceMetrics(38.7, 0.001, True, "query3", time.time())
+ PerformanceMetrics(38.7, 0.001, True, "query3", time.time()),
]
-
+
rag_baseline = [
PerformanceMetrics(420.5, 0.025, False, "query1", time.time()),
PerformanceMetrics(380.2, 0.030, False, "query2", time.time()),
- PerformanceMetrics(450.8, 0.028, False, "query3", time.time())
+ PerformanceMetrics(450.8, 0.028, False, "query3", time.time()),
]
-
+
generator = BenchmarkReportGenerator()
-
+
# Act
report = generator.generate_comparison_report(
- fact_results=fact_metrics,
- rag_baseline=rag_baseline
+ fact_results=fact_metrics, rag_baseline=rag_baseline
)
-
+
# Assert
assert "performance_summary" in report
assert "cost_analysis" in report
assert "improvement_metrics" in report
-
+
# Verify improvement calculations
- assert report["improvement_metrics"]["latency_improvement"] > 5.0 # At least 5x improvement
- assert report["improvement_metrics"]["cost_reduction"] > 0.80 # At least 80% cost reduction
-
+ assert (
+ report["improvement_metrics"]["latency_improvement"] > 5.0
+ ) # At least 5x improvement
+ assert (
+ report["improvement_metrics"]["cost_reduction"] > 0.80
+ ) # At least 80% cost reduction
+
# Verify JSON serializable
json_report = json.dumps(report, default=str)
assert len(json_report) > 0
@@ -561,52 +621,60 @@ def test_benchmark_report_generation(self, test_environment):
@pytest.mark.slow
class TestLongRunningBenchmarks:
"""Long-running benchmark tests for stability and performance over time."""
-
+
@pytest.mark.performance
- async def test_sustained_performance_over_time(self, test_environment, benchmark_queries):
+ async def test_sustained_performance_over_time(
+ self, test_environment, benchmark_queries
+ ):
"""TEST: System maintains performance over sustained operation"""
# Arrange
from src.core.driver import process_user_query
+
duration_minutes = 5 # 5 minute sustained test
measurement_interval = 10 # Measure every 10 seconds
-
+
measurements = []
start_time = time.time()
end_time = start_time + (duration_minutes * 60)
-
+
# Act
query_count = 0
while time.time() < end_time:
query = benchmark_queries[query_count % len(benchmark_queries)]
-
+
measurement_start = time.perf_counter()
response = await process_user_query(query)
measurement_end = time.perf_counter()
-
+
latency_ms = (measurement_end - measurement_start) * 1000
- measurements.append({
- "timestamp": time.time(),
- "latency_ms": latency_ms,
- "query_count": query_count
- })
-
+ measurements.append(
+ {
+ "timestamp": time.time(),
+ "latency_ms": latency_ms,
+ "query_count": query_count,
+ }
+ )
+
query_count += 1
assert response is not None
-
+
# Wait for next measurement interval
await asyncio.sleep(measurement_interval)
-
+
# Assert
latencies = [m["latency_ms"] for m in measurements]
-
+
# Performance should remain stable
early_avg = statistics.mean(latencies[:3]) # First 3 measurements
late_avg = statistics.mean(latencies[-3:]) # Last 3 measurements
-
+
performance_degradation = (late_avg - early_avg) / early_avg
- assert performance_degradation < 0.20, \
- f"Performance degraded {performance_degradation:.1%} over time"
-
+ assert (
+ performance_degradation < 0.20
+ ), f"Performance degraded {performance_degradation:.1%} over time"
+
# Overall performance should meet targets
overall_avg = statistics.mean(latencies)
- assert overall_avg < 100, f"Sustained performance {overall_avg:.2f}ms exceeds target"
\ No newline at end of file
+ assert (
+ overall_avg < 100
+ ), f"Sustained performance {overall_avg:.2f}ms exceeds target"
diff --git a/tests/test_all_fixes.py b/tests/test_all_fixes.py
index eb0ea92..019679a 100644
--- a/tests/test_all_fixes.py
+++ b/tests/test_all_fixes.py
@@ -18,11 +18,16 @@
from unittest.mock import Mock, AsyncMock, MagicMock
# Add the project root to Python path
-sys.path.insert(0, os.path.abspath('..'))
+sys.path.insert(0, os.path.abspath(".."))
from src.core.driver import FACTDriver
from src.core.config import Config, get_config
-from src.tools.connectors.sql import SQLQueryTool, initialize_sql_tool, sql_query_readonly, sql_get_schema
+from src.tools.connectors.sql import (
+ SQLQueryTool,
+ initialize_sql_tool,
+ sql_query_readonly,
+ sql_get_schema,
+)
from src.db.connection import DatabaseManager
from src.tools.decorators import get_tool_registry
from src.core.errors import InvalidSQLError, SecurityError, DatabaseError
@@ -30,38 +35,48 @@
class TestRunner:
"""Comprehensive test runner for FACT system validation."""
-
+
def __init__(self):
self.passed_tests = 0
self.failed_tests = 0
self.test_results = []
-
+
async def run_test(self, test_name: str, test_func):
"""Run a single test and track results."""
print(f"\n{'='*60}")
print(f"Running: {test_name}")
print(f"{'='*60}")
-
+
try:
if asyncio.iscoroutinefunction(test_func):
result = await test_func()
else:
result = test_func()
-
+
if result:
print(f"ā
PASSED: {test_name}")
self.passed_tests += 1
- self.test_results.append({"test": test_name, "status": "PASSED", "error": None})
+ self.test_results.append(
+ {"test": test_name, "status": "PASSED", "error": None}
+ )
else:
print(f"ā FAILED: {test_name} - Test returned False")
self.failed_tests += 1
- self.test_results.append({"test": test_name, "status": "FAILED", "error": "Test returned False"})
-
+ self.test_results.append(
+ {
+ "test": test_name,
+ "status": "FAILED",
+ "error": "Test returned False",
+ }
+ )
+
except Exception as e:
print(f"ā FAILED: {test_name} - {type(e).__name__}: {e}")
self.failed_tests += 1
- self.test_results.append({"test": test_name, "status": "FAILED", "error": str(e)})
-
+ self.test_results.append(
+ {"test": test_name, "status": "FAILED", "error": str(e)}
+ )
+
def print_summary(self):
"""Print test summary."""
total_tests = self.passed_tests + self.failed_tests
@@ -71,8 +86,10 @@ def print_summary(self):
print(f"Total Tests: {total_tests}")
print(f"Passed: {self.passed_tests}")
print(f"Failed: {self.failed_tests}")
- print(f"Success Rate: {(self.passed_tests/total_tests*100) if total_tests > 0 else 0:.1f}%")
-
+ print(
+ f"Success Rate: {(self.passed_tests/total_tests*100) if total_tests > 0 else 0:.1f}%"
+ )
+
if self.failed_tests > 0:
print(f"\nFailed Tests:")
for result in self.test_results:
@@ -82,14 +99,14 @@ def print_summary(self):
async def test_sql_statement_validation():
"""Test 1: SQL statement handling with None checks, empty strings, and non-string inputs."""
-
+
# Initialize database and SQL tool
db_manager = DatabaseManager("db/test_validation.db")
await db_manager.initialize_database()
initialize_sql_tool(db_manager)
-
+
print("Testing SQL statement validation fixes...")
-
+
# Test 1.1: None statement - Test both decorator validation and internal validation
print("\n1.1 Testing None statement...")
try:
@@ -105,7 +122,7 @@ async def test_sql_statement_validation():
else:
print(f"ā None statement test failed: {e}")
return False
-
+
# Test 1.2: Empty string statement
print("\n1.2 Testing empty string statement...")
try:
@@ -119,7 +136,7 @@ async def test_sql_statement_validation():
else:
print(f"ā Empty string test failed: {e}")
return False
-
+
# Test 1.3: Non-string statement
print("\n1.3 Testing non-string statement...")
try:
@@ -133,7 +150,7 @@ async def test_sql_statement_validation():
else:
print(f"ā Non-string test failed: {e}")
return False
-
+
# Test 1.4: Valid query
print("\n1.4 Testing valid SQL query...")
try:
@@ -145,7 +162,7 @@ async def test_sql_statement_validation():
except Exception as e:
print(f"ā Valid query test failed: {e}")
return False
-
+
# Test 1.5: Schema retrieval (tests None/length checks in schema function)
print("\n1.5 Testing schema retrieval with None checks...")
try:
@@ -157,50 +174,50 @@ async def test_sql_statement_validation():
except Exception as e:
print(f"ā Schema retrieval test failed: {e}")
return False
-
+
return True
async def test_tool_execution_async_handling():
"""Test 2: Tool execution with proper async/await handling."""
-
+
print("Testing tool execution and async/await handling...")
-
+
# Test 2.1: Tool registry functionality
print("\n2.1 Testing tool registry...")
registry = get_tool_registry()
tools = registry.list_tools()
-
+
if not tools:
print("ā No tools registered")
return False
-
+
print(f"ā Found {len(tools)} registered tools: {tools}")
-
+
# Test 2.2: Tool schema export
print("\n2.2 Testing tool schema export...")
schemas = registry.export_all_schemas()
-
+
if not schemas:
print("ā No schemas exported")
return False
-
+
print(f"ā Exported {len(schemas)} tool schemas")
-
+
# Verify schema format
for schema in schemas:
if "function" not in schema or "name" not in schema["function"]:
print(f"ā Invalid schema format: {schema}")
return False
-
+
# Test 2.3: Async tool execution
print("\n2.3 Testing async tool execution...")
-
+
# Initialize database for tool testing
db_manager = DatabaseManager("db/test_validation.db")
await db_manager.initialize_database()
initialize_sql_tool(db_manager)
-
+
try:
# Test async tool directly
result = await sql_query_readonly("SELECT 1 as test_value")
@@ -210,36 +227,40 @@ async def test_tool_execution_async_handling():
except Exception as e:
print(f"ā Async tool execution failed: {e}")
return False
-
+
return True
async def test_llm_response_processing():
"""Test 3: LLM response processing with tool calls."""
-
+
print("Testing LLM response processing...")
-
+
# Create a mock LLM response with tool calls
mock_tool_call = Mock()
mock_tool_call.id = "call_123"
mock_tool_call.type = "function"
mock_tool_call.function = Mock()
- mock_tool_call.function.name = "SQL_QueryReadonly" # Note: underscores for API compatibility
- mock_tool_call.function.arguments = '{"statement": "SELECT COUNT(*) FROM companies"}'
-
+ mock_tool_call.function.name = (
+ "SQL_QueryReadonly" # Note: underscores for API compatibility
+ )
+ mock_tool_call.function.arguments = (
+ '{"statement": "SELECT COUNT(*) FROM companies"}'
+ )
+
mock_message = Mock()
mock_message.content = "I'll help you query the database."
mock_message.tool_calls = [mock_tool_call]
-
+
mock_choice = Mock()
mock_choice.message = mock_message
-
+
mock_response = Mock()
mock_response.choices = [mock_choice]
-
+
# Test 3.1: Tool call extraction
print("\n3.1 Testing tool call extraction...")
-
+
# Simulate the driver's tool call processing
try:
tool_calls = mock_response.choices[0].message.tool_calls
@@ -249,39 +270,41 @@ async def test_llm_response_processing():
except Exception as e:
print(f"ā Tool call extraction failed: {e}")
return False
-
+
# Test 3.2: Tool execution simulation
print("\n3.2 Testing tool execution from LLM response...")
-
+
# Initialize database and tools
db_manager = DatabaseManager("db/test_validation.db")
await db_manager.initialize_database()
initialize_sql_tool(db_manager)
-
+
try:
# Simulate what the driver does
tool_call = mock_tool_call
- tool_name = tool_call.function.name.replace('_', '.') # Convert back to registry format
+ tool_name = tool_call.function.name.replace(
+ "_", "."
+ ) # Convert back to registry format
tool_args = json.loads(tool_call.function.arguments)
-
+
# Get tool from registry
registry = get_tool_registry()
tool_definition = registry.get_tool(tool_name)
-
+
# Execute tool (async)
result = await tool_definition.function(**tool_args)
-
+
assert isinstance(result, dict)
assert result["status"] == "success"
print("ā Tool execution from LLM response successful")
-
+
except Exception as e:
print(f"ā Tool execution from LLM response failed: {e}")
return False
-
+
# Test 3.3: Response content handling (test None response handling)
print("\n3.3 Testing None response content handling...")
-
+
# Test the driver's content extraction logic
try:
# Simulate various response scenarios
@@ -290,60 +313,60 @@ async def test_llm_response_processing():
{"content": None, "expected": None},
{"content": "", "expected": ""},
]
-
+
for i, case in enumerate(test_cases):
mock_resp = Mock()
mock_resp.content = case["content"]
-
+
# Simulate driver logic
- if hasattr(mock_resp, 'content'):
+ if hasattr(mock_resp, "content"):
response_text = mock_resp.content
else:
response_text = None
-
+
if case["expected"] is None:
assert response_text is None
else:
assert response_text == case["expected"]
-
+
print("ā Response content handling working correctly")
-
+
except Exception as e:
print(f"ā Response content handling failed: {e}")
return False
-
+
return True
async def test_full_system_functionality():
"""Test 4: Full system test with end-to-end functionality."""
-
+
print("Testing full system functionality...")
-
+
# Test 4.1: Driver initialization
print("\n4.1 Testing driver initialization...")
-
+
try:
# Create test config
config = get_config()
# Can't modify config directly, so create driver with test database path
-
+
# Create driver with default config but test database
driver = FACTDriver(config)
# Override the database path after initialization
driver.config.database_path = "db/test_validation.db"
await driver.initialize()
-
+
assert driver._initialized == True
print("ā Driver initialized successfully")
-
+
except Exception as e:
print(f"ā Driver initialization failed: {e}")
return False
-
+
# Test 4.2: System metrics
print("\n4.2 Testing system metrics...")
-
+
try:
metrics = driver.get_metrics()
assert isinstance(metrics, dict)
@@ -351,140 +374,146 @@ async def test_full_system_functionality():
assert "initialized" in metrics
assert metrics["initialized"] == True
print("ā System metrics working")
-
+
except Exception as e:
print(f"ā System metrics failed: {e}")
return False
-
+
# Test 4.3: Database operations through driver
print("\n4.3 Testing database operations through system...")
-
+
try:
# Test database info
db_info = await driver.database_manager.get_database_info()
assert isinstance(db_info, dict)
assert "tables" in db_info
print("ā Database operations working")
-
+
except Exception as e:
print(f"ā Database operations failed: {e}")
return False
-
+
# Test 4.4: Error handling in driver
print("\n4.4 Testing error handling in driver...")
-
+
try:
# Create a driver with invalid OpenAI key to test error handling
test_config = get_config()
test_config.openai_api_key = "invalid_key_for_testing"
test_config.database_path = "db/test_validation.db"
-
+
error_driver = FACTDriver(test_config)
await error_driver.initialize() # Should not crash, just log warnings
-
+
print("ā Error handling in driver working (allows system to continue)")
-
+
except Exception as e:
# This is expected - driver should handle errors gracefully
print(f"ā Driver error handling working: {type(e).__name__}")
-
+
# Clean up
await driver.shutdown()
-
+
return True
async def test_demo_queries():
"""Test 5: Demo queries to confirm end-to-end functionality."""
-
+
print("Testing demo queries for end-to-end validation...")
-
+
# Initialize system
db_manager = DatabaseManager("db/test_validation.db")
await db_manager.initialize_database()
initialize_sql_tool(db_manager)
-
+
# Test queries that should work
demo_queries = [
{
"name": "Company count",
"query": "SELECT COUNT(*) as total_companies FROM companies",
- "should_work": True
+ "should_work": True,
},
{
"name": "Technology companies",
"query": "SELECT name, symbol FROM companies WHERE sector = 'Technology'",
- "should_work": True
+ "should_work": True,
},
{
"name": "Latest financial data",
"query": "SELECT c.name, f.revenue FROM companies c JOIN financial_records f ON c.id = f.company_id WHERE f.year = 2025 AND f.quarter = 'Q1' LIMIT 5",
- "should_work": True
+ "should_work": True,
},
{
"name": "Schema information",
"query": "PRAGMA table_info(companies)",
- "should_work": True
+ "should_work": True,
},
{
"name": "Invalid dangerous query",
"query": "DROP TABLE companies",
- "should_work": False
+ "should_work": False,
},
{
"name": "SQL injection attempt",
"query": "SELECT * FROM companies WHERE id = 1; DROP TABLE companies; --",
- "should_work": False
- }
+ "should_work": False,
+ },
]
-
+
success_count = 0
-
+
for i, test_case in enumerate(demo_queries, 1):
print(f"\n5.{i} Testing: {test_case['name']}")
print(f"Query: {test_case['query']}")
-
+
try:
- result = await sql_query_readonly(test_case['query'])
-
- if test_case['should_work']:
+ result = await sql_query_readonly(test_case["query"])
+
+ if test_case["should_work"]:
if result["status"] == "success":
print(f"ā Query succeeded as expected")
success_count += 1
else:
- print(f"ā Query should have succeeded but failed: {result.get('error', 'Unknown error')}")
+ print(
+ f"ā Query should have succeeded but failed: {result.get('error', 'Unknown error')}"
+ )
else:
if result["status"] == "failed":
- print(f"ā Query correctly rejected: {result.get('error', 'Security violation')}")
+ print(
+ f"ā Query correctly rejected: {result.get('error', 'Security violation')}"
+ )
success_count += 1
else:
print(f"ā Dangerous query should have been rejected but succeeded")
-
+
except Exception as e:
- if test_case['should_work']:
+ if test_case["should_work"]:
print(f"ā Query failed unexpectedly: {e}")
else:
print(f"ā Query correctly failed with exception: {e}")
success_count += 1
-
+
print(f"\nDemo queries: {success_count}/{len(demo_queries)} passed")
return success_count == len(demo_queries)
async def test_edge_cases():
"""Test 6: Edge cases and boundary conditions."""
-
+
print("Testing edge cases and boundary conditions...")
-
+
# Initialize system
db_manager = DatabaseManager("db/test_validation.db")
await db_manager.initialize_database()
initialize_sql_tool(db_manager)
-
+
# Test 6.1: Very long query (should be rejected)
print("\n6.1 Testing very long query...")
try:
- long_query = "SELECT * FROM companies WHERE " + " OR ".join([f"id = {i}" for i in range(1000)])
+ long_query = "SELECT * FROM companies WHERE " + " OR ".join(
+ [f"id = {i}" for i in range(1000)]
+ )
result = await sql_query_readonly(long_query)
assert result["status"] == "failed"
assert "too long" in result.get("error", "").lower()
@@ -495,23 +524,28 @@ async def test_edge_cases():
else:
print(f"ā Long query test failed: {e}")
return False
-
+
# Test 6.2: Complex nested query (should be rejected)
print("\n6.2 Testing complex nested query...")
try:
nested_query = "SELECT * FROM (SELECT * FROM (SELECT * FROM (SELECT * FROM (SELECT * FROM (SELECT * FROM companies)))))"
result = await sql_query_readonly(nested_query)
assert result["status"] == "failed"
- assert "nested" in result.get("error", "").lower() or "subqueries" in result.get("error", "").lower()
+ assert (
+ "nested" in result.get("error", "").lower()
+ or "subqueries" in result.get("error", "").lower()
+ )
print("ā Complex nested query rejected correctly")
except Exception as e:
print(f"ā Nested query test failed: {e}")
return False
-
+
# Test 6.3: Query with special characters
print("\n6.3 Testing query with special characters...")
try:
- special_query = "SELECT name FROM companies WHERE name LIKE '%&%' OR name LIKE '%@%'"
+ special_query = (
+ "SELECT name FROM companies WHERE name LIKE '%&%' OR name LIKE '%@%'"
+ )
result = await sql_query_readonly(special_query)
# This should work - special characters in data are OK
assert result["status"] == "success"
@@ -519,24 +553,26 @@ async def test_edge_cases():
except Exception as e:
print(f"ā Special characters test failed: {e}")
return False
-
+
# Test 6.4: Unicode query
print("\n6.4 Testing unicode query...")
try:
- unicode_query = "SELECT name FROM companies WHERE name LIKE '%Ć©%' OR name LIKE '%äø%'"
+ unicode_query = (
+ "SELECT name FROM companies WHERE name LIKE '%Ć©%' OR name LIKE '%äø%'"
+ )
result = await sql_query_readonly(unicode_query)
assert result["status"] == "success"
print("ā Unicode query handled correctly")
except Exception as e:
print(f"ā Unicode query test failed: {e}")
return False
-
+
return True
async def main():
"""Main test execution function."""
-
+
print("š§Ŗ FACT System Comprehensive Validation Test")
print("=" * 60)
print("Testing all major fixes and functionality:")
@@ -546,24 +582,26 @@ async def main():
print("4. Full system end-to-end functionality")
print("5. Demo queries validation")
print("6. Edge cases and boundary conditions")
-
+
runner = TestRunner()
-
+
# Run all test suites
await runner.run_test("SQL Statement Validation", test_sql_statement_validation)
- await runner.run_test("Tool Execution & Async Handling", test_tool_execution_async_handling)
+ await runner.run_test(
+ "Tool Execution & Async Handling", test_tool_execution_async_handling
+ )
await runner.run_test("LLM Response Processing", test_llm_response_processing)
await runner.run_test("Full System Functionality", test_full_system_functionality)
await runner.run_test("Demo Queries Validation", test_demo_queries)
await runner.run_test("Edge Cases & Boundary Conditions", test_edge_cases)
-
+
# Print final summary
runner.print_summary()
-
+
# Return appropriate exit code
return 0 if runner.failed_tests == 0 else 1
if __name__ == "__main__":
exit_code = asyncio.run(main())
- sys.exit(exit_code)
\ No newline at end of file
+ sys.exit(exit_code)
diff --git a/tests/test_basic_functionality.py b/tests/test_basic_functionality.py
index ef17fbd..c30c47f 100644
--- a/tests/test_basic_functionality.py
+++ b/tests/test_basic_functionality.py
@@ -12,7 +12,8 @@
# Add src to path for testing
import sys
-sys.path.insert(0, str(Path(__file__).parent.parent / 'src'))
+
+sys.path.insert(0, str(Path(__file__).parent.parent / "src"))
from src.core.config import Config
from src.core.driver import FACTDriver
@@ -24,15 +25,15 @@
@pytest.fixture
async def temp_database():
"""Create a temporary database for testing."""
- with tempfile.NamedTemporaryFile(suffix='.db', delete=False) as tmp:
+ with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as tmp:
db_path = tmp.name
-
+
# Initialize database
db_manager = DatabaseManager(db_path)
await db_manager.initialize_database()
-
+
yield db_manager
-
+
# Cleanup
os.unlink(db_path)
@@ -42,40 +43,42 @@ def test_config(temp_database):
"""Create test configuration with temporary database."""
config = Config()
config._config = {
- 'database_path': temp_database.database_path,
- 'anthropic_api_key': 'test_key',
- 'arcade_api_key': 'test_key',
- 'cache_prefix': 'test_v1',
- 'claude_model': 'claude-3-5-sonnet-20241022',
- 'system_prompt': 'Test prompt'
+ "database_path": temp_database.database_path,
+ "anthropic_api_key": "test_key",
+ "arcade_api_key": "test_key",
+ "cache_prefix": "test_v1",
+ "claude_model": "claude-3-5-sonnet-20241022",
+ "system_prompt": "Test prompt",
}
return config
class TestDatabaseIntegration:
"""Test database integration and SQL tools."""
-
+
@pytest.mark.asyncio
async def test_database_initialization(self, temp_database):
"""Test that database initializes with schema and sample data."""
db_info = await temp_database.get_database_info()
-
- assert db_info['total_tables'] >= 2
- assert 'companies' in db_info['tables']
- assert 'financial_records' in db_info['tables']
- assert db_info['tables']['companies']['row_count'] > 0
- assert db_info['tables']['financial_records']['row_count'] > 0
-
+
+ assert db_info["total_tables"] >= 2
+ assert "companies" in db_info["tables"]
+ assert "financial_records" in db_info["tables"]
+ assert db_info["tables"]["companies"]["row_count"] > 0
+ assert db_info["tables"]["financial_records"]["row_count"] > 0
+
@pytest.mark.asyncio
async def test_sql_query_execution(self, temp_database):
"""Test SQL query execution with security validation."""
# Test valid SELECT query
- result = await temp_database.execute_query("SELECT COUNT(*) as count FROM companies")
-
+ result = await temp_database.execute_query(
+ "SELECT COUNT(*) as count FROM companies"
+ )
+
assert result.row_count == 1
- assert 'count' in result.columns
- assert result.rows[0]['count'] > 0
-
+ assert "count" in result.columns
+ assert result.rows[0]["count"] > 0
+
@pytest.mark.asyncio
async def test_sql_security_validation(self, temp_database):
"""Test that dangerous SQL operations are blocked."""
@@ -83,9 +86,9 @@ async def test_sql_security_validation(self, temp_database):
"DROP TABLE companies",
"DELETE FROM companies",
"UPDATE companies SET name = 'hacked'",
- "INSERT INTO companies VALUES (999, 'hack')"
+ "INSERT INTO companies VALUES (999, 'hack')",
]
-
+
for query in dangerous_queries:
with pytest.raises(Exception): # Should raise SecurityError
await temp_database.execute_query(query)
@@ -93,83 +96,83 @@ async def test_sql_security_validation(self, temp_database):
class TestToolFramework:
"""Test tool registration and execution framework."""
-
+
def test_tool_registry_functionality(self):
"""Test tool registry operations."""
registry = get_tool_registry()
-
+
# Check that SQL tools are registered
tool_names = registry.list_tools()
assert len(tool_names) > 0
-
+
# Check schema export
schemas = registry.export_all_schemas()
assert len(schemas) > 0
- assert all('function' in schema for schema in schemas)
-
+ assert all("function" in schema for schema in schemas)
+
@pytest.mark.asyncio
async def test_sql_tool_integration(self, temp_database):
"""Test SQL tool initialization and execution."""
# Initialize SQL tool
initialize_sql_tool(temp_database)
-
+
# Test tool execution through registry
registry = get_tool_registry()
sql_tool = registry.get_tool("SQL.QueryReadonly")
-
+
# Execute a test query
result = await sql_tool.function(statement="SELECT name FROM companies LIMIT 1")
-
- assert result['status'] == 'success'
- assert 'rows' in result
- assert len(result['rows']) > 0
+
+ assert result["status"] == "success"
+ assert "rows" in result
+ assert len(result["rows"]) > 0
class TestSystemIntegration:
"""Test complete system integration."""
-
+
@pytest.mark.asyncio
async def test_driver_initialization(self, test_config, temp_database):
"""Test FACT driver initialization."""
# Note: This test would need mocking for actual API calls
driver = FACTDriver(test_config)
-
+
# Initialize database component
driver.database_manager = temp_database
await driver._initialize_tools()
-
+
# Check that tools are registered
assert len(driver.tool_registry.list_tools()) > 0
-
+
# Check metrics
metrics = driver.get_metrics()
- assert 'total_queries' in metrics
- assert 'cache_hits' in metrics
+ assert "total_queries" in metrics
+ assert "cache_hits" in metrics
class TestConfigurationManagement:
"""Test configuration loading and validation."""
-
+
def test_config_creation(self):
"""Test configuration object creation."""
config = Config()
-
+
# Test property access
- assert hasattr(config, 'anthropic_api_key')
- assert hasattr(config, 'arcade_api_key')
- assert hasattr(config, 'database_path')
- assert hasattr(config, 'cache_prefix')
-
+ assert hasattr(config, "anthropic_api_key")
+ assert hasattr(config, "arcade_api_key")
+ assert hasattr(config, "database_path")
+ assert hasattr(config, "cache_prefix")
+
def test_config_dictionary_export(self):
"""Test configuration export functionality."""
config = Config()
config_dict = config.to_dict()
-
+
assert isinstance(config_dict, dict)
- assert 'database_path' in config_dict
- assert 'cache_prefix' in config_dict
+ assert "database_path" in config_dict
+ assert "cache_prefix" in config_dict
# Sensitive keys should be masked
- assert config_dict.get('anthropic_api_key') in [None, '***']
+ assert config_dict.get("anthropic_api_key") in [None, "***"]
@pytest.mark.asyncio
@@ -177,30 +180,30 @@ async def test_end_to_end_workflow(temp_database):
"""Test a complete end-to-end workflow."""
# Initialize SQL tool
initialize_sql_tool(temp_database)
-
+
# Get tool registry
registry = get_tool_registry()
-
+
# Test that we can execute a sample query through the tool
sql_tool = registry.get_tool("SQL.QueryReadonly")
-
+
# Execute query
result = await sql_tool.function(
statement="SELECT name, sector FROM companies WHERE sector = 'Technology'"
)
-
+
# Verify result structure
- assert result['status'] == 'success'
- assert 'rows' in result
- assert 'columns' in result
- assert 'execution_time_ms' in result
-
+ assert result["status"] == "success"
+ assert "rows" in result
+ assert "columns" in result
+ assert "execution_time_ms" in result
+
# Verify data content
- if result['rows']:
- assert 'name' in result['rows'][0]
- assert 'sector' in result['rows'][0]
+ if result["rows"]:
+ assert "name" in result["rows"][0]
+ assert "sector" in result["rows"][0]
if __name__ == "__main__":
# Run tests directly
- pytest.main([__file__, "-v"])
\ No newline at end of file
+ pytest.main([__file__, "-v"])
diff --git a/tests/test_fixes_summary.py b/tests/test_fixes_summary.py
index 1b343c5..8e94425 100644
--- a/tests/test_fixes_summary.py
+++ b/tests/test_fixes_summary.py
@@ -14,30 +14,35 @@
import os
# Add the project root to Python path
-sys.path.insert(0, os.path.abspath('..'))
+sys.path.insert(0, os.path.abspath(".."))
-from src.tools.connectors.sql import SQLQueryTool, initialize_sql_tool, sql_query_readonly, sql_get_schema
+from src.tools.connectors.sql import (
+ SQLQueryTool,
+ initialize_sql_tool,
+ sql_query_readonly,
+ sql_get_schema,
+)
from src.db.connection import DatabaseManager
from src.tools.decorators import get_tool_registry
async def main():
"""Main test function demonstrating key fixes."""
-
+
print("š§Ŗ FACT System Key Fixes Validation")
print("=" * 50)
-
+
# Initialize database and tools
print("\nš Initializing system...")
db_manager = DatabaseManager("db/test_validation.db")
await db_manager.initialize_database()
initialize_sql_tool(db_manager)
print("ā
System initialized")
-
+
# Test 1: SQL Input Validation Fixes
print("\nš Test 1: SQL Input Validation Fixes")
print("-" * 40)
-
+
# Test None input
print("Testing None input:")
try:
@@ -48,7 +53,7 @@ async def main():
print("ā
None input correctly rejected")
else:
print(f"ā Unexpected error: {e}")
-
+
# Test empty string
print("Testing empty string:")
try:
@@ -59,7 +64,7 @@ async def main():
print("ā
Empty string correctly rejected")
else:
print(f"ā Unexpected error: {e}")
-
+
# Test non-string
print("Testing non-string input:")
try:
@@ -70,11 +75,11 @@ async def main():
print("ā
Non-string input correctly rejected")
else:
print(f"ā Unexpected error: {e}")
-
+
# Test 2: Async Tool Execution
print("\nā” Test 2: Async Tool Execution")
print("-" * 40)
-
+
# Test simple query that should work
print("Testing async tool execution:")
try:
@@ -85,79 +90,95 @@ async def main():
print(f"ā Tool execution failed: {result.get('error', 'Unknown error')}")
except Exception as e:
print(f"ā Tool execution exception: {e}")
-
+
# Test 3: Tool Registry and Schema Export
print("\nš Test 3: Tool Registry and Schema Export")
print("-" * 40)
-
+
registry = get_tool_registry()
tools = registry.list_tools()
print(f"Registered tools: {len(tools)}")
-
+
schemas = registry.export_all_schemas()
print(f"Exported schemas: {len(schemas)}")
-
+
if len(tools) > 0 and len(schemas) > 0:
print("ā
Tool registry working correctly")
else:
print("ā Tool registry issues")
-
+
# Test 4: Security Validation
print("\nš Test 4: Security Validation")
print("-" * 40)
-
+
# Test dangerous query rejection
print("Testing dangerous query rejection:")
try:
result = await sql_query_readonly("DROP TABLE companies")
# The tool returns a dict with status=failed for security violations
- if isinstance(result, dict) and result.get("status") == "failed" and "error" in result:
+ if (
+ isinstance(result, dict)
+ and result.get("status") == "failed"
+ and "error" in result
+ ):
print("ā
Dangerous query correctly rejected")
else:
print("ā Dangerous query should have been rejected")
except Exception as e:
print(f"ā
Dangerous query rejected with exception: {type(e).__name__}")
-
+
# Test injection attempt
print("Testing SQL injection attempt:")
try:
- result = await sql_query_readonly("SELECT * FROM users WHERE id = 1; DROP TABLE users; --")
+ result = await sql_query_readonly(
+ "SELECT * FROM users WHERE id = 1; DROP TABLE users; --"
+ )
# The tool returns a dict with status=failed for security violations
- if isinstance(result, dict) and result.get("status") == "failed" and "error" in result:
+ if (
+ isinstance(result, dict)
+ and result.get("status") == "failed"
+ and "error" in result
+ ):
print("ā
SQL injection attempt correctly rejected")
else:
print("ā SQL injection should have been rejected")
except Exception as e:
print(f"ā
SQL injection rejected with exception: {type(e).__name__}")
-
+
# Test 5: Schema Operations (tests None handling in schema functions)
print("\nš Test 5: Schema Operations")
print("-" * 40)
-
+
print("Testing schema retrieval:")
try:
schema_result = await sql_get_schema()
if schema_result.get("status") == "success":
- print(f"ā
Schema retrieval working - found {schema_result.get('total_tables', 0)} tables")
+ print(
+ f"ā
Schema retrieval working - found {schema_result.get('total_tables', 0)} tables"
+ )
else:
- print(f"ā Schema retrieval failed: {schema_result.get('error', 'Unknown error')}")
+ print(
+ f"ā Schema retrieval failed: {schema_result.get('error', 'Unknown error')}"
+ )
except Exception as e:
print(f"ā Schema retrieval exception: {e}")
-
+
# Test 6: Database Operations
print("\nš¾ Test 6: Database Operations")
print("-" * 40)
-
+
print("Testing database info:")
try:
db_info = await db_manager.get_database_info()
if isinstance(db_info, dict) and "tables" in db_info:
- print(f"ā
Database operations working - {db_info.get('total_tables', 0)} tables")
+ print(
+ f"ā
Database operations working - {db_info.get('total_tables', 0)} tables"
+ )
else:
print("ā Database info retrieval failed")
except Exception as e:
print(f"ā Database info exception: {e}")
-
+
# Summary
print("\n" + "=" * 50)
print("šÆ SUMMARY")
@@ -174,4 +195,4 @@ async def main():
if __name__ == "__main__":
- asyncio.run(main())
\ No newline at end of file
+ asyncio.run(main())
diff --git a/tests/test_imports.py b/tests/test_imports.py
index 2c19a7c..3bf69c0 100644
--- a/tests/test_imports.py
+++ b/tests/test_imports.py
@@ -13,237 +13,242 @@
from pathlib import Path
# Add src to path for testing
-sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'src'))
+sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "src"))
+
class TestImportResolution(unittest.TestCase):
"""Test that all imports work correctly across the FACT system."""
-
+
def setUp(self):
"""Set up test environment."""
- self.src_path = Path(__file__).parent.parent / 'src'
+ self.src_path = Path(__file__).parent.parent / "src"
self.assertTrue(self.src_path.exists(), "Source directory must exist")
-
+
def test_core_module_imports(self):
"""Test that core module imports work correctly."""
try:
# Test core config import
from core.config import Config, get_config
-
+
# Test core driver import
from core.driver import get_driver, FACTDriver
-
+
# Test core errors import
from core.errors import FACTError, ConfigurationError
-
+
print("ā
Core module imports successful")
except ImportError as e:
self.fail(f"Core module import failed: {e}")
-
+
def test_db_module_imports(self):
"""Test that database module imports work correctly."""
try:
# Test db connection import
from db.connection import DatabaseManager
-
- # Test db models import
+
+ # Test db models import
from db.models import DATABASE_SCHEMA, QueryResult
-
+
print("ā
Database module imports successful")
except ImportError as e:
self.fail(f"Database module import failed: {e}")
-
+
def test_tools_module_imports(self):
"""Test that tools module imports work correctly."""
try:
# Test tools decorators import
from tools.decorators import Tool, get_tool_registry
-
+
# Test tools executor import
from tools.executor import ToolExecutor, ToolCall
-
+
# Test tools validation import
from tools.validation import ParameterValidator
-
+
print("ā
Tools module imports successful")
except ImportError as e:
self.fail(f"Tools module import failed: {e}")
-
+
def test_security_module_imports(self):
"""Test that security module imports work correctly."""
try:
# Test security auth import
from security.auth import AuthorizationManager
-
+
# Test security config import
from security.config import SecurityConfig
-
+
print("ā
Security module imports successful")
except ImportError as e:
self.fail(f"Security module import failed: {e}")
-
+
def test_arcade_module_imports(self):
"""Test that arcade module imports work correctly."""
try:
# Test arcade client import
from arcade.client import ArcadeClient
-
+
# Test arcade gateway import
from arcade.gateway import ArcadeGateway
-
+
# Test arcade errors import
from arcade.errors import ArcadeError
-
+
print("ā
Arcade module imports successful")
except ImportError as e:
self.fail(f"Arcade module import failed: {e}")
-
+
def test_cache_module_imports(self):
"""Test that cache module imports work correctly."""
try:
# Test cache manager import
from cache.manager import CacheManager
-
+
# Test cache strategy import
from cache.strategy import CacheStrategy
-
+
print("ā
Cache module imports successful")
except ImportError as e:
self.fail(f"Cache module import failed: {e}")
-
+
def test_monitoring_module_imports(self):
"""Test that monitoring module imports work correctly."""
try:
# Test monitoring metrics import
from monitoring.metrics import MetricsCollector
-
+
print("ā
Monitoring module imports successful")
except ImportError as e:
self.fail(f"Monitoring module import failed: {e}")
-
+
def test_tools_connectors_imports(self):
"""Test that tools connector imports work correctly."""
try:
# Test SQL connector import
from tools.connectors.sql import initialize_sql_tool
-
+
# Test HTTP connector import
from tools.connectors.http import http_get
-
+
# Test file connector import
from tools.connectors.file import read_file
-
+
print("ā
Tools connector imports successful")
except ImportError as e:
self.fail(f"Tools connector import failed: {e}")
-
+
def test_benchmarking_module_imports(self):
"""Test that benchmarking module imports work correctly."""
try:
# Test benchmarking framework import
from benchmarking.framework import BenchmarkFramework
-
- # Test benchmarking profiler import
+
+ # Test benchmarking profiler import
from benchmarking.profiler import SystemProfiler
-
+
print("ā
Benchmarking module imports successful")
except ImportError as e:
self.fail(f"Benchmarking module import failed: {e}")
-
+
def test_no_relative_imports_beyond_package(self):
"""Test that no files contain problematic relative imports."""
problematic_patterns = [
- 'from ..', # Relative imports that might go beyond package
+ "from ..", # Relative imports that might go beyond package
]
-
+
problematic_files = []
-
- for py_file in self.src_path.rglob('*.py'):
+
+ for py_file in self.src_path.rglob("*.py"):
try:
- with open(py_file, 'r', encoding='utf-8') as f:
+ with open(py_file, "r", encoding="utf-8") as f:
content = f.read()
-
+
for pattern in problematic_patterns:
if pattern in content:
# Check if this is actually problematic by context
lines = content.splitlines()
for i, line in enumerate(lines):
- if pattern in line and 'import' in line:
+ if pattern in line and "import" in line:
# This could be problematic
- problematic_files.append(f"{py_file}:{i+1}: {line.strip()}")
+ problematic_files.append(
+ f"{py_file}:{i+1}: {line.strip()}"
+ )
except Exception as e:
self.fail(f"Error reading {py_file}: {e}")
-
+
if problematic_files:
print("ā ļø Found potentially problematic imports:")
for item in problematic_files:
print(f" {item}")
else:
print("ā
No problematic relative imports found")
-
+
def test_script_execution_context(self):
"""Test that the system can be imported from script context."""
# Simulate the context where scripts/init_environment.py imports
original_path = sys.path.copy()
-
+
try:
# Add the specific path that init_environment.py uses
- script_src_path = os.path.join(os.path.dirname(__file__), '..', 'src')
+ script_src_path = os.path.join(os.path.dirname(__file__), "..", "src")
if script_src_path not in sys.path:
sys.path.insert(0, script_src_path)
-
+
# Try to import the same modules that init_environment.py imports
from core.config import Config
from core.driver import get_driver
from db.connection import DatabaseManager
-
+
print("ā
Script execution context imports successful")
-
+
except ImportError as e:
self.fail(f"Script context import failed: {e}")
finally:
sys.path = original_path
+
class TestMainEntryPoints(unittest.TestCase):
"""Test main entry points work correctly."""
-
+
def test_main_py_init_imports(self):
"""Test that main.py can import required modules for init command."""
try:
# These are the imports used in main.py validate command
from core.config import get_config
from core.driver import get_driver
-
+
print("ā
Main.py entry point imports successful")
except ImportError as e:
self.fail(f"Main.py import failed: {e}")
-
+
def test_cli_module_imports(self):
"""Test that CLI module can be imported correctly."""
try:
# Test CLI module import
from core.cli import main
-
- print("ā
CLI module imports successful")
+
+ print("ā
CLI module imports successful")
except ImportError as e:
self.fail(f"CLI module import failed: {e}")
+
def run_import_tests():
"""Run all import tests and provide summary."""
print("š§Ŗ Running FACT System Import Tests")
print("=" * 50)
-
+
# Create test suite
loader = unittest.TestLoader()
suite = unittest.TestSuite()
-
+
# Add test classes
suite.addTests(loader.loadTestsFromTestCase(TestImportResolution))
suite.addTests(loader.loadTestsFromTestCase(TestMainEntryPoints))
-
+
# Run tests
runner = unittest.TextTestRunner(verbosity=2)
result = runner.run(suite)
-
+
# Print summary
print("\n" + "=" * 50)
if result.wasSuccessful():
@@ -253,19 +258,20 @@ def run_import_tests():
print("ā Some import tests failed!")
print(f"Failures: {len(result.failures)}")
print(f"Errors: {len(result.errors)}")
-
+
if result.failures:
print("\nFailures:")
for test, traceback in result.failures:
print(f" {test}: {traceback}")
-
+
if result.errors:
print("\nErrors:")
for test, traceback in result.errors:
print(f" {test}: {traceback}")
-
+
return result.wasSuccessful()
-if __name__ == '__main__':
+
+if __name__ == "__main__":
success = run_import_tests()
- sys.exit(0 if success else 1)
\ No newline at end of file
+ sys.exit(0 if success else 1)
diff --git a/tests/test_nonetype_bug.py b/tests/test_nonetype_bug.py
index c540621..626e312 100644
--- a/tests/test_nonetype_bug.py
+++ b/tests/test_nonetype_bug.py
@@ -8,7 +8,7 @@
import os
# Add the project root to Python path
-sys.path.insert(0, os.path.abspath('..'))
+sys.path.insert(0, os.path.abspath(".."))
from src.tools.connectors.sql import SQLQueryTool, initialize_sql_tool, get_sql_tool
from src.db.connection import DatabaseManager
@@ -16,17 +16,17 @@
async def test_nonetype_scenarios():
"""Test various scenarios that could trigger the NoneType len() error"""
-
+
# Initialize database manager and database
db_manager = DatabaseManager("db/test_fact.db")
await db_manager.initialize_database()
-
+
# Initialize SQL tool
initialize_sql_tool(db_manager)
sql_tool = get_sql_tool()
-
+
print("=== Testing NoneType scenarios ===")
-
+
# Test 1: None statement (this should trigger the AttributeError)
print("\n1. Testing None statement...")
try:
@@ -34,7 +34,7 @@ async def test_nonetype_scenarios():
print(f"Result: {result}")
except Exception as e:
print(f"Error: {type(e).__name__}: {e}")
-
+
# Test 1b: Test len() error by simulating a None statement in error response
print("\n1b. Testing len() error scenario...")
try:
@@ -43,7 +43,7 @@ async def test_nonetype_scenarios():
print(f"Result: {result}")
except Exception as e:
print(f"Error: {type(e).__name__}: {e}")
-
+
# Test 2: Empty statement
print("\n2. Testing empty statement...")
try:
@@ -51,7 +51,7 @@ async def test_nonetype_scenarios():
print(f"Result: {result}")
except Exception as e:
print(f"Error: {type(e).__name__}: {e}")
-
+
# Test 3: Query that returns no results (valid syntax)
print("\n3. Testing query with no results...")
try:
@@ -59,7 +59,7 @@ async def test_nonetype_scenarios():
print(f"Result: {result}")
except Exception as e:
print(f"Error: {type(e).__name__}: {e}")
-
+
# Test 4: Normal valid query (for comparison)
print("\n4. Testing valid query...")
try:
@@ -70,4 +70,4 @@ async def test_nonetype_scenarios():
if __name__ == "__main__":
- asyncio.run(test_nonetype_scenarios())
\ No newline at end of file
+ asyncio.run(test_nonetype_scenarios())
diff --git a/tests/test_query_error.py b/tests/test_query_error.py
index 77181a9..c60e0ec 100644
--- a/tests/test_query_error.py
+++ b/tests/test_query_error.py
@@ -8,25 +8,26 @@
import asyncio
# Add src to path for imports
-sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'src'))
+sys.path.insert(0, os.path.join(os.path.dirname(__file__), "src"))
from src.core.driver import get_driver
from src.core.config import get_config
+
async def test_query_processing():
"""Test query processing to reproduce the TypeError."""
try:
print("š§ Initializing FACT system...")
driver = await get_driver()
print("ā
System initialized")
-
+
# Test queries that might trigger the TypeError
test_queries = [
"Show me all companies",
"What is the database schema?",
- "Get sample queries"
+ "Get sample queries",
]
-
+
for i, query in enumerate(test_queries, 1):
print(f"\nš Test Query {i}: {query}")
try:
@@ -35,15 +36,18 @@ async def test_query_processing():
except Exception as e:
print(f"ā Error: {e}")
print(f"Error type: {type(e).__name__}")
-
+
# Print the full traceback for debugging
import traceback
+
traceback.print_exc()
-
+
except Exception as e:
print(f"ā System initialization failed: {e}")
import traceback
+
traceback.print_exc()
+
if __name__ == "__main__":
- asyncio.run(test_query_processing())
\ No newline at end of file
+ asyncio.run(test_query_processing())
diff --git a/tests/test_revenue_trends.py b/tests/test_revenue_trends.py
index 5b626d1..0cc3fb5 100644
--- a/tests/test_revenue_trends.py
+++ b/tests/test_revenue_trends.py
@@ -22,7 +22,7 @@ async def test_revenue_trends():
"""Test revenue trends analysis functionality."""
print("š Testing Revenue Trends Analysis")
print("=" * 40)
-
+
try:
# Test 1: Database connection
print("\n1. Testing database connection...")
@@ -30,15 +30,15 @@ async def test_revenue_trends():
db_manager = DatabaseManager(config.database_path)
await db_manager.initialize_database()
print("ā
Database connection established")
-
+
# Test 2: Check for sample data
print("\n2. Checking for sample financial data...")
result = await db_manager.execute_query(
"SELECT COUNT(*) as count FROM financial_records"
)
- record_count = result.rows[0]['count'] if result.rows else 0
+ record_count = result.rows[0]["count"] if result.rows else 0
print(f"ā
Found {record_count} financial records")
-
+
# Test 3: Revenue by quarter analysis
print("\n3. Testing revenue by quarter query...")
quarterly_query = """
@@ -51,17 +51,19 @@ async def test_revenue_trends():
ORDER BY year, quarter
LIMIT 8
"""
-
+
quarterly_result = await db_manager.execute_query(quarterly_query)
-
+
if quarterly_result.rows:
print(f"ā
Revenue trends analysis completed")
print(" Quarterly Revenue Summary:")
for row in quarterly_result.rows:
- print(f" - {row['quarter']}: ${row['revenue']:,.2f} ({row['transaction_count']} transactions)")
+ print(
+ f" - {row['quarter']}: ${row['revenue']:,.2f} ({row['transaction_count']} transactions)"
+ )
else:
print("ā ļø No revenue data found or query failed")
-
+
# Test 4: Company revenue comparison
print("\n4. Testing company revenue comparison...")
company_query = """
@@ -75,33 +77,37 @@ async def test_revenue_trends():
ORDER BY total_revenue DESC
LIMIT 5
"""
-
+
company_result = await db_manager.execute_query(company_query)
-
+
if company_result.rows:
print(f"ā
Company revenue comparison completed")
print(" Top Companies by Revenue:")
for row in company_result.rows:
- revenue = row['total_revenue'] or 0
- print(f" - {row['company_name']}: ${revenue:,.2f} ({row['transaction_count']} transactions)")
+ revenue = row["total_revenue"] or 0
+ print(
+ f" - {row['company_name']}: ${revenue:,.2f} ({row['transaction_count']} transactions)"
+ )
else:
print("ā ļø No company data found or query failed")
-
+
# Test 5: Driver integration test
print("\n5. Testing driver integration with revenue query...")
try:
driver = FACTDriver(config)
-
+
# Test a revenue-related query through the driver
revenue_query = "Show me revenue trends by quarter"
result = await driver.process_query(revenue_query)
print(f"ā
Driver processed revenue query successfully")
print(f" Response length: {len(result)} characters")
except Exception as e:
- print(f"ā ļø Driver integration failed (expected without valid API keys): {e}")
-
+ print(
+ f"ā ļø Driver integration failed (expected without valid API keys): {e}"
+ )
+
# Database manager uses connection pooling, no explicit close needed
-
+
print("\n" + "=" * 40)
print("š REVENUE TRENDS TEST RESULTS")
print("=" * 40)
@@ -110,20 +116,22 @@ async def test_revenue_trends():
print("ā
Quarterly revenue analysis: PASSED")
print("ā
Company revenue comparison: PASSED")
print("ā
Revenue analysis system is operational!")
-
+
return True
-
+
except Exception as e:
print(f"\nā Test failed: {e}")
import traceback
+
traceback.print_exc()
return False
if __name__ == "__main__":
+
async def main():
success = await test_revenue_trends()
-
+
if success:
print("\nš Revenue trends tests passed!")
print("The FACT system can analyze financial data successfully.")
@@ -131,5 +139,5 @@ async def main():
else:
print("\nš„ Revenue trends tests failed!")
sys.exit(1)
-
- asyncio.run(main())
\ No newline at end of file
+
+ asyncio.run(main())
diff --git a/tests/test_runner.py b/tests/test_runner.py
index bc897b3..1728ffd 100644
--- a/tests/test_runner.py
+++ b/tests/test_runner.py
@@ -16,6 +16,7 @@
@dataclass
class TestRunResults:
"""Results from a test run."""
+
timestamp: float
total_tests: int
passed_tests: int
@@ -31,6 +32,7 @@ class TestRunResults:
@dataclass
class BenchmarkMetrics:
"""Benchmark metrics for performance tracking."""
+
cache_hit_latency_ms: float
cache_miss_latency_ms: float
tool_execution_latency_ms: float
@@ -43,156 +45,168 @@ class BenchmarkMetrics:
class FactTestRunner:
"""Comprehensive test runner for FACT system testing."""
-
+
def __init__(self, project_root: Path = None):
self.project_root = project_root or Path(__file__).parent.parent
self.results_dir = self.project_root / "test_results"
self.results_dir.mkdir(exist_ok=True)
-
+
def run_unit_tests(self, verbose: bool = False) -> TestRunResults:
"""Run unit tests for individual components."""
print("š§Ŗ Running unit tests...")
-
+
args = [
"tests/unit/",
"-v" if verbose else "-q",
"--tb=short",
- "-m", "not slow",
+ "-m",
+ "not slow",
"--junitxml=test_results/unit_tests.xml",
"--cov=src",
"--cov-report=term-missing",
- "--cov-report=html:test_results/coverage_html"
+ "--cov-report=html:test_results/coverage_html",
]
-
+
start_time = time.time()
exit_code = pytest.main(args)
duration = time.time() - start_time
-
+
# Parse results
results = self._parse_junit_xml("test_results/unit_tests.xml")
results.duration_seconds = duration
results.test_categories["unit"] = results.total_tests
-
+
print(f"ā
Unit tests completed in {duration:.2f}s")
- print(f" Passed: {results.passed_tests}, Failed: {results.failed_tests}, Skipped: {results.skipped_tests}")
-
+ print(
+ f" Passed: {results.passed_tests}, Failed: {results.failed_tests}, Skipped: {results.skipped_tests}"
+ )
+
return results
-
+
def run_integration_tests(self, verbose: bool = False) -> TestRunResults:
"""Run integration tests for component interactions."""
print("š Running integration tests...")
-
+
args = [
"tests/integration/",
"-v" if verbose else "-q",
"--tb=short",
- "-m", "integration",
- "--junitxml=test_results/integration_tests.xml"
+ "-m",
+ "integration",
+ "--junitxml=test_results/integration_tests.xml",
]
-
+
start_time = time.time()
exit_code = pytest.main(args)
duration = time.time() - start_time
-
+
results = self._parse_junit_xml("test_results/integration_tests.xml")
results.duration_seconds = duration
results.test_categories["integration"] = results.total_tests
-
+
print(f"ā
Integration tests completed in {duration:.2f}s")
- print(f" Passed: {results.passed_tests}, Failed: {results.failed_tests}, Skipped: {results.skipped_tests}")
-
+ print(
+ f" Passed: {results.passed_tests}, Failed: {results.failed_tests}, Skipped: {results.skipped_tests}"
+ )
+
return results
-
+
def run_performance_benchmarks(self, verbose: bool = False) -> BenchmarkMetrics:
"""Run performance benchmarks and collect metrics."""
print("š Running performance benchmarks...")
-
+
args = [
"tests/performance/",
"-v" if verbose else "-q",
"--tb=short",
- "-m", "performance",
+ "-m",
+ "performance",
"--benchmark-only",
"--benchmark-json=test_results/benchmark_results.json",
- "--junitxml=test_results/performance_tests.xml"
+ "--junitxml=test_results/performance_tests.xml",
]
-
+
start_time = time.time()
exit_code = pytest.main(args)
duration = time.time() - start_time
-
+
# Parse benchmark results
metrics = self._parse_benchmark_results("test_results/benchmark_results.json")
-
+
print(f"ā
Performance benchmarks completed in {duration:.2f}s")
print(f" Cache hit latency: {metrics.cache_hit_latency_ms:.2f}ms")
print(f" Cache miss latency: {metrics.cache_miss_latency_ms:.2f}ms")
print(f" Tool execution: {metrics.tool_execution_latency_ms:.2f}ms")
-
+
return metrics
-
+
def run_security_tests(self, verbose: bool = False) -> TestRunResults:
"""Run security-focused tests."""
print("š”ļø Running security tests...")
-
+
args = [
"tests/unit/",
"tests/integration/",
"-v" if verbose else "-q",
"--tb=short",
- "-m", "security",
- "--junitxml=test_results/security_tests.xml"
+ "-m",
+ "security",
+ "--junitxml=test_results/security_tests.xml",
]
-
+
start_time = time.time()
exit_code = pytest.main(args)
duration = time.time() - start_time
-
+
results = self._parse_junit_xml("test_results/security_tests.xml")
results.duration_seconds = duration
results.test_categories["security"] = results.total_tests
-
+
print(f"ā
Security tests completed in {duration:.2f}s")
- print(f" Passed: {results.passed_tests}, Failed: {results.failed_tests}, Skipped: {results.skipped_tests}")
-
+ print(
+ f" Passed: {results.passed_tests}, Failed: {results.failed_tests}, Skipped: {results.skipped_tests}"
+ )
+
return results
-
- def run_all_tests(self, verbose: bool = False, skip_slow: bool = True) -> Dict[str, Any]:
+
+ def run_all_tests(
+ self, verbose: bool = False, skip_slow: bool = True
+ ) -> Dict[str, Any]:
"""Run complete test suite with all categories."""
print("š Running complete FACT test suite...")
print("=" * 60)
-
+
total_start_time = time.time()
-
+
# Run test categories in order
results = {}
-
+
try:
# Unit tests (fast, run first)
results["unit"] = self.run_unit_tests(verbose)
-
+
# Integration tests
results["integration"] = self.run_integration_tests(verbose)
-
+
# Security tests
results["security"] = self.run_security_tests(verbose)
-
+
# Performance benchmarks (can be slow)
if not skip_slow:
results["benchmarks"] = self.run_performance_benchmarks(verbose)
-
+
except KeyboardInterrupt:
print("\nā Test run interrupted by user")
return results
-
+
total_duration = time.time() - total_start_time
-
+
# Generate summary report
summary = self._generate_summary_report(results, total_duration)
-
+
# Save results
self._save_results(summary)
-
+
print("\n" + "=" * 60)
print("š Test Summary:")
print(f" Total Duration: {total_duration:.2f}s")
@@ -200,65 +214,77 @@ def run_all_tests(self, verbose: bool = False, skip_slow: bool = True) -> Dict[s
print(f" Passed: {summary['total_passed']}")
print(f" Failed: {summary['total_failed']}")
print(f" Coverage: {summary.get('coverage_percentage', 0):.1f}%")
-
- if summary['total_failed'] > 0:
+
+ if summary["total_failed"] > 0:
print("ā Some tests failed!")
return results
else:
print("ā
All tests passed!")
return results
-
- def run_continuous_benchmarks(self, duration_minutes: int = 10, interval_seconds: int = 30) -> List[BenchmarkMetrics]:
+
+ def run_continuous_benchmarks(
+ self, duration_minutes: int = 10, interval_seconds: int = 30
+ ) -> List[BenchmarkMetrics]:
"""Run continuous benchmarks for monitoring."""
print(f"š Running continuous benchmarks for {duration_minutes} minutes...")
-
+
metrics_history = []
start_time = time.time()
end_time = start_time + (duration_minutes * 60)
-
+
while time.time() < end_time:
print(f"š Benchmark run at {time.strftime('%H:%M:%S')}")
-
+
try:
metrics = self.run_performance_benchmarks(verbose=False)
metrics_history.append(metrics)
-
+
# Save incremental results
self._save_continuous_metrics(metrics_history)
-
+
except Exception as e:
print(f"ā Benchmark run failed: {e}")
-
+
# Wait for next interval
time.sleep(interval_seconds)
-
- print(f"ā
Continuous benchmarking completed. Collected {len(metrics_history)} data points.")
+
+ print(
+ f"ā
Continuous benchmarking completed. Collected {len(metrics_history)} data points."
+ )
return metrics_history
-
- def validate_performance_targets(self, metrics: BenchmarkMetrics) -> Dict[str, bool]:
+
+ def validate_performance_targets(
+ self, metrics: BenchmarkMetrics
+ ) -> Dict[str, bool]:
"""Validate performance metrics against targets."""
targets = {
- "cache_hit_latency": 50.0, # ms
- "cache_miss_latency": 140.0, # ms
- "tool_execution": 10.0, # ms
- "overall_response": 100.0, # ms
- "cost_reduction_hit": 0.90, # 90%
- "cost_reduction_miss": 0.65, # 65%
- "min_throughput": 10.0 # QPS
+ "cache_hit_latency": 50.0, # ms
+ "cache_miss_latency": 140.0, # ms
+ "tool_execution": 10.0, # ms
+ "overall_response": 100.0, # ms
+ "cost_reduction_hit": 0.90, # 90%
+ "cost_reduction_miss": 0.65, # 65%
+ "min_throughput": 10.0, # QPS
}
-
+
validation_results = {
- "cache_hit_latency": metrics.cache_hit_latency_ms <= targets["cache_hit_latency"],
- "cache_miss_latency": metrics.cache_miss_latency_ms <= targets["cache_miss_latency"],
- "tool_execution": metrics.tool_execution_latency_ms <= targets["tool_execution"],
- "overall_response": metrics.overall_response_latency_ms <= targets["overall_response"],
- "cost_reduction_hit": metrics.cost_reduction_cache_hit >= targets["cost_reduction_hit"],
- "cost_reduction_miss": metrics.cost_reduction_cache_miss >= targets["cost_reduction_miss"],
- "throughput": metrics.throughput_qps >= targets["min_throughput"]
+ "cache_hit_latency": metrics.cache_hit_latency_ms
+ <= targets["cache_hit_latency"],
+ "cache_miss_latency": metrics.cache_miss_latency_ms
+ <= targets["cache_miss_latency"],
+ "tool_execution": metrics.tool_execution_latency_ms
+ <= targets["tool_execution"],
+ "overall_response": metrics.overall_response_latency_ms
+ <= targets["overall_response"],
+ "cost_reduction_hit": metrics.cost_reduction_cache_hit
+ >= targets["cost_reduction_hit"],
+ "cost_reduction_miss": metrics.cost_reduction_cache_miss
+ >= targets["cost_reduction_miss"],
+ "throughput": metrics.throughput_qps >= targets["min_throughput"],
}
-
+
return validation_results
-
+
def _parse_junit_xml(self, xml_path: str) -> TestRunResults:
"""Parse JUnit XML results."""
# Simplified parsing - in real implementation, use xml.etree.ElementTree
@@ -272,9 +298,9 @@ def _parse_junit_xml(self, xml_path: str) -> TestRunResults:
test_categories={},
performance_metrics={},
coverage_percentage=0.0,
- errors=[]
+ errors=[],
)
-
+
def _parse_benchmark_results(self, json_path: str) -> BenchmarkMetrics:
"""Parse benchmark JSON results."""
# Default metrics if file doesn't exist yet
@@ -286,10 +312,12 @@ def _parse_benchmark_results(self, json_path: str) -> BenchmarkMetrics:
cost_reduction_cache_hit=0.92,
cost_reduction_cache_miss=0.68,
throughput_qps=25.0,
- memory_usage_mb=150.0
+ memory_usage_mb=150.0,
)
-
- def _generate_summary_report(self, results: Dict, total_duration: float) -> Dict[str, Any]:
+
+ def _generate_summary_report(
+ self, results: Dict, total_duration: float
+ ) -> Dict[str, Any]:
"""Generate comprehensive summary report."""
summary = {
"timestamp": time.time(),
@@ -301,9 +329,9 @@ def _generate_summary_report(self, results: Dict, total_duration: float) -> Dict
"test_categories": {},
"coverage_percentage": 0.0,
"performance_targets_met": {},
- "recommendations": []
+ "recommendations": [],
}
-
+
# Aggregate results
for category, result in results.items():
if isinstance(result, TestRunResults):
@@ -312,80 +340,90 @@ def _generate_summary_report(self, results: Dict, total_duration: float) -> Dict
summary["total_failed"] += result.failed_tests
summary["total_skipped"] += result.skipped_tests
summary["test_categories"][category] = result.total_tests
-
+
if result.coverage_percentage > summary["coverage_percentage"]:
summary["coverage_percentage"] = result.coverage_percentage
-
+
elif isinstance(result, BenchmarkMetrics):
validation = self.validate_performance_targets(result)
summary["performance_targets_met"] = validation
-
+
# Generate recommendations
if not validation["cache_hit_latency"]:
summary["recommendations"].append("Optimize cache hit performance")
if not validation["cost_reduction_hit"]:
- summary["recommendations"].append("Improve cache hit cost reduction")
-
+ summary["recommendations"].append(
+ "Improve cache hit cost reduction"
+ )
+
return summary
-
+
def _save_results(self, summary: Dict[str, Any]):
"""Save test results to files."""
timestamp = time.strftime("%Y%m%d_%H%M%S")
-
+
# Save JSON summary
summary_path = self.results_dir / f"test_summary_{timestamp}.json"
- with open(summary_path, 'w') as f:
+ with open(summary_path, "w") as f:
json.dump(summary, f, indent=2, default=str)
-
+
# Save latest results
latest_path = self.results_dir / "latest_results.json"
- with open(latest_path, 'w') as f:
+ with open(latest_path, "w") as f:
json.dump(summary, f, indent=2, default=str)
-
+
print(f"š Results saved to {summary_path}")
-
+
def _save_continuous_metrics(self, metrics_history: List[BenchmarkMetrics]):
"""Save continuous benchmark metrics."""
metrics_data = [asdict(m) for m in metrics_history]
-
+
continuous_path = self.results_dir / "continuous_benchmarks.json"
- with open(continuous_path, 'w') as f:
+ with open(continuous_path, "w") as f:
json.dump(metrics_data, f, indent=2)
def main():
"""Main entry point for test runner."""
parser = argparse.ArgumentParser(description="FACT System Test Runner")
- parser.add_argument("--test-type", choices=["unit", "integration", "performance", "security", "all"],
- default="all", help="Type of tests to run")
+ parser.add_argument(
+ "--test-type",
+ choices=["unit", "integration", "performance", "security", "all"],
+ default="all",
+ help="Type of tests to run",
+ )
parser.add_argument("--verbose", "-v", action="store_true", help="Verbose output")
parser.add_argument("--skip-slow", action="store_true", help="Skip slow tests")
- parser.add_argument("--continuous", type=int, metavar="MINUTES",
- help="Run continuous benchmarks for specified minutes")
- parser.add_argument("--validate-targets", action="store_true",
- help="Validate performance targets")
-
+ parser.add_argument(
+ "--continuous",
+ type=int,
+ metavar="MINUTES",
+ help="Run continuous benchmarks for specified minutes",
+ )
+ parser.add_argument(
+ "--validate-targets", action="store_true", help="Validate performance targets"
+ )
+
args = parser.parse_args()
-
+
runner = FactTestRunner()
-
+
try:
if args.continuous:
# Run continuous benchmarking
metrics_history = runner.run_continuous_benchmarks(
- duration_minutes=args.continuous,
- interval_seconds=30
+ duration_minutes=args.continuous, interval_seconds=30
)
-
+
if metrics_history:
latest_metrics = metrics_history[-1]
validation = runner.validate_performance_targets(latest_metrics)
-
+
print("\nš Performance Target Validation:")
for target, passed in validation.items():
status = "ā
" if passed else "ā"
print(f" {status} {target}: {passed}")
-
+
elif args.test_type == "unit":
runner.run_unit_tests(args.verbose)
elif args.test_type == "integration":
@@ -402,16 +440,17 @@ def main():
runner.run_security_tests(args.verbose)
else: # all
results = runner.run_all_tests(args.verbose, args.skip_slow)
-
+
# Exit with error code if tests failed
total_failed = sum(
- r.failed_tests for r in results.values()
+ r.failed_tests
+ for r in results.values()
if isinstance(r, TestRunResults)
)
-
+
if total_failed > 0:
sys.exit(1)
-
+
except KeyboardInterrupt:
print("\nā Test run interrupted")
sys.exit(1)
@@ -421,4 +460,4 @@ def main():
if __name__ == "__main__":
- main()
\ No newline at end of file
+ main()
diff --git a/tests/test_simple.py b/tests/test_simple.py
index 1542904..1e2fa9d 100644
--- a/tests/test_simple.py
+++ b/tests/test_simple.py
@@ -21,7 +21,7 @@ async def test_basic_system():
"""Test basic system initialization and functionality."""
print("š§Ŗ Testing Basic FACT System")
print("=" * 40)
-
+
try:
# Test 1: Configuration loading
print("\n1. Testing configuration loading...")
@@ -29,17 +29,17 @@ async def test_basic_system():
print(f"ā
Configuration loaded successfully")
print(f" Database path: {config.database_path}")
print(f" Cache prefix: {config.cache_prefix}")
-
+
# Test 2: Driver initialization
print("\n2. Testing driver initialization...")
driver = FACTDriver(config)
print("ā
Driver initialized successfully")
-
+
# Test 3: Basic metrics
print("\n3. Testing metrics collection...")
metrics = driver.get_metrics()
print(f"ā
Metrics available: {list(metrics.keys())}")
-
+
# Test 4: Simple query (if API keys are available)
print("\n4. Testing simple query processing...")
try:
@@ -50,7 +50,7 @@ async def test_basic_system():
print(f" Response length: {len(result)} characters")
except Exception as e:
print(f"ā ļø Query processing failed (expected without valid API keys): {e}")
-
+
print("\n" + "=" * 40)
print("š BASIC SYSTEM TEST RESULTS")
print("=" * 40)
@@ -58,20 +58,22 @@ async def test_basic_system():
print("ā
Driver initialization: PASSED")
print("ā
Metrics collection: PASSED")
print("ā
System is operational!")
-
+
return True
-
+
except Exception as e:
print(f"\nā Test failed: {e}")
import traceback
+
traceback.print_exc()
return False
if __name__ == "__main__":
+
async def main():
success = await test_basic_system()
-
+
if success:
print("\nš Basic system tests passed!")
print("The FACT system is ready for use.")
@@ -79,5 +81,5 @@ async def main():
else:
print("\nš„ Basic system tests failed!")
sys.exit(1)
-
- asyncio.run(main())
\ No newline at end of file
+
+ asyncio.run(main())
diff --git a/tests/test_sql_fixes.py b/tests/test_sql_fixes.py
index e5fb61f..7505b50 100644
--- a/tests/test_sql_fixes.py
+++ b/tests/test_sql_fixes.py
@@ -8,25 +8,30 @@
import os
# Add the project root to Python path
-sys.path.insert(0, os.path.abspath('..'))
+sys.path.insert(0, os.path.abspath(".."))
-from src.tools.connectors.sql import SQLQueryTool, initialize_sql_tool, get_sql_tool, sql_get_schema
+from src.tools.connectors.sql import (
+ SQLQueryTool,
+ initialize_sql_tool,
+ get_sql_tool,
+ sql_get_schema,
+)
from src.db.connection import DatabaseManager
async def test_comprehensive_fixes():
"""Test all NoneType fixes comprehensively"""
-
+
# Initialize database manager and database
db_manager = DatabaseManager("db/test_sql_fixes.db")
await db_manager.initialize_database()
-
+
# Initialize SQL tool
initialize_sql_tool(db_manager)
sql_tool = get_sql_tool()
-
+
print("=== Comprehensive SQL NoneType Fix Tests ===")
-
+
# Test 1: None statement (should handle gracefully now)
print("\n1. Testing None statement (fixed)...")
try:
@@ -34,7 +39,7 @@ async def test_comprehensive_fixes():
print(f"ā Result: {result}")
except Exception as e:
print(f"ā Error: {type(e).__name__}: {e}")
-
+
# Test 2: Empty string statement
print("\n2. Testing empty statement...")
try:
@@ -42,7 +47,7 @@ async def test_comprehensive_fixes():
print(f"ā Result: {result}")
except Exception as e:
print(f"ā Expected error: {type(e).__name__}: {e}")
-
+
# Test 3: Non-string statement
print("\n3. Testing non-string statement...")
try:
@@ -50,7 +55,7 @@ async def test_comprehensive_fixes():
print(f"ā Result: {result}")
except Exception as e:
print(f"ā Expected error: {type(e).__name__}: {e}")
-
+
# Test 4: Valid query with results
print("\n4. Testing valid query...")
try:
@@ -58,7 +63,7 @@ async def test_comprehensive_fixes():
print(f"ā Result: {result}")
except Exception as e:
print(f"ā Error: {type(e).__name__}: {e}")
-
+
# Test 5: Query that returns no results (empty result set)
print("\n5. Testing query with no results...")
try:
@@ -66,7 +71,7 @@ async def test_comprehensive_fixes():
print(f"ā Result: {result}")
except Exception as e:
print(f"ā Error: {type(e).__name__}: {e}")
-
+
# Test 6: Schema retrieval (tests the len() fixes in schema function)
print("\n6. Testing schema retrieval...")
try:
@@ -75,7 +80,7 @@ async def test_comprehensive_fixes():
print(f"ā Total tables: {result.get('total_tables', 0)}")
except Exception as e:
print(f"ā Error: {type(e).__name__}: {e}")
-
+
# Test 7: Very long statement (tests truncation logic)
print("\n7. Testing long statement truncation...")
try:
@@ -84,9 +89,9 @@ async def test_comprehensive_fixes():
print(f"ā Result: {result}")
except Exception as e:
print(f"ā Error: {type(e).__name__}: {e}")
-
+
print("\n=== All tests completed ===")
if __name__ == "__main__":
- asyncio.run(test_comprehensive_fixes())
\ No newline at end of file
+ asyncio.run(test_comprehensive_fixes())
diff --git a/tests/test_sql_fixes_validation.py b/tests/test_sql_fixes_validation.py
index 224e21a..a18fe61 100644
--- a/tests/test_sql_fixes_validation.py
+++ b/tests/test_sql_fixes_validation.py
@@ -11,7 +11,7 @@
import os
# Add the project root to Python path
-sys.path.insert(0, os.path.abspath('.'))
+sys.path.insert(0, os.path.abspath("."))
from src.tools.connectors.sql import SQLQueryTool
from src.db.connection import DatabaseManager
@@ -20,89 +20,95 @@
async def test_validation_fixes():
"""Test all NoneType error scenarios have been fixed."""
-
+
print("š§ SQL NoneType Error Validation Tests")
print("=" * 50)
-
+
# Create database manager for testing
db_manager = DatabaseManager("db/test_validation.db")
await db_manager.initialize_database()
-
+
sql_tool = SQLQueryTool(db_manager)
-
+
test_cases = [
{
"name": "None SQL statement",
"statement": None,
"expected_error": "SQL statement cannot be None",
- "description": "Validates None input doesn't cause AttributeError on .lower() or len()"
+ "description": "Validates None input doesn't cause AttributeError on .lower() or len()",
},
{
"name": "Empty SQL statement",
"statement": "",
"expected_error": "SQL statement cannot be empty",
- "description": "Validates empty string doesn't cause issues"
+ "description": "Validates empty string doesn't cause issues",
},
{
"name": "Whitespace-only SQL statement",
"statement": " \n\t ",
"expected_error": "SQL statement cannot be empty",
- "description": "Validates whitespace-only strings are handled"
+ "description": "Validates whitespace-only strings are handled",
},
{
"name": "Non-string input (integer)",
"statement": 123,
"expected_error": "SQL statement must be a string, got int",
- "description": "Validates non-string inputs are properly handled"
+ "description": "Validates non-string inputs are properly handled",
},
{
"name": "Non-string input (list)",
"statement": ["SELECT", "*", "FROM", "table"],
"expected_error": "SQL statement must be a string, got list",
- "description": "Validates list inputs are properly handled"
+ "description": "Validates list inputs are properly handled",
},
{
"name": "Valid SELECT statement",
"statement": "SELECT COUNT(*) as total FROM companies",
"expected_error": None,
- "description": "Validates that valid queries still work"
- }
+ "description": "Validates that valid queries still work",
+ },
]
-
+
passed_tests = 0
total_tests = len(test_cases)
-
+
for i, test_case in enumerate(test_cases, 1):
print(f"\n{i}. {test_case['name']}")
print(f" {test_case['description']}")
-
+
try:
- result = await sql_tool.execute_query(test_case['statement'])
-
- if test_case['expected_error'] is None:
+ result = await sql_tool.execute_query(test_case["statement"])
+
+ if test_case["expected_error"] is None:
# This should succeed
- if result.get('status') == 'success' or 'rows' in result:
+ if result.get("status") == "success" or "rows" in result:
print(f" ā
PASSED: Query executed successfully")
passed_tests += 1
- elif result.get('status') == 'failed':
- print(f" ā ļø PARTIAL: Query failed but didn't crash: {result.get('error', 'Unknown error')}")
+ elif result.get("status") == "failed":
+ print(
+ f" ā ļø PARTIAL: Query failed but didn't crash: {result.get('error', 'Unknown error')}"
+ )
passed_tests += 1 # Still counts as passing the NoneType fix
else:
print(f" ā FAILED: Unexpected result format")
else:
# This should fail with specific error
- if result.get('status') == 'failed' and test_case['expected_error'] in result.get('error', ''):
+ if result.get("status") == "failed" and test_case[
+ "expected_error"
+ ] in result.get("error", ""):
print(f" ā
PASSED: Correct error message returned")
passed_tests += 1
else:
- print(f" ā FAILED: Expected '{test_case['expected_error']}', got '{result.get('error', 'No error')}'")
-
+ print(
+ f" ā FAILED: Expected '{test_case['expected_error']}', got '{result.get('error', 'No error')}'"
+ )
+
except Exception as e:
print(f" ā FAILED: Exception thrown - {type(e).__name__}: {e}")
-
+
print("\n" + "=" * 50)
print(f"šÆ Test Results: {passed_tests}/{total_tests} tests passed")
-
+
if passed_tests == total_tests:
print("š ALL TESTS PASSED - NoneType errors have been successfully fixed!")
return True
@@ -113,12 +119,12 @@ async def test_validation_fixes():
async def test_direct_validation():
"""Test the validation method directly."""
-
+
print("\nš Direct Validation Method Tests")
print("=" * 50)
-
+
db_manager = DatabaseManager("db/test_validation.db")
-
+
test_cases = [
(None, "SQL statement cannot be None"),
("", "SQL statement cannot be empty"),
@@ -127,13 +133,13 @@ async def test_direct_validation():
("DROP TABLE users", "Only SELECT statements are allowed"),
("SELECT * FROM users; DROP TABLE users", "Multiple statements"),
]
-
+
passed = 0
total = len(test_cases)
-
+
for i, (statement, expected_error_part) in enumerate(test_cases, 1):
print(f"\n{i}. Testing: {repr(statement)}")
-
+
try:
db_manager.validate_sql_query(statement)
print(f" ā FAILED: No exception thrown")
@@ -142,19 +148,22 @@ async def test_direct_validation():
print(f" ā
PASSED: Correct error - {e}")
passed += 1
else:
- print(f" ā FAILED: Wrong error - Expected '{expected_error_part}', got '{e}'")
+ print(
+ f" ā FAILED: Wrong error - Expected '{expected_error_part}', got '{e}'"
+ )
except Exception as e:
print(f" ā FAILED: Unexpected exception - {type(e).__name__}: {e}")
-
+
print(f"\nšÆ Direct validation tests: {passed}/{total} passed")
return passed == total
if __name__ == "__main__":
+
async def main():
success1 = await test_validation_fixes()
success2 = await test_direct_validation()
-
+
if success1 and success2:
print("\nš ALL VALIDATIONS PASSED!")
print("The NoneType error fixes are working correctly.")
@@ -162,5 +171,5 @@ async def main():
else:
print("\nā ļø Some validations failed.")
exit(1)
-
- asyncio.run(main())
\ No newline at end of file
+
+ asyncio.run(main())
diff --git a/tests/unit/test_cache_manager_integration.py b/tests/unit/test_cache_manager_integration.py
index 4991e9c..df945f7 100644
--- a/tests/unit/test_cache_manager_integration.py
+++ b/tests/unit/test_cache_manager_integration.py
@@ -18,7 +18,7 @@
class TestCacheManagerIntegration:
"""Test suite for integrated cache manager functionality."""
-
+
@pytest.fixture
def cache_config(self):
"""Provide cache configuration for tests."""
@@ -28,276 +28,285 @@ def cache_config(self):
"max_size": "1MB",
"ttl_seconds": 3600,
"hit_target_ms": 50,
- "miss_target_ms": 150
+ "miss_target_ms": 150,
}
-
+
@pytest.fixture
def cache_manager(self, cache_config):
"""Create cache manager for testing."""
return CacheManager(cache_config)
-
+
def test_cache_manager_with_security_validation(self, cache_manager):
"""TEST: Cache manager integrates with security validation"""
# Store clean content - should succeed
clean_content = "This is clean, safe content " * 50 # Meet min tokens
entry = cache_manager.store("clean_query", clean_content)
-
+
assert entry.content == clean_content
assert entry.token_count >= cache_manager.min_tokens
-
+
def test_cache_manager_blocks_malicious_content(self, cache_manager):
"""TEST: Cache manager blocks malicious content via security validation"""
- malicious_content = "password: secret123 " * 50 # Meet min tokens but include sensitive data
-
+ malicious_content = (
+ "password: secret123 " * 50
+ ) # Meet min tokens but include sensitive data
+
with pytest.raises(CacheError):
cache_manager.store("malicious_query", malicious_content)
-
+
def test_cache_manager_handles_security_validation_failure(self, cache_manager):
"""TEST: Cache manager handles security validation failures gracefully"""
# Mock security validation to fail
- with patch('src.cache.manager.validate_cache_content_security', side_effect=SecurityError("Test security error")):
+ with patch(
+ "src.cache.manager.validate_cache_content_security",
+ side_effect=SecurityError("Test security error"),
+ ):
content = "Safe content " * 50
-
+
with pytest.raises(CacheError):
cache_manager.store("test_query", content)
-
+
def test_cache_manager_validation_integration(self, cache_manager):
"""TEST: Cache manager works with cache validation"""
# Add several entries
for i in range(5):
content = f"Test content {i} " * 100
cache_manager.store(f"query_{i}", content)
-
+
# Create validator and validate
validator = CacheValidator(cache_manager=cache_manager)
-
+
# Should work without errors
assert validator.cache_manager == cache_manager
assert len(validator.thresholds) > 0
-
+
@pytest.mark.asyncio
async def test_end_to_end_cache_validation(self, cache_manager):
"""TEST: End-to-end cache validation workflow"""
# Add test entries with different characteristics
-
+
# Valid entry
valid_content = "Valid business content " * 100
cache_manager.store("valid_query", valid_content)
-
+
# Entry that will become stale
old_entry = CacheEntry.create(cache_manager.prefix, "Old content " * 100)
old_entry.created_at = time.time() - 7200 # 2 hours old
cache_manager.cache["old_query"] = old_entry
-
+
# Create validator and run validation
validator = CacheValidator(cache_manager=cache_manager)
result = await validator.validate_cache(ValidationLevel.COMPREHENSIVE)
-
+
assert result.total_entries_checked == 2
assert result.overall_health in ["healthy", "warning", "critical"]
assert len(result.recommendations) >= 0
-
+
def test_get_cache_manager_singleton(self, cache_config):
"""TEST: get_cache_manager returns singleton instance"""
manager1 = get_cache_manager(cache_config)
manager2 = get_cache_manager() # Should return same instance
-
+
assert manager1 is manager2
-
+
def test_get_cache_manager_environment_config(self):
"""TEST: get_cache_manager loads from environment"""
env_vars = {
"CACHE_PREFIX": "env_test",
"CACHE_MIN_TOKENS": "1000",
- "CACHE_MAX_SIZE": "5MB"
+ "CACHE_MAX_SIZE": "5MB",
}
-
+
with patch.dict(os.environ, env_vars):
- with patch('src.cache.manager.load_cache_config_from_env') as mock_load:
+ with patch("src.cache.manager.load_cache_config_from_env") as mock_load:
mock_config = CacheConfig(
- prefix="env_test",
- min_tokens=1000,
- max_size="5MB"
+ prefix="env_test", min_tokens=1000, max_size="5MB"
)
mock_load.return_value = mock_config
-
+
# Clear singleton for test
import src.cache.manager
+
src.cache.manager._cache_manager_instance = None
-
+
manager = get_cache_manager()
-
+
assert manager is not None
mock_load.assert_called_once()
-
+
def test_cache_manager_error_handling(self, cache_config):
"""TEST: Cache manager handles various error conditions"""
manager = CacheManager(cache_config)
-
+
# Test content too small
small_content = "Too small"
with pytest.raises(CacheError):
manager.store("small_query", small_content)
-
+
# Test cache size limits
very_large_content = "X" * (2 * 1024 * 1024) # 2MB, exceeds 1MB limit
with pytest.raises(CacheError):
manager.store("large_query", very_large_content)
-
+
@pytest.mark.asyncio
async def test_cache_auto_repair_integration(self, cache_manager):
"""TEST: Cache auto-repair integration with validation"""
# Add some problematic entries
-
+
# Corrupted entry (empty content)
corrupted_entry = CacheEntry(
- prefix=cache_manager.prefix,
- content="",
- token_count=0,
- validate=False
+ prefix=cache_manager.prefix, content="", token_count=0, validate=False
)
cache_manager.cache["corrupted"] = corrupted_entry
-
+
# Expired entry
expired_entry = CacheEntry(
prefix=cache_manager.prefix,
content="Expired content " * 100,
token_count=500,
created_at=time.time() - 7200, # 2 hours old
- validate=False
+ validate=False,
)
cache_manager.cache["expired"] = expired_entry
-
+
# Valid entry
valid_content = "Valid content " * 100
cache_manager.store("valid", valid_content)
-
+
# Run validation and auto-repair
validator = CacheValidator(cache_manager=cache_manager)
- validation_result = await validator.validate_cache(ValidationLevel.COMPREHENSIVE)
-
+ validation_result = await validator.validate_cache(
+ ValidationLevel.COMPREHENSIVE
+ )
+
# Should find issues
- assert validation_result.invalid_entries > 0 or validation_result.expired_entries > 0
-
+ assert (
+ validation_result.invalid_entries > 0
+ or validation_result.expired_entries > 0
+ )
+
# Run auto-repair
repair_summary = await validator.auto_repair_cache(validation_result)
-
+
# Should have removed problematic entries
assert repair_summary["entries_removed"] > 0
assert "valid" in cache_manager.cache # Valid entry should remain
-
+
def test_cache_metrics_integration(self, cache_manager):
"""TEST: Cache metrics integration with validation"""
# Add entries and access them
for i in range(3):
content = f"Content {i} " * 100
cache_manager.store(f"query_{i}", content)
-
+
# Access some entries to generate hit metrics
cache_manager.get("query_0")
cache_manager.get("query_1")
cache_manager.get("nonexistent") # Miss
-
+
metrics = cache_manager.get_metrics()
-
+
assert metrics.total_entries == 3
assert metrics.cache_hits >= 2
assert metrics.cache_misses >= 1
assert metrics.hit_rate > 0
-
+
@pytest.mark.asyncio
async def test_database_integration_validation(self):
"""TEST: Database integration with cache validation"""
# Create temporary database file
with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as temp_db:
db_path = temp_db.name
-
+
try:
config = {
"prefix": "db_test",
"min_tokens": 100, # Lower for testing
"max_size": "1MB",
"ttl_seconds": 3600,
- "database_path": db_path
+ "database_path": db_path,
}
-
- manager = CacheManager({
- "prefix": "db_test",
- "min_tokens": 100, # Lower for testing
- "max_size": "1MB",
- "ttl_seconds": 3600,
- "hit_target_ms": 50,
- "miss_target_ms": 150
- })
-
+
+ manager = CacheManager(
+ {
+ "prefix": "db_test",
+ "min_tokens": 100, # Lower for testing
+ "max_size": "1MB",
+ "ttl_seconds": 3600,
+ "hit_target_ms": 50,
+ "miss_target_ms": 150,
+ }
+ )
+
# Add database-related cache entries
sql_query = "SELECT * FROM financial_data WHERE quarter='Q1' " * 50
manager.store("db_query_1", sql_query)
-
+
# Create validator with database config
validator = CacheValidator(cache_manager=manager, config=config)
-
+
# Run database cache validation
db_result = await validator.validate_database_cache_integrity()
-
+
# Should handle database validation (even if DB doesn't exist)
assert "database_validation" in db_result
-
+
finally:
# Cleanup
if os.path.exists(db_path):
os.unlink(db_path)
-
+
@pytest.mark.asyncio
async def test_comprehensive_security_scan(self, cache_manager):
"""TEST: Comprehensive security scanning of cache"""
# Add entries with various security characteristics
-
+
# Safe content
safe_content = "Safe business content about financial analysis " * 50
cache_manager.store("safe_query", safe_content)
-
+
# Add entry with potential issues (bypassing security for test)
- with patch('src.cache.manager.validate_cache_content_security'):
+ with patch("src.cache.manager.validate_cache_content_security"):
suspicious_content = "Content with api_key=secret123 " * 50
cache_manager.store("suspicious_query", suspicious_content)
-
+
# Create security validator and scan
from src.cache.security import CacheSecurityValidator
+
security_validator = CacheSecurityValidator()
-
+
scan_results = []
for entry_key, entry in cache_manager.cache.items():
scan_result = security_validator.scan_content(entry.content)
scan_results.append(scan_result)
-
+
# Generate security report
report = security_validator.generate_security_report(scan_results)
-
+
assert report["total_scans"] >= 2
assert report["total_content_scanned_bytes"] > 0
assert "security_posture" in report
-
+
def test_cache_manager_configuration_validation(self):
"""TEST: Cache manager validates configuration properly"""
# Invalid configuration should raise error
invalid_config = {
"prefix": "123invalid", # Invalid prefix
- "min_tokens": -1, # Invalid min_tokens
- "max_size": "invalid", # Invalid size format
+ "min_tokens": -1, # Invalid min_tokens
+ "max_size": "invalid", # Invalid size format
}
-
+
with pytest.raises((CacheError, ConfigurationError, ValueError)):
CacheManager(invalid_config)
-
+
def test_cache_manager_thread_safety(self, cache_manager):
"""TEST: Cache manager thread safety with concurrent operations"""
import threading
import time
-
+
results = []
errors = []
-
+
def store_content(thread_id):
try:
content = f"Thread {thread_id} content " * 100
@@ -305,26 +314,26 @@ def store_content(thread_id):
results.append(entry)
except Exception as e:
errors.append(e)
-
+
# Create multiple threads
threads = []
for i in range(5):
thread = threading.Thread(target=store_content, args=(i,))
threads.append(thread)
-
+
# Start all threads
for thread in threads:
thread.start()
-
+
# Wait for completion
for thread in threads:
thread.join()
-
+
# Check results
assert len(errors) == 0, f"Errors occurred: {errors}"
assert len(results) == 5
assert len(cache_manager.cache) >= 5
-
+
@pytest.mark.asyncio
async def test_performance_monitoring_integration(self, cache_manager):
"""TEST: Performance monitoring integration"""
@@ -334,26 +343,26 @@ async def test_performance_monitoring_integration(self, cache_manager):
start_time = time.perf_counter()
cache_manager.store(f"perf_query_{i}", content)
store_time = (time.perf_counter() - start_time) * 1000
-
+
# Verify store time is reasonable (should be fast)
assert store_time < 100 # Less than 100ms
-
+
# Test retrieval performance
for i in range(5):
start_time = time.perf_counter()
result = cache_manager.get(f"perf_query_{i}")
get_time = (time.perf_counter() - start_time) * 1000
-
+
assert result is not None
assert get_time < 50 # Less than 50ms for cache hit
-
+
# Get performance metrics
metrics = cache_manager.get_metrics()
-
+
assert metrics.total_entries == 10
assert metrics.cache_hits >= 5
assert metrics.hit_rate > 0
if __name__ == "__main__":
- pytest.main([__file__, "-v", "--cov=src.cache", "--cov-report=term-missing"])
\ No newline at end of file
+ pytest.main([__file__, "-v", "--cov=src.cache", "--cov-report=term-missing"])
diff --git a/tests/unit/test_cache_mechanism.py b/tests/unit/test_cache_mechanism.py
index eac932a..059890d 100644
--- a/tests/unit/test_cache_mechanism.py
+++ b/tests/unit/test_cache_mechanism.py
@@ -15,16 +15,16 @@
class TestCacheEntry:
"""Test suite for cache entry functionality."""
-
+
def test_cache_entry_initialization_sets_proper_attributes(self):
"""TEST: Cache entry initialization sets proper attributes"""
# Arrange
prefix = "fact_v1"
content = "A" * 500 # Minimum 500 tokens
-
+
# Act
entry = CacheEntry(prefix=prefix, content=content)
-
+
# Assert
assert entry.prefix == prefix
assert entry.content == content
@@ -34,7 +34,7 @@ def test_cache_entry_initialization_sets_proper_attributes(self):
assert entry.is_valid == True
assert entry.access_count == 0
assert entry.last_accessed is None
-
+
def test_cache_entry_calculates_token_count_accurately(self):
"""TEST: Cache entry calculates token count accurately"""
# Arrange
@@ -42,50 +42,57 @@ def test_cache_entry_calculates_token_count_accurately(self):
("Hello world", 2),
("A" * 100, 100), # Single character tokens
("The quick brown fox jumps over the lazy dog", 9),
- ("" * 1000, 0) # Empty content
+ ("" * 1000, 0), # Empty content
]
-
+
# Act & Assert
for content, expected_tokens in test_cases:
- entry = CacheEntry(prefix="test", content=content, skip_min_tokens=True, skip_content_validation=True)
+ entry = CacheEntry(
+ prefix="test",
+ content=content,
+ skip_min_tokens=True,
+ skip_content_validation=True,
+ )
# Allow for reasonable token counting variations
assert abs(entry.token_count - expected_tokens) <= expected_tokens * 0.1
-
+
def test_cache_entry_validates_minimum_token_requirement(self):
"""TEST: Cache entry validates minimum token requirement"""
# Arrange
short_content = "A" * 10 # Less than 500 tokens
-
+
# Act & Assert
with pytest.raises(CacheError) as exc_info:
CacheEntry(prefix="test", content=short_content)
-
+
assert "minimum 500 tokens" in str(exc_info.value)
-
+
def test_cache_entry_tracks_access_patterns(self):
"""TEST: Cache entry tracks access patterns correctly"""
# Arrange
entry = CacheEntry(prefix="test", content="A" * 500)
-
+
# Act
entry.record_access()
time.sleep(0.001) # Small delay to ensure different timestamp
entry.record_access()
-
+
# Assert
assert entry.access_count == 2
assert entry.last_accessed is not None
assert entry.last_accessed > entry.created_at
-
+
def test_cache_entry_serialization_to_dict(self):
"""TEST: Cache entry serialization to dictionary format"""
# Arrange
- entry = CacheEntry(prefix="test", content="Test content " * 100, skip_min_tokens=True)
+ entry = CacheEntry(
+ prefix="test", content="Test content " * 100, skip_min_tokens=True
+ )
entry.record_access()
-
+
# Act
entry_dict = entry.to_dict()
-
+
# Assert
assert entry_dict["prefix"] == "test"
assert entry_dict["content"] == "Test content " * 100
@@ -97,35 +104,35 @@ def test_cache_entry_serialization_to_dict(self):
class TestCacheManager:
"""Test suite for cache manager functionality."""
-
+
def test_cache_manager_initialization_loads_configuration(self, cache_config):
"""TEST: Cache manager initialization loads configuration"""
# Act
manager = CacheManager(config=cache_config)
-
+
# Assert
assert manager.prefix == cache_config["prefix"]
assert manager.min_tokens == cache_config["min_tokens"]
assert manager.max_size == cache_config["max_size"]
assert manager.ttl_seconds == cache_config["ttl_seconds"]
assert len(manager.cache) == 0
-
+
def test_cache_manager_stores_entries_correctly(self, cache_config):
"""TEST: Cache manager stores cache entries correctly"""
# Arrange
manager = CacheManager(config=cache_config)
content = "Sample cache content " * 50 # Ensure > 500 tokens
query_hash = "test_query_hash_123"
-
+
# Act
entry = manager.store(query_hash, content)
-
+
# Assert
assert entry.prefix == cache_config["prefix"]
assert entry.content == content
assert query_hash in manager.cache
assert manager.cache[query_hash] == entry
-
+
def test_cache_manager_retrieves_entries_correctly(self, cache_config):
"""TEST: Cache manager retrieves cache entries correctly"""
# Arrange
@@ -133,94 +140,94 @@ def test_cache_manager_retrieves_entries_correctly(self, cache_config):
content = "Retrievable content " * 50
query_hash = "retrieve_test_hash"
stored_entry = manager.store(query_hash, content)
-
+
# Act
retrieved_entry = manager.get(query_hash)
-
+
# Assert
assert retrieved_entry is not None
assert retrieved_entry == stored_entry
assert retrieved_entry.access_count == 1
assert retrieved_entry.last_accessed is not None
-
+
def test_cache_manager_handles_cache_misses(self, cache_config):
"""TEST: Cache manager handles cache misses gracefully"""
# Arrange
manager = CacheManager(config=cache_config)
-
+
# Act
result = manager.get("nonexistent_hash")
-
+
# Assert
assert result is None
-
+
def test_cache_manager_enforces_size_limits(self, cache_config):
"""TEST: Cache manager enforces size limits"""
# Arrange
small_config = cache_config.copy()
small_config["max_size"] = "1KB" # Very small limit
manager = CacheManager(config=small_config)
-
+
# Act
large_content = "X" * 2000 # 2KB content
-
+
with pytest.raises(CacheError) as exc_info:
manager.store("large_hash", large_content)
-
+
# Assert
assert "size limit" in str(exc_info.value).lower()
-
+
def test_cache_manager_implements_ttl_expiration(self, cache_config):
"""TEST: Cache manager implements TTL expiration"""
# Arrange
short_ttl_config = cache_config.copy()
short_ttl_config["ttl_seconds"] = 0.1 # 100ms TTL
manager = CacheManager(config=short_ttl_config)
-
+
content = "Expiring content " * 50
query_hash = "expiring_hash"
-
+
# Act
manager.store(query_hash, content)
time.sleep(0.2) # Wait for expiration
-
+
# Assert
assert manager.get(query_hash) is None
-
+
def test_cache_manager_invalidates_entries_by_prefix(self, cache_config):
"""TEST: Cache manager invalidates entries by prefix"""
# Arrange
manager = CacheManager(config=cache_config)
-
+
# Store multiple entries
for i in range(3):
content = f"Content {i} " * 50
manager.store(f"hash_{i}", content)
-
+
# Act
invalidated_count = manager.invalidate_by_prefix(cache_config["prefix"])
-
+
# Assert
assert invalidated_count == 3
assert len(manager.cache) == 0
-
+
def test_cache_manager_calculates_metrics_accurately(self, cache_config):
"""TEST: Cache manager calculates metrics accurately"""
# Arrange
manager = CacheManager(config=cache_config)
-
+
# Create test scenario
content = "Metrics test content " * 50
manager.store("hash_1", content)
manager.store("hash_2", content)
-
+
# Access one entry to create hit/miss data
manager.get("hash_1")
manager.get("nonexistent") # Miss
-
+
# Act
metrics = manager.get_metrics()
-
+
# Assert
assert isinstance(metrics, CacheMetrics)
assert metrics.total_entries == 2
@@ -232,7 +239,7 @@ def test_cache_manager_calculates_metrics_accurately(self, cache_config):
class TestCachePerformance:
"""Test suite for cache performance requirements."""
-
+
@pytest.mark.performance
def test_cache_hit_latency_under_50ms(self, cache_config, performance_timer):
"""TEST: Cache hits achieve target latency under 50ms"""
@@ -241,120 +248,133 @@ def test_cache_hit_latency_under_50ms(self, cache_config, performance_timer):
content = "Performance test content " * 100
query_hash = "perf_test_hash"
manager.store(query_hash, content)
-
+
# Act
with performance_timer() as timer:
result = manager.get(query_hash)
-
+
# Assert
assert result is not None
- assert timer.duration_ms < 50, f"Cache hit took {timer.duration_ms}ms, exceeds 50ms target"
-
+ assert (
+ timer.duration_ms < 50
+ ), f"Cache hit took {timer.duration_ms}ms, exceeds 50ms target"
+
@pytest.mark.performance
- def test_cache_storage_performance_under_10ms(self, cache_config, performance_timer):
+ def test_cache_storage_performance_under_10ms(
+ self, cache_config, performance_timer
+ ):
"""TEST: Cache storage operations complete under 10ms"""
# Arrange
manager = CacheManager(config=cache_config)
content = "Storage performance test " * 100
query_hash = "storage_perf_hash"
-
+
# Act
with performance_timer() as timer:
entry = manager.store(query_hash, content)
-
+
# Assert
assert entry is not None
- assert timer.duration_ms < 10, f"Cache storage took {timer.duration_ms}ms, exceeds 10ms target"
-
+ assert (
+ timer.duration_ms < 10
+ ), f"Cache storage took {timer.duration_ms}ms, exceeds 10ms target"
+
@pytest.mark.performance
def test_concurrent_cache_access_performance(self, cache_config):
"""TEST: Concurrent cache access maintains performance"""
# Arrange
import asyncio
+
manager = CacheManager(config=cache_config)
-
+
# Pre-populate cache
for i in range(10):
content = f"Concurrent test content {i} " * 50
manager.store(f"concurrent_hash_{i}", content)
-
+
async def concurrent_access(hash_id):
start_time = time.perf_counter()
result = manager.get(f"concurrent_hash_{hash_id}")
end_time = time.perf_counter()
return result, (end_time - start_time) * 1000
-
+
# Act
async def run_concurrent_tests():
tasks = [concurrent_access(i % 10) for i in range(50)]
return await asyncio.gather(*tasks)
-
+
results = asyncio.run(run_concurrent_tests())
-
+
# Assert
for result, latency_ms in results:
assert result is not None
assert latency_ms < 50, f"Concurrent access took {latency_ms}ms"
-
+
@pytest.mark.performance
def test_cache_memory_efficiency(self, cache_config):
"""TEST: Cache maintains memory efficiency"""
# Arrange
manager = CacheManager(config=cache_config)
import sys
-
+
# Measure baseline memory
baseline_size = sys.getsizeof(manager)
-
+
# Add many entries
content = "Memory efficiency test " * 50
for i in range(100):
manager.store(f"memory_hash_{i}", content)
-
+
# Act
current_size = sys.getsizeof(manager)
size_per_entry = (current_size - baseline_size) / 100
-
+
# Assert
# Each entry should be reasonably sized (less than 10KB overhead)
- assert size_per_entry < 10240, f"Cache overhead {size_per_entry} bytes per entry too high"
+ assert (
+ size_per_entry < 10240
+ ), f"Cache overhead {size_per_entry} bytes per entry too high"
class TestCacheIntegration:
"""Test suite for cache integration with other components."""
-
- def test_cache_integration_with_anthropic_client(self, mock_anthropic_client, cache_config):
+
+ def test_cache_integration_with_anthropic_client(
+ self, mock_anthropic_client, cache_config
+ ):
"""TEST: Cache integrates properly with Anthropic client"""
# Arrange
manager = CacheManager(config=cache_config)
query = "What is Q1-2025 revenue?"
query_hash = manager.generate_hash(query)
-
+
# Mock cached response
cached_content = "Cached response: Q1-2025 revenue was $1,234,567.89"
manager.store(query_hash, cached_content)
-
+
# Act
from src.cache.manager import get_cached_response
- with patch('src.cache.manager.cache_manager', manager):
+
+ with patch("src.cache.manager.cache_manager", manager):
response = get_cached_response(query, mock_anthropic_client)
-
+
# Assert
assert response is not None
assert "1,234,567.89" in response
# Anthropic client should not be called for cache hit
mock_anthropic_client.messages.create.assert_not_called()
-
+
def test_cache_warming_improves_performance(self, cache_config, benchmark_queries):
"""TEST: Cache warming improves subsequent performance"""
# Arrange
manager = CacheManager(config=cache_config)
-
+
# Act - Warm cache
from src.cache.manager import warm_cache
- with patch('src.cache.manager.cache_manager', manager):
+
+ with patch("src.cache.manager.cache_manager", manager):
warm_cache(benchmark_queries[:5])
-
+
# Measure performance after warming
warmed_times = []
for query in benchmark_queries[:5]:
@@ -362,30 +382,33 @@ def test_cache_warming_improves_performance(self, cache_config, benchmark_querie
query_hash = manager.generate_hash(query)
result = manager.get(query_hash)
end_time = time.perf_counter()
-
+
if result: # Cache hit
warmed_times.append((end_time - start_time) * 1000)
-
+
# Assert
assert len(warmed_times) > 0, "Cache warming should create cached entries"
avg_warmed_time = sum(warmed_times) / len(warmed_times)
- assert avg_warmed_time < 50, f"Warmed cache average {avg_warmed_time}ms exceeds target"
-
+ assert (
+ avg_warmed_time < 50
+ ), f"Warmed cache average {avg_warmed_time}ms exceeds target"
+
def test_cache_invalidation_on_schema_changes(self, cache_config):
"""TEST: Cache invalidation occurs on schema changes"""
# Arrange
manager = CacheManager(config=cache_config)
-
+
# Store entries with old schema version
old_content = "Old schema content " * 50
manager.store("schema_hash_1", old_content)
manager.store("schema_hash_2", old_content)
-
+
# Act - Simulate schema change
from src.cache.manager import invalidate_on_schema_change
- with patch('src.cache.manager.cache_manager', manager):
+
+ with patch("src.cache.manager.cache_manager", manager):
invalidated_count = invalidate_on_schema_change("Database schema updated")
-
+
# Assert
assert invalidated_count == 2
assert manager.get("schema_hash_1") is None
@@ -394,29 +417,29 @@ def test_cache_invalidation_on_schema_changes(self, cache_config):
class TestCacheMetrics:
"""Test suite for cache metrics and monitoring."""
-
+
def test_cache_metrics_calculation_accuracy(self, cache_config):
"""TEST: Cache metrics calculation is accurate"""
# Arrange
manager = CacheManager(config=cache_config)
-
+
# Create test scenario
content = "Metrics calculation test " * 50
-
+
# Store 5 entries
for i in range(5):
manager.store(f"metrics_hash_{i}", content)
-
+
# Access 3 entries (hits) and 2 non-existent (misses)
for i in range(3):
manager.get(f"metrics_hash_{i}")
-
+
for i in range(2):
manager.get(f"nonexistent_hash_{i}")
-
+
# Act
metrics = manager.get_metrics()
-
+
# Assert
assert metrics.total_entries == 5
assert metrics.cache_hits == 3
@@ -424,31 +447,31 @@ def test_cache_metrics_calculation_accuracy(self, cache_config):
assert metrics.total_requests == 5
assert abs(metrics.hit_rate - 60.0) < 0.1 # 3/5 = 60%
assert abs(metrics.miss_rate - 40.0) < 0.1 # 2/5 = 40%
-
+
def test_cache_metrics_cost_calculation(self, cache_config, performance_targets):
"""TEST: Cache metrics calculate cost savings accurately"""
# Arrange
manager = CacheManager(config=cache_config)
content = "Cost calculation test " * 100
-
+
# Store and access entries
manager.store("cost_hash", content)
manager.get("cost_hash") # Hit
manager.get("miss_hash") # Miss
-
+
# Act
metrics = manager.get_metrics()
-
+
# Assert
- assert hasattr(metrics, 'cost_savings')
- assert hasattr(metrics, 'token_efficiency')
+ assert hasattr(metrics, "cost_savings")
+ assert hasattr(metrics, "token_efficiency")
# Verify cost reduction meets targets
expected_hit_savings = performance_targets["cost_reduction_cache_hit"]
expected_miss_savings = performance_targets["cost_reduction_cache_miss"]
-
- assert metrics.cost_savings['cache_hit_reduction'] >= expected_hit_savings
- assert metrics.cost_savings['cache_miss_reduction'] >= expected_miss_savings
-
+
+ assert metrics.cost_savings["cache_hit_reduction"] >= expected_hit_savings
+ assert metrics.cost_savings["cache_miss_reduction"] >= expected_miss_savings
+
def test_cache_metrics_export_to_json(self, cache_config):
"""TEST: Cache metrics export to JSON format"""
# Arrange
@@ -456,12 +479,12 @@ def test_cache_metrics_export_to_json(self, cache_config):
content = "JSON export test " * 50
manager.store("json_hash", content)
manager.get("json_hash")
-
+
# Act
metrics = manager.get_metrics()
json_metrics = metrics.to_json()
parsed_metrics = json.loads(json_metrics)
-
+
# Assert
assert "total_entries" in parsed_metrics
assert "hit_rate" in parsed_metrics
@@ -475,82 +498,83 @@ def test_cache_metrics_export_to_json(self, cache_config):
@pytest.mark.cache
class TestCacheEdgeCases:
"""Test suite for cache edge cases and error conditions."""
-
+
def test_cache_handles_corrupted_entries(self, cache_config):
"""TEST: Cache handles corrupted entries gracefully"""
# Arrange
manager = CacheManager(config=cache_config)
-
+
# Manually corrupt a cache entry
corrupted_entry = CacheEntry(prefix="test", content="Valid content " * 50)
corrupted_entry.content = None # Corrupt the content
manager.cache["corrupted_hash"] = corrupted_entry
-
+
# Act
result = manager.get("corrupted_hash")
-
+
# Assert
assert result is None # Should handle corruption gracefully
assert "corrupted_hash" not in manager.cache # Should remove corrupted entry
-
+
def test_cache_handles_memory_pressure(self, cache_config):
"""TEST: Cache handles memory pressure situations"""
# Arrange
tight_config = cache_config.copy()
tight_config["max_size"] = "100KB" # Small limit
manager = CacheManager(config=tight_config)
-
+
# Act - Try to store many large entries
large_content = "X" * 1000 # 1KB each
stored_count = 0
-
+
for i in range(200): # Try to store 200KB
try:
manager.store(f"pressure_hash_{i}", large_content)
stored_count += 1
except CacheError:
break
-
+
# Assert
assert stored_count < 200 # Should hit limit before storing all
- assert stored_count > 0 # Should store some entries
-
+ assert stored_count > 0 # Should store some entries
+
# Verify cache is still functional
metrics = manager.get_metrics()
assert metrics.total_entries == stored_count
-
+
def test_cache_handles_concurrent_invalidation(self, cache_config):
"""TEST: Cache handles concurrent invalidation safely"""
# Arrange
import threading
+
manager = CacheManager(config=cache_config)
-
+
# Store test entries
content = "Concurrent invalidation test " * 50
for i in range(10):
manager.store(f"concurrent_hash_{i}", content)
-
+
# Act - Concurrent access and invalidation
results = []
-
+
def access_cache():
for i in range(10):
result = manager.get(f"concurrent_hash_{i}")
results.append(result is not None)
-
+
def invalidate_cache():
manager.invalidate_by_prefix(cache_config["prefix"])
-
+
access_thread = threading.Thread(target=access_cache)
invalidate_thread = threading.Thread(target=invalidate_cache)
-
+
access_thread.start()
invalidate_thread.start()
-
+
access_thread.join()
invalidate_thread.join()
-
+
# Assert
# Should not crash, results may vary due to race conditions
assert len(results) == 10
- assert len(manager.cache) == 0 # All entries should be invalidated
\ No newline at end of file
+ assert len(manager.cache) == 0 # All entries should be invalidated
diff --git a/tests/unit/test_cache_security.py b/tests/unit/test_cache_security.py
index 06a9eda..0bd8cfc 100644
--- a/tests/unit/test_cache_security.py
+++ b/tests/unit/test_cache_security.py
@@ -11,8 +11,13 @@
from typing import Dict, Any, List
from src.cache.security import (
- CacheSecurityValidator, ThreatLevel, SecurityThreat, SecurityScanResult,
- get_security_validator, validate_cache_content_security, sanitize_cache_output
+ CacheSecurityValidator,
+ ThreatLevel,
+ SecurityThreat,
+ SecurityScanResult,
+ get_security_validator,
+ validate_cache_content_security,
+ sanitize_cache_output,
)
from src.cache.config import SecurityConfig
from src.core.errors import SecurityError, ValidationError
@@ -20,22 +25,22 @@
class TestSecurityConfig:
"""Test suite for security configuration."""
-
+
def test_security_config_default_values(self):
"""TEST: Security config initializes with correct defaults"""
config = SecurityConfig()
-
+
assert config.enable_input_validation == True
assert config.enable_output_sanitization == True
assert len(config.sensitive_data_patterns) > 0
assert config.max_content_length == 1048576 # 1MB
-
+
def test_security_config_sensitive_patterns(self):
"""TEST: Security config includes comprehensive sensitive data patterns"""
config = SecurityConfig()
-
+
patterns = " ".join(config.sensitive_data_patterns)
-
+
# Check for common sensitive data patterns
assert "password" in patterns.lower()
assert "api" in patterns.lower()
@@ -45,7 +50,7 @@ def test_security_config_sensitive_patterns(self):
class TestCacheSecurityValidator:
"""Test suite for cache security validator."""
-
+
@pytest.fixture
def security_config(self):
"""Create test security configuration."""
@@ -53,105 +58,105 @@ def security_config(self):
enable_input_validation=True,
enable_output_sanitization=True,
sensitive_data_patterns=[
- r'\bpassword\s*[:=]\s*\S+',
- r'\bapi[_-]?key\s*[:=]\s*\S+',
- r'\b\d{3}-\d{2}-\d{4}\b', # SSN
- r'\b\d{4}[-\s]?\d{4}[-\s]?\d{4}[-\s]?\d{4}\b', # Credit card
+ r"\bpassword\s*[:=]\s*\S+",
+ r"\bapi[_-]?key\s*[:=]\s*\S+",
+ r"\b\d{3}-\d{2}-\d{4}\b", # SSN
+ r"\b\d{4}[-\s]?\d{4}[-\s]?\d{4}[-\s]?\d{4}\b", # Credit card
],
- max_content_length=10000
+ max_content_length=10000,
)
-
+
@pytest.fixture
def validator(self, security_config):
"""Create security validator for testing."""
return CacheSecurityValidator(security_config)
-
+
def test_validator_initialization(self, security_config):
"""TEST: Security validator initializes correctly"""
validator = CacheSecurityValidator(security_config)
-
+
assert validator.config == security_config
assert len(validator._compiled_patterns) > 0
assert "sql_injection" in validator._injection_patterns
assert "xss_injection" in validator._injection_patterns
-
+
def test_validator_initialization_with_invalid_patterns(self):
"""TEST: Validator handles invalid regex patterns gracefully"""
config = SecurityConfig(
sensitive_data_patterns=["[invalid_regex"] # Invalid regex
)
-
+
# Should not raise exception
validator = CacheSecurityValidator(config)
assert validator.config == config
-
+
def test_validator_initialization_without_config(self):
"""TEST: Validator loads default config when none provided"""
- with patch('src.cache.security.load_security_config_from_env') as mock_load:
+ with patch("src.cache.security.load_security_config_from_env") as mock_load:
mock_config = SecurityConfig()
mock_load.return_value = mock_config
-
+
validator = CacheSecurityValidator()
assert validator.config == mock_config
-
+
def test_input_validation_passes_clean_content(self, validator):
"""TEST: Input validation passes clean content"""
clean_content = "This is clean content without any threats."
-
+
# Should not raise exception
validator.validate_input(clean_content, "test_source")
-
+
def test_input_validation_blocks_sensitive_data(self, validator):
"""TEST: Input validation blocks content with sensitive data"""
sensitive_content = "User password: secret123 and api_key: abc123def"
-
+
with pytest.raises(SecurityError):
validator.validate_input(sensitive_content, "test_source")
-
+
def test_input_validation_blocks_sql_injection(self, validator):
"""TEST: Input validation blocks SQL injection attempts"""
sql_injection = "SELECT * FROM users; DROP TABLE users; --"
-
+
with pytest.raises(SecurityError):
validator.validate_input(sql_injection, "test_source")
-
+
def test_input_validation_blocks_xss_injection(self, validator):
"""TEST: Input validation blocks XSS injection attempts"""
xss_injection = ""
-
+
with pytest.raises(SecurityError):
validator.validate_input(xss_injection, "test_source")
-
+
def test_input_validation_blocks_command_injection(self, validator):
"""TEST: Input validation blocks command injection attempts"""
command_injection = "normal content; rm -rf / ;"
-
+
with pytest.raises(SecurityError):
validator.validate_input(command_injection, "test_source")
-
+
def test_input_validation_content_too_large(self, validator):
"""TEST: Input validation blocks oversized content"""
large_content = "X" * (validator.config.max_content_length + 1)
-
+
with pytest.raises(ValidationError):
validator.validate_input(large_content, "test_source")
-
+
def test_input_validation_disabled(self, validator):
"""TEST: Input validation can be disabled"""
validator.config.enable_input_validation = False
sensitive_content = "password: secret123"
-
+
# Should not raise exception when disabled
validator.validate_input(sensitive_content, "test_source")
-
+
def test_input_validation_high_threats_allowed(self, validator):
"""TEST: Input validation allows limited high threats"""
# Single high threat should be allowed with warning
content_with_single_threat = "Some content with "
-
+
# Should not raise exception for single high threat
validator.validate_input(content_with_single_threat, "test_source")
-
+
def test_input_validation_multiple_high_threats_blocked(self, validator):
"""TEST: Input validation blocks multiple high threats"""
content_with_multiple_threats = """
@@ -160,240 +165,260 @@ def test_input_validation_multiple_high_threats_blocked(self, validator):
javascript:alert('xss')
"""
-
+
with pytest.raises(SecurityError):
validator.validate_input(content_with_multiple_threats, "test_source")
-
+
def test_output_sanitization_removes_scripts(self, validator):
"""TEST: Output sanitization removes script tags"""
malicious_output = "Safe content more content"
-
+
sanitized = validator.sanitize_output(malicious_output)
-
+
assert ""
-
+
sanitized = validator.sanitize_output(malicious_content)
-
+
# Should return original content when disabled
assert sanitized == malicious_content
-
+
def test_output_sanitization_handles_errors(self, validator):
"""TEST: Output sanitization handles errors gracefully"""
# Mock re.sub to raise exception
- with patch('re.sub', side_effect=Exception("Regex error")):
+ with patch("re.sub", side_effect=Exception("Regex error")):
content = "test content"
sanitized = validator.sanitize_output(content)
-
+
# Should return original content on error
assert sanitized == content
-
+
def test_content_scan_detects_sensitive_data(self, validator):
"""TEST: Content scan detects sensitive data patterns"""
sensitive_content = "User SSN: 123-45-6789 and password: secret"
-
+
result = validator.scan_content(sensitive_content)
-
+
assert result.threats_detected > 0
assert result.overall_risk_level == ThreatLevel.CRITICAL
assert "sensitive_data" in result.threat_breakdown
assert len(result.threats) > 0
assert any(threat.threat_type == "sensitive_data" for threat in result.threats)
-
+
def test_content_scan_detects_sql_injection(self, validator):
"""TEST: Content scan detects SQL injection patterns"""
sql_injection = "'; DROP TABLE users; --"
-
+
result = validator.scan_content(sql_injection)
-
+
assert result.threats_detected > 0
assert "sql_injection" in result.threat_breakdown
assert any(threat.threat_type == "sql_injection" for threat in result.threats)
-
+
def test_content_scan_detects_xss_injection(self, validator):
"""TEST: Content scan detects XSS injection patterns"""
xss_content = ""
-
+
result = validator.scan_content(xss_content)
-
+
assert result.threats_detected > 0
assert "xss_injection" in result.threat_breakdown
-
+
def test_content_scan_detects_command_injection(self, validator):
"""TEST: Content scan detects command injection patterns"""
command_injection = "file.txt; rm -rf /"
-
+
result = validator.scan_content(command_injection)
-
+
assert result.threats_detected > 0
assert "command_injection" in result.threat_breakdown
-
+
def test_content_scan_detects_path_traversal(self, validator):
"""TEST: Content scan detects path traversal patterns"""
path_traversal = "../../../etc/passwd"
-
+
result = validator.scan_content(path_traversal)
-
+
assert result.threats_detected > 0
assert "path_traversal" in result.threat_breakdown
-
+
def test_content_scan_detects_suspicious_encoding(self, validator):
"""TEST: Content scan detects suspicious encoding patterns"""
# Excessive URL encoding
suspicious_content = "data=" + "%20%20%20" * 10 # Lots of encoded spaces
-
+
result = validator.scan_content(suspicious_content)
-
+
assert "suspicious_encoding" in result.threat_breakdown
-
+
def test_content_scan_detects_unicode_bypass(self, validator):
"""TEST: Content scan detects Unicode bypass attempts"""
unicode_content = "\\u0041\\u0042\\u0043" * 5 # Excessive Unicode
-
+
result = validator.scan_content(unicode_content)
-
+
assert "unicode_bypass" in result.threat_breakdown
-
+
def test_content_scan_detects_base64_payload(self, validator):
"""TEST: Content scan detects suspicious base64 content"""
import base64
-
+
# Create base64 encoded script
malicious_script = "script>alert('xss') 0
-
+
def test_content_scan_error_handling(self, validator):
"""TEST: Content scan handles errors gracefully"""
# Mock pattern matching to raise exception
- with patch.object(validator, '_compiled_patterns', {'bad_pattern': Exception("Test error")}):
+ with patch.object(
+ validator, "_compiled_patterns", {"bad_pattern": Exception("Test error")}
+ ):
result = validator.scan_content("test content")
-
+
# Should handle error and return result
assert result.scanned_content_length > 0
assert "scan_error" in result.threat_breakdown
-
+
def test_threat_level_determination_sql_injection(self, validator):
"""TEST: Threat level determination for SQL injection"""
# Critical SQL injection
- critical_level = validator._determine_injection_threat_level("sql_injection", ["DROP TABLE"])
+ critical_level = validator._determine_injection_threat_level(
+ "sql_injection", ["DROP TABLE"]
+ )
assert critical_level == ThreatLevel.CRITICAL
-
+
# High SQL injection
- high_level = validator._determine_injection_threat_level("sql_injection", ["SELECT * FROM"])
+ high_level = validator._determine_injection_threat_level(
+ "sql_injection", ["SELECT * FROM"]
+ )
assert high_level == ThreatLevel.HIGH
-
+
def test_threat_level_determination_xss_injection(self, validator):
"""TEST: Threat level determination for XSS injection"""
# Critical XSS
- critical_level = validator._determine_injection_threat_level("xss_injection", ["Safe content"
-
+
sanitized = sanitize_cache_output(malicious_content)
-
+
assert "