diff --git a/README.md b/README.md index 07ce5a9..e6eb85b 100644 --- a/README.md +++ b/README.md @@ -34,8 +34,6 @@ make dev Additional MCP servers are configured in `agent-chat-cli.config.yaml` and prompts added within the `prompts` folder. -Optionally, MCP servers can be lazy-loaded via chat inference, which is useful if you have many MCP servers or MCP servers with many tools; set `mcp_server_inference: true` to enable it. - ## Development - Install pre-commit hooks via [pre-commit](https://pre-commit.com/) diff --git a/agent-chat-cli.config.yaml b/agent-chat-cli.config.yaml index c58cb4c..270e8b2 100644 --- a/agent-chat-cli.config.yaml +++ b/agent-chat-cli.config.yaml @@ -8,10 +8,6 @@ model: haiku # Enable streaming include_partial_messages: true -# Enable dynamic/lazy MCP server inference. Useful if one has many MCP servers or -# many tools, or is cost conscious about loading everything up front. -mcp_server_inference: false - # Global tool restrictions disallowed_tools: ["Bash"] diff --git a/src/agent_chat_cli/components/header.py b/src/agent_chat_cli/components/header.py index b22adfd..8395397 100644 --- a/src/agent_chat_cli/components/header.py +++ b/src/agent_chat_cli/components/header.py @@ -4,6 +4,7 @@ from agent_chat_cli.components.spacer import Spacer from agent_chat_cli.utils.config import load_config +from agent_chat_cli.utils.mcp_server_status import MCPServerStatus class Header(Widget): @@ -18,7 +19,8 @@ def compose(self) -> ComposeResult: ) yield Label( - f"[dim]Available MCP Servers:[/dim] {mcp_servers}", + f"[dim]Available MCP Servers: {mcp_servers}[/dim]", + id="header-mcp-servers", ) if agents: @@ -35,3 +37,29 @@ def compose(self) -> ComposeResult: id="header-instructions", classes="header-instructions", ) + + def on_mount(self) -> None: + MCPServerStatus.subscribe(self._handle_mcp_server_status) + + def on_unmount(self) -> None: + MCPServerStatus.unsubscribe(self._handle_mcp_server_status) + + def _handle_mcp_server_status(self) -> None: + config = load_config() + server_names = list(config.mcp_servers.keys()) + + server_parts = [] + for name in server_names: + is_connected = MCPServerStatus.is_connected(name) + + if is_connected: + server_parts.append(f"{name}") + else: + # Error connecting to MCP + server_parts.append(f"[#ffa2dc][strike]{name}[/strike][/]") + + mcp_servers = ", ".join(server_parts) + markup = f"[dim]Available MCP Servers:[/dim] {mcp_servers}" + + label = self.query_one("#header-mcp-servers", Label) + label.update(markup) diff --git a/src/agent_chat_cli/core/agent_loop.py b/src/agent_chat_cli/core/agent_loop.py index 76cb65d..2dc8d6f 100644 --- a/src/agent_chat_cli/core/agent_loop.py +++ b/src/agent_chat_cli/core/agent_loop.py @@ -8,6 +8,7 @@ ) from claude_agent_sdk.types import ( AssistantMessage, + Message, SystemMessage, TextBlock, ToolUseBlock, @@ -23,8 +24,8 @@ get_sdk_config, ) from agent_chat_cli.utils.enums import AgentMessageType, ContentType, ControlCommand -from agent_chat_cli.core.mcp_inference import infer_mcp_servers from agent_chat_cli.utils.logger import log_json +from agent_chat_cli.utils.mcp_server_status import MCPServerStatus if TYPE_CHECKING: from agent_chat_cli.app import AgentChatCLIApp @@ -46,7 +47,6 @@ def __init__( self.config = load_config() self.session_id = session_id self.available_servers = get_available_servers() - self.inferred_servers: set[str] = set() self.client: ClaudeSDKClient @@ -58,78 +58,33 @@ def __init__( self.interrupting = False async def start(self) -> None: - # Boot MCP servers lazily - if self.config.mcp_server_inference: - await self._initialize_client(mcp_servers={}) - else: - # Boot MCP servers all at once - mcp_servers = { - name: config.model_dump() - for name, config in self.available_servers.items() - } - - await self._initialize_client(mcp_servers=mcp_servers) + mcp_servers = { + name: config.model_dump() for name, config in self.available_servers.items() + } + + await self._initialize_client(mcp_servers=mcp_servers) self._running = True while self._running: user_input = await self.query_queue.get() - # Check for new convo flags if isinstance(user_input, ControlCommand): if user_input == ControlCommand.NEW_CONVERSATION: - self.inferred_servers.clear() - - await self.client.disconnect() - - # Reset MCP servers based on config settings - if self.config.mcp_server_inference: - await self._initialize_client(mcp_servers={}) - else: - mcp_servers = { - name: config.model_dump() - for name, config in self.available_servers.items() - } - - await self._initialize_client(mcp_servers=mcp_servers) - continue - - # Infer MCP servers based on user messages in chat - if self.config.mcp_server_inference: - inference_result = await infer_mcp_servers( - user_message=user_input, - available_servers=self.available_servers, - inferred_servers=self.inferred_servers, - session_id=self.session_id, - ) - - # If there are new results, create an updated mcp_server list - if inference_result["new_servers"]: - server_list = ", ".join(inference_result["new_servers"]) - - self.app.actions.post_system_message( - f"Connecting to {server_list}..." - ) - - await asyncio.sleep(0.1) - - # If there's updates, we reinitialize the agent SDK (with the - # persisted session_id from the turn, stored in the instance) await self.client.disconnect() mcp_servers = { name: config.model_dump() - for name, config in inference_result["selected_servers"].items() + for name, config in self.available_servers.items() } await self._initialize_client(mcp_servers=mcp_servers) + continue self.interrupting = False - # Send query await self.client.query(user_input) - # Wait for messages from Claude async for message in self.client.receive_response(): if self.interrupting: continue @@ -154,7 +109,7 @@ async def _initialize_client(self, mcp_servers: dict) -> None: await self.client.connect() - async def _handle_message(self, message: Any) -> None: + async def _handle_message(self, message: Message) -> None: if isinstance(message, SystemMessage): log_json(message.data) @@ -164,6 +119,9 @@ async def _handle_message(self, message: Any) -> None: # When initializing the chat, we store the session_id for later self.session_id = message.data["session_id"] + # Report connected / error status back to UI + MCPServerStatus.update(message.data["mcp_servers"]) + # Handle streaming messages if hasattr(message, "event"): event = message.event # type: ignore[attr-defined] @@ -215,7 +173,7 @@ async def _can_use_tool( self, tool_name: str, tool_input: dict[str, Any], - context: ToolPermissionContext, + _context: ToolPermissionContext, ) -> PermissionResult: """Agent SDK handler for tool use permissions""" diff --git a/src/agent_chat_cli/core/mcp_inference.py b/src/agent_chat_cli/core/mcp_inference.py deleted file mode 100644 index eaa9f95..0000000 --- a/src/agent_chat_cli/core/mcp_inference.py +++ /dev/null @@ -1,106 +0,0 @@ -from textwrap import dedent -from typing import Any - -from claude_agent_sdk import ClaudeSDKClient, ClaudeAgentOptions -from claude_agent_sdk.types import ResultMessage - -from agent_chat_cli.utils.config import MCPServerConfig - -_inference_client: ClaudeSDKClient | None = None - - -async def _get_inference_client( - available_servers: dict[str, MCPServerConfig], -) -> ClaudeSDKClient: - global _inference_client - - if _inference_client is not None: - return _inference_client - - server_descriptions = "\n".join( - [ - f"- {name}: {config.description}" - for name, config in available_servers.items() - ] - ) - - system_prompt = dedent( - f""" - You are an MCP server inference engine. Based on the user's message, determine which MCP servers are needed to fulfill the request. - - Available MCP servers: - {server_descriptions} - - Return ONLY the names of servers that are likely needed for this request. If no specific servers are needed, return an empty array. - - Examples: - - "Show me my GitHub issues" → ["github"] - - "Open a browser tab" → ["chrome"] - - "What's the weather?" → [] - - "Search my Notion workspace and open related GitHub PRs" → ["notion", "github"] - """ - ).strip() - - inference_options = ClaudeAgentOptions( - model="haiku", - output_format={ - "type": "json_schema", - "schema": { - "type": "object", - "properties": { - "servers": { - "type": "array", - "items": {"type": "string"}, - "description": "List of MCP server names to connect to", - } - }, - "required": ["servers"], - }, - }, - system_prompt=system_prompt, - mcp_servers={}, - ) - - _inference_client = ClaudeSDKClient(options=inference_options) - - await _inference_client.connect() - - return _inference_client - - -async def infer_mcp_servers( - user_message: str, - available_servers: dict[str, MCPServerConfig], - inferred_servers: set[str], - session_id: str | None = None, -) -> dict[str, Any]: - if not available_servers: - return {"selected_servers": {}, "new_servers": []} - - client = await _get_inference_client(available_servers) - - selected_server_names: list[str] = [] - - await client.query(user_message) - - async for message in client.receive_response(): - if isinstance(message, ResultMessage): - if hasattr(message, "structured_output") and message.structured_output: - selected_server_names = message.structured_output.get("servers", []) - - new_servers = [ - name for name in selected_server_names if name not in inferred_servers - ] - - inferred_servers.update(selected_server_names) - - selected_servers = { - name: available_servers[name] - for name in selected_server_names - if name in available_servers - } - - return { - "selected_servers": selected_servers, - "new_servers": new_servers, - } diff --git a/src/agent_chat_cli/core/styles.tcss b/src/agent_chat_cli/core/styles.tcss index bb7eb66..5675daf 100644 --- a/src/agent_chat_cli/core/styles.tcss +++ b/src/agent_chat_cli/core/styles.tcss @@ -10,7 +10,7 @@ MarkdownH6 { } Screen { - background: transparent; + background: #222; } Spacer { diff --git a/src/agent_chat_cli/docs/architecture.md b/src/agent_chat_cli/docs/architecture.md index 85754c1..268bfb1 100644 --- a/src/agent_chat_cli/docs/architecture.md +++ b/src/agent_chat_cli/docs/architecture.md @@ -26,19 +26,10 @@ Manages the conversation loop with Claude SDK: - Parses SDK messages into structured AgentMessage objects - Emits AgentMessageType events (STREAM_EVENT, ASSISTANT, RESULT) - Manages session persistence via session_id -- Supports dynamic MCP server inference and loading - Implements `_can_use_tool` callback for interactive tool permission requests - Uses `permission_lock` (asyncio.Lock) to serialize parallel permission requests - Manages `permission_response_queue` for user responses to tool permission prompts -#### MCP Server Inference (`system/mcp_inference.py`) -Intelligently determines which MCP servers are needed for each query: -- Uses a persistent Haiku client for fast inference (~1-3s after initial boot) -- Analyzes user queries to infer required servers -- Maintains a cached set of inferred servers across conversation -- Returns only newly needed servers to minimize reconnections -- Can be disabled via `mcp_server_inference: false` config option - #### Message Bus (`system/message_bus.py`) Routes agent messages to appropriate UI components: - Handles streaming text updates @@ -76,7 +67,7 @@ Loads and validates YAML configuration: ## Data Flow -### Standard Query Flow (with MCP Inference enabled) +### Standard Query Flow ``` User Input @@ -87,16 +78,7 @@ MessagePosted event → ChatHistory (immediate UI update) ↓ Actions.query(user_input) → AgentLoop.query_queue.put() ↓ -AgentLoop: MCP Server Inference (if enabled) - ↓ -infer_mcp_servers(user_message) → Haiku query - ↓ -If new servers needed: - - Post SYSTEM message ("Connecting to [servers]...") - - Disconnect client - - Reconnect with new servers (preserving session_id) - ↓ -Claude SDK (streaming response with connected MCP tools) +Claude SDK (all enabled servers pre-connected at startup) ↓ AgentLoop._handle_message ↓ @@ -109,22 +91,6 @@ Match on AgentMessageType: - RESULT → Reset thinking indicator ``` -### Query Flow (with MCP Inference disabled) - -``` -User Input - ↓ -UserInput.on_input_submitted - ↓ -MessagePosted event → ChatHistory (immediate UI update) - ↓ -Actions.query(user_input) → AgentLoop.query_queue.put() - ↓ -Claude SDK (all servers pre-connected at startup) - ↓ -[Same as above from _handle_message onwards] -``` - ### Control Commands Flow ``` User Action (ESC, Ctrl+N, "clear", "exit") @@ -188,36 +154,12 @@ Configuration is loaded from `agent-chat-cli.config.yaml`: - **system_prompt**: Base system prompt (supports file paths) - **model**: Claude model to use - **include_partial_messages**: Enable streaming responses (default: true) -- **mcp_server_inference**: Enable dynamic MCP server inference (default: true) - - When `true`: App boots instantly without MCP servers, connects only when needed - - When `false`: All enabled MCP servers load at startup (traditional behavior) - **mcp_servers**: MCP server configurations (filtered by enabled flag) - **agents**: Named agent configurations - **disallowed_tools**: Tool filtering - **permission_mode**: Permission handling mode -MCP server prompts are automatically appended to the system prompt. - -### MCP Server Inference - -When `mcp_server_inference: true` (default): - -1. **Fast Boot**: App starts without connecting to any MCP servers -2. **Smart Detection**: Before each query, Haiku analyzes which servers are needed -3. **Dynamic Loading**: Only connects to newly required servers -4. **Session Preservation**: Maintains conversation history when reconnecting with new servers -5. **Performance**: ~1-3s inference latency after initial boot (first query ~8-12s) - -Example config: -```yaml -mcp_server_inference: true # or false to disable - -mcp_servers: - github: - description: "Search code, PRs, issues" - enabled: true - # ... rest of config -``` +MCP server prompts are automatically appended to the system prompt. All enabled MCP servers are loaded at startup. ## Tool Permission System diff --git a/src/agent_chat_cli/utils/config.py b/src/agent_chat_cli/utils/config.py index 6a113f7..e780a85 100644 --- a/src/agent_chat_cli/utils/config.py +++ b/src/agent_chat_cli/utils/config.py @@ -24,7 +24,6 @@ class AgentChatConfig(BaseModel): system_prompt: str model: str include_partial_messages: bool = True - mcp_server_inference: bool = True agents: dict[str, AgentDefinition] = Field(default_factory=dict) mcp_servers: dict[str, MCPServerConfig] = Field(default_factory=dict) disallowed_tools: list[str] = Field(default_factory=list) @@ -109,6 +108,4 @@ def get_available_servers( def get_sdk_config(config: AgentChatConfig) -> dict: - sdk_config = config.model_dump() - sdk_config.pop("mcp_server_inference", None) - return sdk_config + return config.model_dump() diff --git a/src/agent_chat_cli/utils/mcp_server_status.py b/src/agent_chat_cli/utils/mcp_server_status.py new file mode 100644 index 0000000..64de61e --- /dev/null +++ b/src/agent_chat_cli/utils/mcp_server_status.py @@ -0,0 +1,34 @@ +from typing import Any, Callable + + +class MCPServerStatus: + # After the first query is sent, claude agent sdk sends back an init payload which + # with various statuses. In it, we can see mcp connection success or failure. + _mcp_servers: list[dict[str, Any]] = [] + + # Register component callbacks that need access to the status + _callbacks: list[Callable[[], None]] = [] + + @classmethod + def update(cls, mcp_servers: list[dict[str, Any]]) -> None: + cls._mcp_servers = mcp_servers + + for callback in cls._callbacks: + callback() + + @classmethod + def is_connected(cls, server_name: str) -> bool: + for server in cls._mcp_servers: + if server.get("name") == server_name: + return server.get("status") == "connected" + + return False + + @classmethod + def subscribe(cls, callback: Callable[[], None]) -> None: + cls._callbacks.append(callback) + + @classmethod + def unsubscribe(cls, callback: Callable[[], None]) -> None: + if callback in cls._callbacks: + cls._callbacks.remove(callback) diff --git a/tests/test_mcp_inference.py b/tests/test_mcp_inference.py deleted file mode 100644 index f1014db..0000000 --- a/tests/test_mcp_inference.py +++ /dev/null @@ -1,60 +0,0 @@ -import asyncio -import time -from dotenv import load_dotenv - -from agent_chat_cli.utils.config import get_available_servers -from agent_chat_cli.core.mcp_inference import infer_mcp_servers, _inference_client - -load_dotenv() - -# TODO: This can be deleted, but keeping here to check if speed is related to anthropics -# servers or something else - - -async def test_inference(): - """ - Tests the overall return times for MCP inference - To run: uv run python tests/test_mcp_inference.py - """ - - print("=== MCP Server Inference Test ===\n") - - available_servers = get_available_servers() - print(f"Available servers: {list(available_servers.keys())}\n") - - inferred_servers: set[str] = set() - - test_queries = [ - "Show me my GitHub issues", - "Open a browser tab", - "What's the weather?", - "Search my GitHub for code related to authentication", - ] - - for user_message in test_queries: - print(f"Query: {user_message}") - - start_time = time.time() - - result = await infer_mcp_servers( - user_message=user_message, - available_servers=available_servers, - inferred_servers=inferred_servers, - ) - - elapsed = time.time() - start_time - - print(f"Selected servers: {list(result['selected_servers'].keys())}") - print(f"New servers: {result['new_servers']}") - print(f"Time: {elapsed:.2f}s") - print(f"Inferred servers so far: {inferred_servers}") - print("-" * 50 + "\n") - - if _inference_client: - print("Disconnecting inference client...") - await _inference_client.disconnect() - print("Done!") - - -if __name__ == "__main__": - asyncio.run(test_inference())