Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/__tests__/store.test.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ describe("Store", () => {
expect(getState().sessionId).toBeUndefined()
expect(getState().stats).toBeUndefined()
expect(getState().pendingToolPermission).toBeUndefined()
expect(getState().abortController).toBeInstanceOf(AbortController)
expect(getState().abortController).toBeUndefined()
})

test("should have MessageQueue instance", () => {
Expand Down
24 changes: 13 additions & 11 deletions src/components/AgentChat.tsx
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import { Box, Text, useInput } from "ink"
import Spinner from "ink-spinner"
import { ChatHeader } from "components/ChatHeader"
import { Markdown } from "components/Markdown"
import { Stats } from "components/Stats"
Expand All @@ -8,6 +6,8 @@ import { ToolUses } from "components/ToolUses"
import { UserInput } from "components/UserInput"
import { useAgent } from "hooks/useAgent"
import { useMcpClient } from "hooks/useMcpClient"
import { Box, Text, useInput } from "ink"
import Spinner from "ink-spinner"
import { AgentStore } from "store"

export const AgentChat: React.FC = () => {
Expand Down Expand Up @@ -92,21 +92,23 @@ export const AgentChat: React.FC = () => {

case state.isProcessing: {
return (
<Text dimColor>
<Text color="cyan">
<Spinner type="balloon" />
<>
<Text dimColor>
<Text color="cyan">
<Spinner type="balloon" />
</Text>
{" Agent is thinking..."}
</Text>
{" Agent is thinking..."}
</Text>
)
}

default: {
return <UserInput />
<Box marginBottom={1} />
</>
)
}
}
})()}

<UserInput />

<Box marginTop={1} />
</Box>
)
Expand Down
6 changes: 6 additions & 0 deletions src/components/UserInput.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,12 @@ export const UserInput: React.FC = () => {
})

actions.sendMessage(value)

// Slight delay just in case user has aborted request via second message
setTimeout(() => {
actions.setIsProcessing(true)
}, 100)

reset()
}

Expand Down
96 changes: 64 additions & 32 deletions src/hooks/useAgent.ts
Original file line number Diff line number Diff line change
@@ -1,48 +1,64 @@
import { useEffect, useRef } from "react"
import { useCallback, useEffect, useRef } from "react"
import { AgentStore } from "store"
import { log } from "utils/logger"
import { messageTypes, runAgentLoop } from "utils/runAgentLoop"

export function useAgent() {
const messageQueue = AgentStore.useStoreState((state) => state.messageQueue)
const sessionId = AgentStore.useStoreState((state) => state.sessionId)
const config = AgentStore.useStoreState((state) => state.config)
const abortController = AgentStore.useStoreState(
(state) => state.abortController
)
const actions = AgentStore.useStoreActions((actions) => actions)
const currentAssistantMessageRef = useRef("")
const abortControllerRef = useRef(abortController)
const sessionIdRef = useRef<string | undefined>(undefined)
const abortControllerRef = useRef<AbortController | undefined>(undefined)
const connectedServersRef = useRef<Set<string>>(new Set())

// Update ref when abort controller changes
abortControllerRef.current = abortController
const runQuery = useCallback(
async (userMessage: string) => {
if (abortControllerRef.current) {
log("[useAgent] Aborting existing query for new message:", userMessage)

useEffect(() => {
const streamEnabled = config.stream ?? false

const runAgent = async () => {
const { agentLoop } = await runAgentLoop({
messageQueue,
sessionId,
config,
abortControllerRef,
onToolPermissionRequest: (toolName, input) => {
actions.setPendingToolPermission({ toolName, input })
},
onServerConnection: (status) => {
actions.addChatHistoryEntry({
type: "message",
role: "system",
content: status,
})
},
setIsProcessing: actions.setIsProcessing,
})
// When a new message comes in, always abort the old one and start fresh
abortControllerRef.current.abort()
}

// Create fresh abort controller for this query
const abortController = new AbortController()
abortControllerRef.current = abortController
actions.setAbortController(abortController)

const streamEnabled = config.stream ?? false

try {
const agentLoop = runAgentLoop({
abortController,
config,
connectedServers: connectedServersRef.current,
messageQueue,
onToolPermissionRequest: (toolName, input) => {
actions.setPendingToolPermission({ toolName, input })
},
onServerConnection: (status) => {
actions.addChatHistoryEntry({
type: "message",
role: "system",
content: status,
})
},
sessionId: sessionIdRef.current,
setIsProcessing: actions.setIsProcessing,
userMessage,
})

for await (const message of agentLoop) {
if (abortController.signal.aborted) {
log("[useAgent] Query was aborted, stopping message processing")
return
}

switch (true) {
case message.type === messageTypes.SYSTEM &&
message.subtype === messageTypes.INIT: {
sessionIdRef.current = message.session_id
actions.setSessionId(message.session_id)
actions.handleMcpServerStatus(message.mcp_servers)

Expand Down Expand Up @@ -127,6 +143,12 @@ export function useAgent() {
}
}
} catch (error) {
if (error instanceof Error && error.name === "AbortError") {
actions.setIsProcessing(false)
return
}

// Handle other errors
if (
error instanceof Error &&
!error.message.includes("process aborted by user")
Expand All @@ -136,8 +158,18 @@ export function useAgent() {

actions.setIsProcessing(false)
}
}
},
[config, messageQueue, actions]
)

runAgent()
}, [])
// Start listening for new messages from input
useEffect(() => {
const unsubscribe = messageQueue.subscribe((userMessage) => {
setTimeout(() => {
runQuery(userMessage)
}, 0)
})

return unsubscribe
}, [messageQueue, runQuery])
}
13 changes: 7 additions & 6 deletions src/mcp/getAgentStatus.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,17 @@ import { messageTypes, runAgentLoop } from "utils/runAgentLoop"
export const getAgentStatus = async (mcpServer?: McpServer) => {
const config = await loadConfig()
const messageQueue = new MessageQueue()
const abortController = new AbortController()
const connectedServers = new Set<string>()

const { agentLoop } = await runAgentLoop({
messageQueue,
const agentLoop = runAgentLoop({
abortController,
config,
connectedServers,
messageQueue,
userMessage: "status",
})

await new Promise((resolve) => setTimeout(resolve, 0))

messageQueue.sendMessage("status")

for await (const message of agentLoop) {
if (
message.type === messageTypes.SYSTEM &&
Expand Down
13 changes: 7 additions & 6 deletions src/mcp/runStandaloneAgentLoop.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,17 @@ export const runStandaloneAgentLoop = async ({
const messageQueue = new MessageQueue()
const streamEnabled = config.stream ?? false

const { agentLoop, connectedServers } = await runAgentLoop({
const connectedServers = existingConnectedServers ?? new Set<string>()
const abortController = new AbortController()

const agentLoop = runAgentLoop({
abortController,
additionalSystemPrompt,
config,
existingConnectedServers,
connectedServers,
messageQueue,
sessionId,
userMessage: prompt,
onServerConnection: async (status) => {
await mcpServer.sendLoggingMessage({
level: "info",
Expand All @@ -42,10 +47,6 @@ export const runStandaloneAgentLoop = async ({
},
})

await new Promise((resolve) => setTimeout(resolve, 0))

messageQueue.sendMessage(prompt)

let finalResponse = ""
let assistantMessage = ""

Expand Down
6 changes: 4 additions & 2 deletions src/store.ts
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ export interface StoreModel {
}

export const AgentStore = createContextStore<StoreModel>({
abortController: new AbortController(),
abortController: undefined,
chatHistory: [],
config: null as unknown as AgentChatConfig,
currentAssistantMessage: "",
Expand Down Expand Up @@ -142,7 +142,9 @@ export const AgentStore = createContextStore<StoreModel>({
// Actions
abortRequest: action((state) => {
state.abortController?.abort()
state.abortController = new AbortController()
state.abortController = undefined
state.currentAssistantMessage = ""
state.stats = "User aborted the request."
state.isProcessing = false
}),

Expand Down
7 changes: 7 additions & 0 deletions src/utils/MessageQueue.ts
Original file line number Diff line number Diff line change
Expand Up @@ -33,4 +33,11 @@ export class MessageQueue extends EventEmitter {
hasPendingRequests(): boolean {
return this.listenerCount("message") > 0
}

subscribe(callback: (message: string) => void): () => void {
this.on("message", callback)
return () => {
this.off("message", callback)
}
}
}
12 changes: 6 additions & 6 deletions src/utils/mcpServerSelectionAgent.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import { messageTypes } from "./runAgentLoop"
interface SelectMcpServersOptions {
abortController?: AbortController
agents?: Record<string, AgentConfig>
alreadyConnectedServers?: Set<string>
connectedServers?: Set<string>
enabledMcpServers: Record<string, any> | undefined
onServerConnection?: (status: string) => void
sessionId?: string
Expand All @@ -22,7 +22,7 @@ interface SelectMcpServersOptions {
export const selectMcpServers = async ({
abortController,
agents,
alreadyConnectedServers = new Set(),
connectedServers = new Set(),
enabledMcpServers,
onServerConnection,
sessionId,
Expand All @@ -38,7 +38,7 @@ export const selectMcpServers = async ({

log(
"[mcpServerSelectionAgent] Already connected:",
Array.from(alreadyConnectedServers).join(", ") || "none"
Array.from(connectedServers).join(", ") || "none"
)

const serverCapabilities = Object.entries(enabledMcpServers)
Expand Down Expand Up @@ -159,7 +159,7 @@ Examples:
log("[mcpServerSelectionAgent] Selected MCP servers:", selectedServers)

const newServers = selectedServers.filter(
(server) => !alreadyConnectedServers.has(server.toLowerCase())
(server) => !connectedServers.has(server.toLowerCase())
)

if (newServers.length > 0) {
Expand All @@ -172,7 +172,7 @@ Examples:
}

const allServers = new Set([
...Array.from(alreadyConnectedServers),
...Array.from(connectedServers),
...selectedServers,
])

Expand Down Expand Up @@ -212,7 +212,7 @@ Examples:

// Update the connected servers set with new servers
newServers.forEach((server) => {
alreadyConnectedServers.add(server.toLowerCase())
connectedServers.add(server.toLowerCase())
})

return {
Expand Down
Loading