diff --git a/backend/internal/ai/agent/agent.go b/backend/internal/ai/agent/agent.go index f66d906..24ac1ff 100644 --- a/backend/internal/ai/agent/agent.go +++ b/backend/internal/ai/agent/agent.go @@ -11,6 +11,7 @@ package agent import ( "context" "encoding/json" + "errors" "fmt" "log/slog" "strings" @@ -20,6 +21,15 @@ import ( "github.com/Harsh-2002/Orva/backend/internal/database" ) +// maxRepeatedToolFailures is the loop-breaker budget: when the model re-issues +// the SAME tool call (identical name + arguments) and it fails this many times +// in consecutive iterations, the turn stops with an error instead of running +// out the full MaxIterations budget. Keying on the exact call keeps legitimate +// exploration alive (probing three different missing functions is three +// different calls) while catching the real pathology — a stuck model retrying +// one broken call, classically after its arguments were corrupted. +const maxRepeatedToolFailures = 3 + // Tool is one callable the agent may offer the model, plus the metadata that // drives approval gating and UI grouping. type Tool struct { @@ -174,6 +184,10 @@ func (r *Runner) advance(ctx context.Context, sink Sink, convID, principalID str return err } + // Loop-breaker state: per-call-signature consecutive-failure streaks (see + // maxRepeatedToolFailures). + failStreak := map[string]int{} + for iter := 0; iter < r.cfg.MaxIterations; iter++ { // Bail if the client disconnected (or the request was cancelled). // Otherwise the loop would keep calling the billed provider and @@ -220,7 +234,9 @@ func (r *Runner) advance(ctx context.Context, sink Sink, convID, principalID str toolCalls = ev.ToolCalls usage = ev.Usage case llm.EventError: - r.finalizeAssistant(assistant, textB.String(), thinkB.String(), nil, nil) + // Keep any usage the provider reported before the cut so the + // turn's token accounting survives the error path. + r.finalizeAssistant(assistant, textB.String(), thinkB.String(), nil, ev.Usage) _ = sink.Send("error", map[string]any{"message": ev.Err.Error()}) return ev.Err } @@ -247,11 +263,30 @@ func (r *Runner) advance(ctx context.Context, sink Sink, convID, principalID str // Process the requested tool calls. Read-only / auto ones run now; // anything gated pauses the turn for approval. - paused, results := r.processToolCalls(ctx, sink, convID, assistant.ID, toolCalls) + paused, results, failedSigs := r.processToolCalls(ctx, sink, convID, assistant.ID, toolCalls) if paused { _ = sink.Send("awaiting_approval", map[string]any{"conversation_id": convID}) return nil } + // Advance the per-signature failure streaks: bump the calls that failed + // this round, reset every signature that didn't recur. + failedNow := make(map[string]bool, len(failedSigs)) + for _, sig := range failedSigs { + failedNow[sig] = true + failStreak[sig]++ + } + for sig := range failStreak { + if !failedNow[sig] { + delete(failStreak, sig) + } + } + for _, sig := range failedSigs { + if failStreak[sig] >= maxRepeatedToolFailures { + err := fmt.Errorf("stopped: the model re-issued the same failing tool call %d times in a row and is not making progress — retry the message or rephrase the request", failStreak[sig]) + _ = sink.Send("error", map[string]any{"message": err.Error()}) + return err + } + } // Append the persisted tool results, then loop and let the model continue. history = append(history, results...) } @@ -264,8 +299,10 @@ func (r *Runner) advance(ctx context.Context, sink Sink, convID, principalID str // non-gated ones immediately, and returns true if any call is awaiting approval // (the turn must pause). It also returns the tool-result messages it persisted, // in order, so the caller can append them to the in-memory history without a -// re-read. When paused, results is irrelevant (the turn returns). -func (r *Runner) processToolCalls(ctx context.Context, sink Sink, convID, msgID string, calls []llm.ToolCall) (paused bool, results []llm.Message) { +// re-read, and the signatures (name + arguments) of the calls that failed — +// the caller's loop-breaker signal. When paused, results is irrelevant (the +// turn returns). +func (r *Runner) processToolCalls(ctx context.Context, sink Sink, convID, msgID string, calls []llm.ToolCall) (paused bool, results []llm.Message, failedSigs []string) { for _, c := range calls { meta, known := r.byName[c.Name] requiresApproval := known && r.approvalNeeded(meta) @@ -291,9 +328,23 @@ func (r *Runner) processToolCalls(ctx context.Context, sink Sink, convID, msgID Destructive: destructive, } _ = r.store.InsertToolCall(row) + // Embed the args verbatim only when they're valid JSON. Invalid args + // (truncated stream) would fail the frame's own marshal and the sink + // degrades the whole payload to {} — losing the id and name the UI + // needs to render the (failing) call. Fall back to a JSON string, + // capped: the call is about to be failed anyway, so don't push a huge + // broken blob down the SSE pipe just to label it. + argsPayload := json.RawMessage(row.Args) + if !json.Valid(argsPayload) { + s := row.Args + if len(s) > 2048 { + s = s[:2048] + "…" + } + argsPayload = json.RawMessage(mustJSON(s)) + } _ = sink.Send("tool_call", map[string]any{ "id": row.ID, "call_id": c.ID, "name": c.Name, "group": group, - "args": json.RawMessage(row.Args), "requires_approval": requiresApproval, + "args": argsPayload, "requires_approval": requiresApproval, }) if requiresApproval { @@ -301,11 +352,14 @@ func (r *Runner) processToolCalls(ctx context.Context, sink Sink, convID, msgID continue // wait for the user; do NOT run it } r.runToolCall(ctx, sink, convID, row) + if row.Status == "failed" { + failedSigs = append(failedSigs, c.Name+"\x00"+row.Args) + } // Record the result for the in-memory history (mirrors persistToolResult: // a role=tool message keyed by the call id). results = append(results, llm.Message{Role: llm.RoleTool, Content: row.Result, ToolCallID: row.CallID}) } - return paused, results + return paused, results, failedSigs } // runToolCall dispatches one (approved or auto) tool call, persists the @@ -318,10 +372,22 @@ func (r *Runner) runToolCall(ctx context.Context, sink Sink, convID string, row } started := time.Now().UTC() row.StartedAt = &started - row.Status = "running" - _ = r.store.UpdateToolCall(row) - out, err := r.dispatch(ctx, row.ToolName, json.RawMessage(emptyToObj(row.Args))) + // Refuse to dispatch arguments that aren't valid JSON. They can't be a + // real call — the most common cause is a provider stream truncated mid + // tool call — and the dispatcher would only fail later with a confusing + // tool-specific unmarshal error. Failing here gives the model (and the + // audit row) a precise, self-correctable signal. (The doomed call also + // skips the interim status=running write — it goes straight to failed.) + var out json.RawMessage + var err error + if raw := json.RawMessage(emptyToObj(row.Args)); !json.Valid(raw) { + err = errors.New("invalid tool arguments: not valid JSON (the model response may have been truncated) — re-issue the call with complete arguments") + } else { + row.Status = "running" + _ = r.store.UpdateToolCall(row) + out, err = r.dispatch(ctx, row.ToolName, raw) + } var resultJSON string if err != nil { resultJSON = mustJSON(map[string]string{"error": err.Error()}) @@ -340,9 +406,16 @@ func (r *Runner) runToolCall(ctx context.Context, sink Sink, convID string, row row.Result = resultJSON _ = r.store.UpdateToolCall(row) r.persistToolResult(convID, row, resultJSON) + // Same {}-degradation hazard as the tool_call frame: a dispatcher that + // returns non-JSON output would fail the frame's marshal and the sink + // would drop the id/status the UI needs to settle the call's spinner. + resultPayload := json.RawMessage(resultJSON) + if !json.Valid(resultPayload) { + resultPayload = json.RawMessage(mustJSON(resultJSON)) + } _ = sink.Send("tool_result", map[string]any{ "id": row.ID, "call_id": row.CallID, "status": row.Status, - "result": json.RawMessage(resultJSON), + "result": resultPayload, }) } diff --git a/backend/internal/ai/agent/agent_test.go b/backend/internal/ai/agent/agent_test.go new file mode 100644 index 0000000..dae8a3f --- /dev/null +++ b/backend/internal/ai/agent/agent_test.go @@ -0,0 +1,155 @@ +package agent + +import ( + "context" + "encoding/json" + "errors" + "strings" + "testing" + + "github.com/Harsh-2002/Orva/backend/internal/ai/llm" + "github.com/Harsh-2002/Orva/backend/internal/database" +) + +// ─── fakes ─────────────────────────────────────────────────────────────────── + +type fakeStore struct { + messages []*database.AIMessage + toolCalls []*database.AIToolCall +} + +func (s *fakeStore) InsertMessage(m *database.AIMessage) error { + s.messages = append(s.messages, m) + return nil +} +func (s *fakeStore) UpdateMessage(id, content, parts, tokenUsage string) error { return nil } +func (s *fakeStore) ListMessages(conversationID string, sinceSeq int) ([]*database.AIMessage, error) { + return s.messages, nil +} +func (s *fakeStore) InsertToolCall(t *database.AIToolCall) error { + s.toolCalls = append(s.toolCalls, t) + return nil +} +func (s *fakeStore) GetToolCall(id string) (*database.AIToolCall, error) { + return nil, errors.New("not found") +} +func (s *fakeStore) UpdateToolCall(t *database.AIToolCall) error { return nil } +func (s *fakeStore) ListToolCalls(conversationID string) ([]*database.AIToolCall, error) { + return s.toolCalls, nil +} +func (s *fakeStore) TouchConversation(id string) error { return nil } + +type fakeSink struct{ events []string } + +func (s *fakeSink) Send(event string, data any) error { + s.events = append(s.events, event) + return nil +} + +func testRunner(dispatch Dispatcher, tools ...Tool) *Runner { + return New(nil, tools, dispatch, &fakeStore{}, Config{ApprovalPolicy: "auto"}) +} + +func readOnlyTool(name string) Tool { + return Tool{Def: llm.ToolDef{Name: name}, Group: "test", Perm: "read", ReadOnly: true} +} + +// ─── invalid-argument guard (truncated tool calls) ────────────────────────── + +// TestRunToolCallRejectsInvalidArgsWithoutDispatch: arguments that aren't +// valid JSON (the signature of a provider stream truncated mid tool call) +// must fail fast with a self-correctable error — and never reach the +// dispatcher, whose tool-specific unmarshal error would obscure the cause. +func TestRunToolCallRejectsInvalidArgsWithoutDispatch(t *testing.T) { + dispatched := false + r := testRunner(func(ctx context.Context, name string, args json.RawMessage) (json.RawMessage, error) { + dispatched = true + return json.RawMessage(`{}`), nil + }, readOnlyTool("list_functions")) + + row := &database.AIToolCall{ + ConversationID: "conv1", CallID: "call_0", + ToolName: "list_functions", + Args: `{"name":"demo","runt`, // truncated mid-stream + } + r.runToolCall(context.Background(), &fakeSink{}, "conv1", row) + + if dispatched { + t.Error("dispatcher must not be called with invalid JSON arguments") + } + if row.Status != "failed" { + t.Errorf("status = %q, want failed", row.Status) + } + if !strings.Contains(row.Result, "invalid tool arguments") { + t.Errorf("result should explain the invalid arguments, got %q", row.Result) + } +} + +// TestRunToolCallValidArgsStillDispatch: the guard must not get in the way of +// well-formed calls (including empty args, which normalize to {}). +func TestRunToolCallValidArgsStillDispatch(t *testing.T) { + for _, args := range []string{`{"limit":5}`, "", " "} { + dispatched := false + r := testRunner(func(ctx context.Context, name string, args json.RawMessage) (json.RawMessage, error) { + dispatched = true + return json.RawMessage(`{"ok":true}`), nil + }, readOnlyTool("list_functions")) + + row := &database.AIToolCall{ConversationID: "c", CallID: "call_0", ToolName: "list_functions", Args: args} + r.runToolCall(context.Background(), &fakeSink{}, "c", row) + + if !dispatched { + t.Errorf("args %q: dispatcher should have been called", args) + } + if row.Status != "succeeded" { + t.Errorf("args %q: status = %q, want succeeded", args, row.Status) + } + } +} + +// ─── loop-breaker signal ───────────────────────────────────────────────────── + +// TestProcessToolCallsFailedSignatures: the third return value feeds the +// agent's loop-breaker; it must carry the signature (name + args) of exactly +// the calls that failed, so the breaker can track per-call streaks — a +// succeeding companion call must not mask a repeatedly failing one, and +// distinct failing probes must not be conflated into one signature. +func TestProcessToolCallsFailedSignatures(t *testing.T) { + succeedFor := func(okName string) Dispatcher { + return func(ctx context.Context, name string, args json.RawMessage) (json.RawMessage, error) { + if name == okName { + return json.RawMessage(`{}`), nil + } + return nil, errors.New("boom") + } + } + calls := []llm.ToolCall{ + {ID: "c1", Name: "alpha", Arguments: `{"x":1}`}, + {ID: "c2", Name: "beta", Arguments: `{}`}, + } + tools := []Tool{readOnlyTool("alpha"), readOnlyTool("beta")} + + r := testRunner(succeedFor("beta"), tools...) + _, _, sigs := r.processToolCalls(context.Background(), &fakeSink{}, "c", "m", calls) + if len(sigs) != 1 { + t.Fatalf("exactly the failing call must be reported, got %d signatures", len(sigs)) + } + if sigs[0] != "alpha\x00"+`{"x":1}` { + t.Errorf("signature must be name+args, got %q", sigs[0]) + } + + r = testRunner(succeedFor("alpha"), tools...) + _, _, sigs = r.processToolCalls(context.Background(), &fakeSink{}, "c", "m", calls) + if len(sigs) != 1 || sigs[0] != "beta\x00{}" { + t.Errorf("expected only beta's signature, got %v", sigs) + } + + allOK := func(ctx context.Context, name string, args json.RawMessage) (json.RawMessage, error) { + return json.RawMessage(`{}`), nil + } + r = testRunner(allOK, tools...) + _, _, sigs = r.processToolCalls(context.Background(), &fakeSink{}, "c", "m", calls) + if len(sigs) != 0 { + t.Errorf("no failures → no signatures, got %v", sigs) + } +} diff --git a/backend/internal/ai/llm/llm.go b/backend/internal/ai/llm/llm.go index bf81c49..7eddced 100644 --- a/backend/internal/ai/llm/llm.go +++ b/backend/internal/ai/llm/llm.go @@ -73,73 +73,105 @@ func (c *Client) Stream(ctx context.Context, req Request) (<-chan Event, error) go func() { defer close(out) defer cancel() + pump(ctx, stream, out) + }() + + return out, nil +} - // Tool calls arrive incrementally (per OpenAI streaming): each delta - // carries a tool-call index plus a fragment of the name/arguments. We - // accumulate by index and assemble the complete calls at finish. - byIndex := map[int]*toolAcc{} - var order []int - var finish string - var usage *Usage +// pump consumes one provider stream and translates it into neutral events. +// It always terminates the event stream with exactly one EventDone or +// EventError. Split out of Stream so the termination logic is unit-testable +// with a synthetic chunk channel. +func pump(ctx context.Context, stream chan *schemas.BifrostStreamChunk, out chan<- Event) { + // Tool calls arrive incrementally (per OpenAI streaming): each delta + // carries a tool-call index plus a fragment of the name/arguments. We + // accumulate by index and assemble the complete calls at finish. + byIndex := map[int]*toolAcc{} + var order []int + var finish string + var usage *Usage - for chunk := range stream { - if chunk == nil { - continue + for chunk := range stream { + if chunk == nil { + continue + } + if chunk.BifrostError != nil { + // When the request context is already cancelled, the provider error + // is just the wreckage of our own abort (Bifrost may surface the + // cancellation as an error chunk instead of a bare channel close). + // Surface the context error so callers can keep filtering routine + // disconnects out of the error log with errors.Is. + err := error(errors.New(bifrostErr(chunk.BifrostError))) + if ctxErr := ctx.Err(); ctxErr != nil { + err = ctxErr + } + out <- Event{Type: EventError, Err: err} + return + } + resp := chunk.BifrostChatResponse + if resp == nil { + continue + } + if resp.Usage != nil { + usage = &Usage{ + PromptTokens: resp.Usage.PromptTokens, + CompletionTokens: resp.Usage.CompletionTokens, + TotalTokens: resp.Usage.TotalTokens, } - if chunk.BifrostError != nil { - out <- Event{Type: EventError, Err: errors.New(bifrostErr(chunk.BifrostError))} - return + } + for i := range resp.Choices { + choice := resp.Choices[i] + if choice.FinishReason != nil && *choice.FinishReason != "" { + finish = *choice.FinishReason } - resp := chunk.BifrostChatResponse - if resp == nil { + sc := choice.ChatStreamResponseChoice + if sc == nil || sc.Delta == nil { continue } - if resp.Usage != nil { - usage = &Usage{ - PromptTokens: resp.Usage.PromptTokens, - CompletionTokens: resp.Usage.CompletionTokens, - TotalTokens: resp.Usage.TotalTokens, - } + d := sc.Delta + if d.Content != nil && *d.Content != "" { + out <- Event{Type: EventText, Text: *d.Content} } - for i := range resp.Choices { - choice := resp.Choices[i] - if choice.FinishReason != nil && *choice.FinishReason != "" { - finish = *choice.FinishReason - } - sc := choice.ChatStreamResponseChoice - if sc == nil || sc.Delta == nil { - continue - } - d := sc.Delta - if d.Content != nil && *d.Content != "" { - out <- Event{Type: EventText, Text: *d.Content} + if d.Reasoning != nil && *d.Reasoning != "" { + out <- Event{Type: EventThinking, Text: *d.Reasoning} + } + for _, tc := range d.ToolCalls { + idx := int(tc.Index) + a := byIndex[idx] + if a == nil { + a = &toolAcc{} + byIndex[idx] = a + order = append(order, idx) } - if d.Reasoning != nil && *d.Reasoning != "" { - out <- Event{Type: EventThinking, Text: *d.Reasoning} + if tc.ID != nil && *tc.ID != "" { + a.id = *tc.ID } - for _, tc := range d.ToolCalls { - idx := int(tc.Index) - a := byIndex[idx] - if a == nil { - a = &toolAcc{} - byIndex[idx] = a - order = append(order, idx) - } - if tc.ID != nil && *tc.ID != "" { - a.id = *tc.ID - } - if tc.Function.Name != nil && *tc.Function.Name != "" { - a.name = *tc.Function.Name - } - a.args.WriteString(tc.Function.Arguments) + if tc.Function.Name != nil && *tc.Function.Name != "" { + a.name = *tc.Function.Name } + a.args.WriteString(tc.Function.Arguments) } } + } - out <- Event{Type: EventDone, ToolCalls: assembleToolCalls(order, byIndex), FinishReason: finish, Usage: usage} - }() + // A healthy stream always reports a finish reason ("stop", "tool_calls", …) + // before the provider closes it. The channel closing without one means the + // upstream connection was reset or the response was truncated mid-stream — + // reporting Done here would hand the caller a half answer (or half a tool + // call) marked as success. + if finish == "" { + err := ctx.Err() + if err == nil { + err = errors.New("provider stream ended unexpectedly (connection reset or truncated response)") + } + // Usage may have arrived before the cut — pass it along so the turn's + // token accounting survives the error path. + out <- Event{Type: EventError, Err: err, Usage: usage} + return + } - return out, nil + out <- Event{Type: EventDone, ToolCalls: assembleToolCalls(order, byIndex), FinishReason: finish, Usage: usage} } // toolAcc accumulates the streamed fragments of one tool call (its id, name, and @@ -259,12 +291,22 @@ func toBifrostMessage(m Message) schemas.ChatMessage { for _, tc := range m.ToolCalls { id := tc.ID name := tc.Name + // Replay only valid-JSON arguments. A persisted call whose args + // were cut mid-stream (or emitted malformed by the model) would + // otherwise be replayed verbatim on every subsequent iteration + // AND every future turn of the conversation — strict providers + // reject the whole request, permanently bricking the chat. The + // model already saw the failure in the call's tool result. + args := tc.Arguments + if strings.TrimSpace(args) == "" || !json.Valid([]byte(args)) { + args = "{}" + } calls = append(calls, schemas.ChatAssistantMessageToolCall{ Type: ptrStr("function"), ID: &id, Function: schemas.ChatAssistantMessageToolCallFunction{ Name: &name, - Arguments: tc.Arguments, + Arguments: args, }, }) } diff --git a/backend/internal/ai/llm/llm_test.go b/backend/internal/ai/llm/llm_test.go index fa3a434..c5f0ed5 100644 --- a/backend/internal/ai/llm/llm_test.go +++ b/backend/internal/ai/llm/llm_test.go @@ -1,8 +1,12 @@ package llm import ( + "context" + "errors" "strings" "testing" + + "github.com/maximhq/bifrost/core/schemas" ) func acc(id, name, args string) *toolAcc { @@ -39,6 +43,152 @@ func TestAssembleToolCallsSynthesizesID(t *testing.T) { } } +// ─── pump (stream termination) ─────────────────────────────────────────────── + +func textChunk(s string) *schemas.BifrostStreamChunk { + return &schemas.BifrostStreamChunk{ + BifrostChatResponse: &schemas.BifrostChatResponse{ + Choices: []schemas.BifrostResponseChoice{{ + ChatStreamResponseChoice: &schemas.ChatStreamResponseChoice{ + Delta: &schemas.ChatStreamResponseChoiceDelta{Content: &s}, + }, + }}, + }, + } +} + +func finishChunk(reason string) *schemas.BifrostStreamChunk { + return &schemas.BifrostStreamChunk{ + BifrostChatResponse: &schemas.BifrostChatResponse{ + Choices: []schemas.BifrostResponseChoice{{FinishReason: &reason}}, + }, + } +} + +func toolFragmentChunk(idx int, id, name, argsFragment string) *schemas.BifrostStreamChunk { + tc := schemas.ChatAssistantMessageToolCall{ + Function: schemas.ChatAssistantMessageToolCallFunction{Arguments: argsFragment}, + } + tc.Index = uint16(idx) + if id != "" { + tc.ID = &id + } + if name != "" { + tc.Function.Name = &name + } + return &schemas.BifrostStreamChunk{ + BifrostChatResponse: &schemas.BifrostChatResponse{ + Choices: []schemas.BifrostResponseChoice{{ + ChatStreamResponseChoice: &schemas.ChatStreamResponseChoice{ + Delta: &schemas.ChatStreamResponseChoiceDelta{ + ToolCalls: []schemas.ChatAssistantMessageToolCall{tc}, + }, + }, + }}, + }, + } +} + +// runPump feeds the given chunks through pump and returns every emitted event. +func runPump(ctx context.Context, chunks ...*schemas.BifrostStreamChunk) []Event { + stream := make(chan *schemas.BifrostStreamChunk, len(chunks)) + for _, c := range chunks { + stream <- c + } + close(stream) + out := make(chan Event, 64) + go func() { + defer close(out) + pump(ctx, stream, out) + }() + var events []Event + for ev := range out { + events = append(events, ev) + } + return events +} + +// TestPumpCleanCompletion: a stream that reports a finish reason ends with +// EventDone carrying it. +func TestPumpCleanCompletion(t *testing.T) { + events := runPump(context.Background(), textChunk("hello "), textChunk("world"), finishChunk("stop")) + last := events[len(events)-1] + if last.Type != EventDone { + t.Fatalf("expected EventDone, got %s (%v)", last.Type, last.Err) + } + if last.FinishReason != "stop" { + t.Errorf("finish reason = %q, want stop", last.FinishReason) + } +} + +// TestPumpRejectsStreamWithoutFinishReason covers the silent-truncation bug: +// a provider reset closes the stream with no finish reason, which must surface +// as EventError — not as a clean EventDone wrapping a half answer. +func TestPumpRejectsStreamWithoutFinishReason(t *testing.T) { + events := runPump(context.Background(), textChunk("half an ans")) // reset mid-response + last := events[len(events)-1] + if last.Type != EventError { + t.Fatalf("truncated stream must end with EventError, got %s", last.Type) + } + if last.Err == nil || !strings.Contains(last.Err.Error(), "stream ended unexpectedly") { + t.Errorf("unexpected error: %v", last.Err) + } + for _, ev := range events { + if ev.Type == EventDone { + t.Error("truncated stream must not emit EventDone") + } + } +} + +// TestPumpRejectsTruncatedToolCall: a reset mid tool call (fragmented JSON +// arguments, no finish reason) must error rather than hand the agent a +// half-assembled call — the trigger for the "stuck in tool calling" loop. +func TestPumpRejectsTruncatedToolCall(t *testing.T) { + events := runPump(context.Background(), + toolFragmentChunk(0, "call_1", "create_function", `{"name":"demo","runt`)) // cut mid-args + last := events[len(events)-1] + if last.Type != EventError { + t.Fatalf("truncated tool-call stream must end with EventError, got %s", last.Type) + } +} + +// TestPumpAssemblesFragmentedToolCall: tool-call arguments arrive as indexed +// fragments across deltas (id/name only on the first); a properly finished +// stream must reassemble them into one complete call. +func TestPumpAssemblesFragmentedToolCall(t *testing.T) { + events := runPump(context.Background(), + toolFragmentChunk(0, "call_1", "create_function", `{"na`), + toolFragmentChunk(0, "", "", `me":"demo"}`), + finishChunk("tool_calls")) + last := events[len(events)-1] + if last.Type != EventDone { + t.Fatalf("expected EventDone, got %s (%v)", last.Type, last.Err) + } + if len(last.ToolCalls) != 1 { + t.Fatalf("expected 1 assembled call, got %d", len(last.ToolCalls)) + } + tc := last.ToolCalls[0] + if tc.ID != "call_1" || tc.Name != "create_function" || tc.Arguments != `{"name":"demo"}` { + t.Errorf("fragments not reassembled: %+v", tc) + } +} + +// TestPumpCancelledContextClassifiedAsCancellation: when the request context +// is cancelled (client disconnect), the terminating error must be the context +// error so callers can keep filtering cancellations out of the error log. +func TestPumpCancelledContextClassifiedAsCancellation(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + cancel() + events := runPump(ctx, textChunk("partial")) + last := events[len(events)-1] + if last.Type != EventError { + t.Fatalf("expected EventError, got %s", last.Type) + } + if !errors.Is(last.Err, context.Canceled) { + t.Errorf("expected context.Canceled, got %v", last.Err) + } +} + // TestIsReasoningError gates the thinking graceful-fallback: only reasoning- // related upstream errors should trigger the strip-and-retry; auth / model / // network failures must surface immediately instead of being masked. diff --git a/backend/internal/ai/llm/types.go b/backend/internal/ai/llm/types.go index 72134e2..7f60268 100644 --- a/backend/internal/ai/llm/types.go +++ b/backend/internal/ai/llm/types.go @@ -88,7 +88,9 @@ type Event struct { // requested this turn (empty if the model produced only text). ToolCalls []ToolCall - // FinishReason and Usage are set on EventDone when the provider reports them. + // FinishReason and Usage are set on EventDone when the provider reports + // them. Usage may also accompany an EventError when the provider reported + // it before the stream was cut, so token accounting survives the error path. FinishReason string Usage *Usage diff --git a/cli/commands/chat.go b/cli/commands/chat.go index fa14eed..3f7b2ae 100644 --- a/cli/commands/chat.go +++ b/cli/commands/chat.go @@ -437,7 +437,9 @@ func (s *chatSession) runTurn(parent context.Context, content string) error { fmt.Fprintln(s.errOut, s.styles.Muted.Render("("+res.note+")")) } if res.errMsg != "" { - s.printError(res.errMsg) + // Return without printing: the REPL prints returned errors itself and + // the one-shot path surfaces them via cobra — printing here too showed + // every stream error twice. return errors.New(res.errMsg) } return nil @@ -607,6 +609,10 @@ func (s *chatSession) drive(resp *http.Response) (turnResult, error) { } _ = json.Unmarshal([]byte(data), &d) res.errMsg = d.Message + // The error path sends no message_end — close out whatever + // streamed so the error doesn't land on the same row as a + // truncated half-answer. + s.finishMessage(text.String(), textStarted && !interleaved) return true, nil } return false, nil diff --git a/frontend/src/stores/ai.js b/frontend/src/stores/ai.js index 63014a8..975f7c6 100644 --- a/frontend/src/stores/ai.js +++ b/frontend/src/stores/ai.js @@ -81,7 +81,11 @@ export const useAIStore = defineStore('ai', () => { if (!activeId.value) return try { const { data } = await apiClient.get(`/ai/conversations/${activeId.value}`) - timeline.value = buildTimeline(data) + // Error items live only in the optimistic timeline (the server has no + // notion of them) — carry them across the rebuild or the ErrorCard for + // a just-failed turn vanishes the moment the post-turn refresh lands. + const errors = timeline.value.filter((it) => it.kind === 'error') + timeline.value = [...buildTimeline(data), ...errors] } catch { /* keep the optimistic timeline on a refresh failure */ } } @@ -205,6 +209,16 @@ export const useAIStore = defineStore('ai', () => { case 'error': pushError(data.message || 'stream error', data.code) streaming.value = false + // The error path sends no message_end, so settle the streaming + // assistant message here too: stop the thinking timer/shimmer and + // release the index so the next turn starts a fresh message. + if (curIdx >= 0) { + patchAssistant((m) => { + const i = m.parts.findIndex((p) => p.type === 'thinking') + if (i >= 0) m.parts[i] = { ...m.parts[i], streaming: false } + }) + } + curIdx = -1 break } }