diff --git a/src/ai/work.ts b/src/ai/work.ts index 4664e9e..0f034fb 100644 --- a/src/ai/work.ts +++ b/src/ai/work.ts @@ -17,77 +17,76 @@ async function work(props: WorkProps) { const { session, chatService, signal } = props; - const messages = await workTools({ ...props, session }); - - if (needsToolConfirmation(messages)) + while (true) { - return messages; - } + const messages = await workTools({ ...props, session }); - let aiMessage: AIMessageChunk | undefined = undefined; - let toolProgressMessages: ToolProgressMessage[] = []; - - if (session.streaming) - { - const { result: stream, error } = await tryCatch(chatService.stream(session, signal)); - - if (!stream) + if (needsToolConfirmation(messages)) { - messages.push(new ErrorMessage(`ERROR: ${error?.message || error?.toString() || "llm.stream(...) failed."}`, "error", error)); return messages; } - for await (const chunk of stream) + let aiMessage: AIMessageChunk | undefined = undefined; + let toolProgressMessages: ToolProgressMessage[] = []; + + if (session.streaming) { - aiMessage = aiMessage !== undefined ? concat(aiMessage, chunk) : chunk; - setMessageIsStreaming(aiMessage, true); + const { result: stream, error } = await tryCatch(chatService.stream(session, signal)); - if (!aiMessage?.tool_calls?.length && !aiMessage?.tool_call_chunks?.length) + if (!stream) { - session.setMessages([...messages, aiMessage]); + messages.push(new ErrorMessage(`ERROR: ${error?.message || error?.toString() || "llm.stream(...) failed."}`, "error", error)); + return messages; } - else + + for await (const chunk of stream) { - toolProgressMessages = ToolProgressMessage.createFromChunks(aiMessage.tool_call_chunks); + aiMessage = aiMessage !== undefined ? concat(aiMessage, chunk) : chunk; + setMessageIsStreaming(aiMessage, true); - if (toolProgressMessages.length) + if (!aiMessage?.tool_calls?.length && !aiMessage?.tool_call_chunks?.length) { - session.setMessages([...messages, aiMessage, ...toolProgressMessages]); + session.setMessages([...messages, aiMessage]); + } + else + { + toolProgressMessages = ToolProgressMessage.createFromChunks(aiMessage.tool_call_chunks); + + if (toolProgressMessages.length) + { + session.setMessages([...messages, aiMessage, ...toolProgressMessages]); + } } } } - } - else - { - const { result, error } = await tryCatch(chatService.generate(session, signal)); - - if (!result || error) + else { - messages.push(new ErrorMessage(`ERROR: ${error?.message || error?.toString() || "llm.generate(...) failed."}`, "error", error)); - return messages; - } + const { result, error } = await tryCatch(chatService.generate(session, signal)); - aiMessage = result; - } - - if (aiMessage) - { - setMessageIsStreaming(aiMessage, false); - messages.push(aiMessage); - } + if (!result || error) + { + messages.push(new ErrorMessage(`ERROR: ${error?.message || error?.toString() || "llm.generate(...) failed."}`, "error", error)); + return messages; + } - if (!aiMessage?.tool_calls?.length) - { - return messages; - } + aiMessage = result; + } - // add pending toolCall(s) - toolProgressMessages = aiMessage.tool_calls.map(toolCall => new ToolProgressMessage(toolCall)) || []; - session.setMessages([...messages, ...toolProgressMessages]); + if (aiMessage) + { + setMessageIsStreaming(aiMessage, false); + messages.push(aiMessage); + } - await session.flush(); + if (!aiMessage?.tool_calls?.length) + { + return messages; + } - return work(props); + // add pending toolCall(s) + toolProgressMessages = aiMessage.tool_calls.map(toolCall => new ToolProgressMessage(toolCall)) || []; + session.setMessages([...messages, ...toolProgressMessages]); + } } async function workTools(props: Pick)