From e84e00dba59f657ac91d5a8d7882f7524660a119 Mon Sep 17 00:00:00 2001 From: Christopher Pappas Date: Mon, 10 Nov 2025 16:05:03 -0800 Subject: [PATCH] refactor: simplify agent loop / abort --- src/__tests__/store.test.tsx | 2 +- src/components/AgentChat.tsx | 24 +++-- src/components/UserInput.tsx | 6 ++ src/hooks/useAgent.ts | 96 +++++++++++------ src/mcp/getAgentStatus.ts | 13 +-- src/mcp/runStandaloneAgentLoop.ts | 13 +-- src/store.ts | 6 +- src/utils/MessageQueue.ts | 7 ++ src/utils/mcpServerSelectionAgent.ts | 12 +-- src/utils/runAgentLoop.ts | 149 ++++++++++----------------- 10 files changed, 169 insertions(+), 159 deletions(-) diff --git a/src/__tests__/store.test.tsx b/src/__tests__/store.test.tsx index a465eb4..b48c866 100644 --- a/src/__tests__/store.test.tsx +++ b/src/__tests__/store.test.tsx @@ -48,7 +48,7 @@ describe("Store", () => { expect(getState().sessionId).toBeUndefined() expect(getState().stats).toBeUndefined() expect(getState().pendingToolPermission).toBeUndefined() - expect(getState().abortController).toBeInstanceOf(AbortController) + expect(getState().abortController).toBeUndefined() }) test("should have MessageQueue instance", () => { diff --git a/src/components/AgentChat.tsx b/src/components/AgentChat.tsx index c9c099d..67388ed 100644 --- a/src/components/AgentChat.tsx +++ b/src/components/AgentChat.tsx @@ -1,5 +1,3 @@ -import { Box, Text, useInput } from "ink" -import Spinner from "ink-spinner" import { ChatHeader } from "components/ChatHeader" import { Markdown } from "components/Markdown" import { Stats } from "components/Stats" @@ -8,6 +6,8 @@ import { ToolUses } from "components/ToolUses" import { UserInput } from "components/UserInput" import { useAgent } from "hooks/useAgent" import { useMcpClient } from "hooks/useMcpClient" +import { Box, Text, useInput } from "ink" +import Spinner from "ink-spinner" import { AgentStore } from "store" export const AgentChat: React.FC = () => { @@ -92,21 +92,23 @@ export const AgentChat: React.FC = () => { case state.isProcessing: { return ( - - - + <> + + + + + {" Agent is thinking..."} - {" Agent is thinking..."} - - ) - } - default: { - return + + + ) } } })()} + + ) diff --git a/src/components/UserInput.tsx b/src/components/UserInput.tsx index fbf81e4..b680026 100644 --- a/src/components/UserInput.tsx +++ b/src/components/UserInput.tsx @@ -35,6 +35,12 @@ export const UserInput: React.FC = () => { }) actions.sendMessage(value) + + // Slight delay just in case user has aborted request via second message + setTimeout(() => { + actions.setIsProcessing(true) + }, 100) + reset() } diff --git a/src/hooks/useAgent.ts b/src/hooks/useAgent.ts index 34bb9f9..2c02bbf 100644 --- a/src/hooks/useAgent.ts +++ b/src/hooks/useAgent.ts @@ -1,48 +1,64 @@ -import { useEffect, useRef } from "react" +import { useCallback, useEffect, useRef } from "react" import { AgentStore } from "store" +import { log } from "utils/logger" import { messageTypes, runAgentLoop } from "utils/runAgentLoop" export function useAgent() { const messageQueue = AgentStore.useStoreState((state) => state.messageQueue) - const sessionId = AgentStore.useStoreState((state) => state.sessionId) const config = AgentStore.useStoreState((state) => state.config) - const abortController = AgentStore.useStoreState( - (state) => state.abortController - ) const actions = AgentStore.useStoreActions((actions) => actions) const currentAssistantMessageRef = useRef("") - const abortControllerRef = useRef(abortController) + const sessionIdRef = useRef(undefined) + const abortControllerRef = useRef(undefined) + const connectedServersRef = useRef>(new Set()) - // Update ref when abort controller changes - abortControllerRef.current = abortController + const runQuery = useCallback( + async (userMessage: string) => { + if (abortControllerRef.current) { + log("[useAgent] Aborting existing query for new message:", userMessage) - useEffect(() => { - const streamEnabled = config.stream ?? false - - const runAgent = async () => { - const { agentLoop } = await runAgentLoop({ - messageQueue, - sessionId, - config, - abortControllerRef, - onToolPermissionRequest: (toolName, input) => { - actions.setPendingToolPermission({ toolName, input }) - }, - onServerConnection: (status) => { - actions.addChatHistoryEntry({ - type: "message", - role: "system", - content: status, - }) - }, - setIsProcessing: actions.setIsProcessing, - }) + // When a new message comes in, always abort the old one and start fresh + abortControllerRef.current.abort() + } + + // Create fresh abort controller for this query + const abortController = new AbortController() + abortControllerRef.current = abortController + actions.setAbortController(abortController) + + const streamEnabled = config.stream ?? false try { + const agentLoop = runAgentLoop({ + abortController, + config, + connectedServers: connectedServersRef.current, + messageQueue, + onToolPermissionRequest: (toolName, input) => { + actions.setPendingToolPermission({ toolName, input }) + }, + onServerConnection: (status) => { + actions.addChatHistoryEntry({ + type: "message", + role: "system", + content: status, + }) + }, + sessionId: sessionIdRef.current, + setIsProcessing: actions.setIsProcessing, + userMessage, + }) + for await (const message of agentLoop) { + if (abortController.signal.aborted) { + log("[useAgent] Query was aborted, stopping message processing") + return + } + switch (true) { case message.type === messageTypes.SYSTEM && message.subtype === messageTypes.INIT: { + sessionIdRef.current = message.session_id actions.setSessionId(message.session_id) actions.handleMcpServerStatus(message.mcp_servers) @@ -127,6 +143,12 @@ export function useAgent() { } } } catch (error) { + if (error instanceof Error && error.name === "AbortError") { + actions.setIsProcessing(false) + return + } + + // Handle other errors if ( error instanceof Error && !error.message.includes("process aborted by user") @@ -136,8 +158,18 @@ export function useAgent() { actions.setIsProcessing(false) } - } + }, + [config, messageQueue, actions] + ) - runAgent() - }, []) + // Start listening for new messages from input + useEffect(() => { + const unsubscribe = messageQueue.subscribe((userMessage) => { + setTimeout(() => { + runQuery(userMessage) + }, 0) + }) + + return unsubscribe + }, [messageQueue, runQuery]) } diff --git a/src/mcp/getAgentStatus.ts b/src/mcp/getAgentStatus.ts index eb04b6d..451192a 100644 --- a/src/mcp/getAgentStatus.ts +++ b/src/mcp/getAgentStatus.ts @@ -6,16 +6,17 @@ import { messageTypes, runAgentLoop } from "utils/runAgentLoop" export const getAgentStatus = async (mcpServer?: McpServer) => { const config = await loadConfig() const messageQueue = new MessageQueue() + const abortController = new AbortController() + const connectedServers = new Set() - const { agentLoop } = await runAgentLoop({ - messageQueue, + const agentLoop = runAgentLoop({ + abortController, config, + connectedServers, + messageQueue, + userMessage: "status", }) - await new Promise((resolve) => setTimeout(resolve, 0)) - - messageQueue.sendMessage("status") - for await (const message of agentLoop) { if ( message.type === messageTypes.SYSTEM && diff --git a/src/mcp/runStandaloneAgentLoop.ts b/src/mcp/runStandaloneAgentLoop.ts index 2836f31..0d5978e 100644 --- a/src/mcp/runStandaloneAgentLoop.ts +++ b/src/mcp/runStandaloneAgentLoop.ts @@ -25,12 +25,17 @@ export const runStandaloneAgentLoop = async ({ const messageQueue = new MessageQueue() const streamEnabled = config.stream ?? false - const { agentLoop, connectedServers } = await runAgentLoop({ + const connectedServers = existingConnectedServers ?? new Set() + const abortController = new AbortController() + + const agentLoop = runAgentLoop({ + abortController, additionalSystemPrompt, config, - existingConnectedServers, + connectedServers, messageQueue, sessionId, + userMessage: prompt, onServerConnection: async (status) => { await mcpServer.sendLoggingMessage({ level: "info", @@ -42,10 +47,6 @@ export const runStandaloneAgentLoop = async ({ }, }) - await new Promise((resolve) => setTimeout(resolve, 0)) - - messageQueue.sendMessage(prompt) - let finalResponse = "" let assistantMessage = "" diff --git a/src/store.ts b/src/store.ts index 51b8011..d4000b1 100644 --- a/src/store.ts +++ b/src/store.ts @@ -112,7 +112,7 @@ export interface StoreModel { } export const AgentStore = createContextStore({ - abortController: new AbortController(), + abortController: undefined, chatHistory: [], config: null as unknown as AgentChatConfig, currentAssistantMessage: "", @@ -142,7 +142,9 @@ export const AgentStore = createContextStore({ // Actions abortRequest: action((state) => { state.abortController?.abort() - state.abortController = new AbortController() + state.abortController = undefined + state.currentAssistantMessage = "" + state.stats = "User aborted the request." state.isProcessing = false }), diff --git a/src/utils/MessageQueue.ts b/src/utils/MessageQueue.ts index 9dec3e9..9950269 100644 --- a/src/utils/MessageQueue.ts +++ b/src/utils/MessageQueue.ts @@ -33,4 +33,11 @@ export class MessageQueue extends EventEmitter { hasPendingRequests(): boolean { return this.listenerCount("message") > 0 } + + subscribe(callback: (message: string) => void): () => void { + this.on("message", callback) + return () => { + this.off("message", callback) + } + } } diff --git a/src/utils/mcpServerSelectionAgent.ts b/src/utils/mcpServerSelectionAgent.ts index 2a8073e..d663152 100644 --- a/src/utils/mcpServerSelectionAgent.ts +++ b/src/utils/mcpServerSelectionAgent.ts @@ -12,7 +12,7 @@ import { messageTypes } from "./runAgentLoop" interface SelectMcpServersOptions { abortController?: AbortController agents?: Record - alreadyConnectedServers?: Set + connectedServers?: Set enabledMcpServers: Record | undefined onServerConnection?: (status: string) => void sessionId?: string @@ -22,7 +22,7 @@ interface SelectMcpServersOptions { export const selectMcpServers = async ({ abortController, agents, - alreadyConnectedServers = new Set(), + connectedServers = new Set(), enabledMcpServers, onServerConnection, sessionId, @@ -38,7 +38,7 @@ export const selectMcpServers = async ({ log( "[mcpServerSelectionAgent] Already connected:", - Array.from(alreadyConnectedServers).join(", ") || "none" + Array.from(connectedServers).join(", ") || "none" ) const serverCapabilities = Object.entries(enabledMcpServers) @@ -159,7 +159,7 @@ Examples: log("[mcpServerSelectionAgent] Selected MCP servers:", selectedServers) const newServers = selectedServers.filter( - (server) => !alreadyConnectedServers.has(server.toLowerCase()) + (server) => !connectedServers.has(server.toLowerCase()) ) if (newServers.length > 0) { @@ -172,7 +172,7 @@ Examples: } const allServers = new Set([ - ...Array.from(alreadyConnectedServers), + ...Array.from(connectedServers), ...selectedServers, ]) @@ -212,7 +212,7 @@ Examples: // Update the connected servers set with new servers newServers.forEach((server) => { - alreadyConnectedServers.add(server.toLowerCase()) + connectedServers.add(server.toLowerCase()) }) return { diff --git a/src/utils/runAgentLoop.ts b/src/utils/runAgentLoop.ts index 9d2e0b7..dae5334 100644 --- a/src/utils/runAgentLoop.ts +++ b/src/utils/runAgentLoop.ts @@ -1,7 +1,6 @@ -import { query, type SDKUserMessage } from "@anthropic-ai/claude-agent-sdk" +import { query } from "@anthropic-ai/claude-agent-sdk" import type { AgentChatConfig } from "store" import { createCanUseTool } from "utils/canUseTool" -import { createSDKAgents } from "utils/createAgent" import { getEnabledMcpServers } from "utils/getEnabledMcpServers" import { buildSystemPrompt } from "utils/getPrompt" import { getDisallowedTools } from "utils/getToolInfo" @@ -23,28 +22,32 @@ export const contentTypes = { } as const export interface RunAgentLoopOptions { - abortControllerRef?: { current: AbortController | undefined } + abortController: AbortController additionalSystemPrompt?: string config: AgentChatConfig - existingConnectedServers?: Set + connectedServers: Set messageQueue: MessageQueue onServerConnection?: (status: string) => void onToolPermissionRequest?: (toolName: string, input: any) => void sessionId?: string setIsProcessing?: (value: boolean) => void + userMessage: string } -export const runAgentLoop = async ({ - abortControllerRef, +export async function* runAgentLoop({ + abortController, additionalSystemPrompt, config, - existingConnectedServers, + connectedServers, messageQueue, onServerConnection, onToolPermissionRequest, - sessionId: initialSessionId, + sessionId, setIsProcessing, -}: RunAgentLoopOptions) => { + userMessage, +}: RunAgentLoopOptions) { + log("\n[runAgentLoop] USER:", userMessage, "\n") + const canUseTool = createCanUseTool({ messageQueue, onToolPermissionRequest, @@ -54,97 +57,53 @@ export const runAgentLoop = async ({ const disallowedTools = getDisallowedTools(config) const enabledMcpServers = getEnabledMcpServers(config.mcpServers) - let currentSessionId = initialSessionId - const connectedServers = existingConnectedServers ?? new Set() - - async function* agentLoop() { - while (true) { - const userMessage = await messageQueue.waitForMessage() - - log("\n[runAgentLoop] USER:", userMessage, "\n") - - if (userMessage.toLowerCase() === "exit") { - break - } - - if (!userMessage.trim()) { - continue - } - - const systemPrompt = await buildSystemPrompt({ - config, - additionalSystemPrompt, - connectedServers, - }) - - const { mcpServers } = await selectMcpServers({ - abortController: abortControllerRef?.current, - agents: config.agents, - alreadyConnectedServers: connectedServers, - enabledMcpServers, - onServerConnection, - sessionId: currentSessionId, - userMessage, - }) - - const agents = await createSDKAgents(config.agents) - - try { - const turnResponse = query({ - prompt: (async function* () { - yield { - type: "user" as const, - session_id: currentSessionId || "", - message: { - role: "user" as const, - content: userMessage, - }, - } as SDKUserMessage - })(), - options: { - model: config.model ?? "haiku", - permissionMode: config.permissionMode ?? "default", - includePartialMessages: config.stream ?? false, - mcpServers, - agents, - abortController: abortControllerRef?.current, - canUseTool, - systemPrompt, - disallowedTools, - resume: currentSessionId, - }, - }) + const { mcpServers } = await selectMcpServers({ + abortController, + agents: config.agents, + connectedServers, + enabledMcpServers, + onServerConnection, + sessionId, + userMessage, + }) - for await (const message of turnResponse) { - if ( - message.type === messageTypes.SYSTEM && - message.subtype === messageTypes.INIT - ) { - log( - "[runAgentLoop] [messageTypes.INIT]:", - JSON.stringify(message, null, 2) - ) + const systemPrompt = await buildSystemPrompt({ + additionalSystemPrompt, + config, + connectedServers, + }) - currentSessionId = message.session_id - } + const turnResponse = query({ + prompt: userMessage, + options: { + abortController, + canUseTool, + disallowedTools, + includePartialMessages: config.stream ?? false, + mcpServers, + model: config.model ?? "haiku", + permissionMode: config.permissionMode ?? "default", + resume: sessionId, + systemPrompt, + }, + }) - yield message + for await (const message of turnResponse) { + if ( + message.type === messageTypes.SYSTEM && + message.subtype === messageTypes.INIT + ) { + log( + "[runAgentLoop] [messageTypes.INIT]:", + JSON.stringify(message, null, 2) + ) + } - // If we hit a RESULT, this turn is complete - if (message.type === messageTypes.RESULT) { - break - } - } - } catch (error) { - log("[ERROR] [runAgentLoop] Query aborted or failed:", error) + yield message - // Continue to next message - } + // If we hit a RESULT, this turn is complete + if (message.type === messageTypes.RESULT) { + break } } - - return { - agentLoop: agentLoop(), - connectedServers, - } }