Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 83 additions & 10 deletions backend/internal/ai/agent/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ package agent
import (
"context"
"encoding/json"
"errors"
"fmt"
"log/slog"
"strings"
Expand All @@ -20,6 +21,15 @@ import (
"github.com/Harsh-2002/Orva/backend/internal/database"
)

// maxRepeatedToolFailures is the loop-breaker budget: when the model re-issues
// the SAME tool call (identical name + arguments) and it fails this many times
// in consecutive iterations, the turn stops with an error instead of running
// out the full MaxIterations budget. Keying on the exact call keeps legitimate
// exploration alive (probing three different missing functions is three
// different calls) while catching the real pathology — a stuck model retrying
// one broken call, classically after its arguments were corrupted.
const maxRepeatedToolFailures = 3

// Tool is one callable the agent may offer the model, plus the metadata that
// drives approval gating and UI grouping.
type Tool struct {
Expand Down Expand Up @@ -174,6 +184,10 @@ func (r *Runner) advance(ctx context.Context, sink Sink, convID, principalID str
return err
}

// Loop-breaker state: per-call-signature consecutive-failure streaks (see
// maxRepeatedToolFailures).
failStreak := map[string]int{}

for iter := 0; iter < r.cfg.MaxIterations; iter++ {
// Bail if the client disconnected (or the request was cancelled).
// Otherwise the loop would keep calling the billed provider and
Expand Down Expand Up @@ -220,7 +234,9 @@ func (r *Runner) advance(ctx context.Context, sink Sink, convID, principalID str
toolCalls = ev.ToolCalls
usage = ev.Usage
case llm.EventError:
r.finalizeAssistant(assistant, textB.String(), thinkB.String(), nil, nil)
// Keep any usage the provider reported before the cut so the
// turn's token accounting survives the error path.
r.finalizeAssistant(assistant, textB.String(), thinkB.String(), nil, ev.Usage)
_ = sink.Send("error", map[string]any{"message": ev.Err.Error()})
return ev.Err
}
Expand All @@ -247,11 +263,30 @@ func (r *Runner) advance(ctx context.Context, sink Sink, convID, principalID str

// Process the requested tool calls. Read-only / auto ones run now;
// anything gated pauses the turn for approval.
paused, results := r.processToolCalls(ctx, sink, convID, assistant.ID, toolCalls)
paused, results, failedSigs := r.processToolCalls(ctx, sink, convID, assistant.ID, toolCalls)
if paused {
_ = sink.Send("awaiting_approval", map[string]any{"conversation_id": convID})
return nil
}
// Advance the per-signature failure streaks: bump the calls that failed
// this round, reset every signature that didn't recur.
failedNow := make(map[string]bool, len(failedSigs))
for _, sig := range failedSigs {
failedNow[sig] = true
failStreak[sig]++
}
for sig := range failStreak {
if !failedNow[sig] {
delete(failStreak, sig)
}
}
for _, sig := range failedSigs {
if failStreak[sig] >= maxRepeatedToolFailures {
err := fmt.Errorf("stopped: the model re-issued the same failing tool call %d times in a row and is not making progress — retry the message or rephrase the request", failStreak[sig])
_ = sink.Send("error", map[string]any{"message": err.Error()})
return err
}
}
// Append the persisted tool results, then loop and let the model continue.
history = append(history, results...)
}
Expand All @@ -264,8 +299,10 @@ func (r *Runner) advance(ctx context.Context, sink Sink, convID, principalID str
// non-gated ones immediately, and returns true if any call is awaiting approval
// (the turn must pause). It also returns the tool-result messages it persisted,
// in order, so the caller can append them to the in-memory history without a
// re-read. When paused, results is irrelevant (the turn returns).
func (r *Runner) processToolCalls(ctx context.Context, sink Sink, convID, msgID string, calls []llm.ToolCall) (paused bool, results []llm.Message) {
// re-read, and the signatures (name + arguments) of the calls that failed —
// the caller's loop-breaker signal. When paused, results is irrelevant (the
// turn returns).
func (r *Runner) processToolCalls(ctx context.Context, sink Sink, convID, msgID string, calls []llm.ToolCall) (paused bool, results []llm.Message, failedSigs []string) {
for _, c := range calls {
meta, known := r.byName[c.Name]
requiresApproval := known && r.approvalNeeded(meta)
Expand All @@ -291,21 +328,38 @@ func (r *Runner) processToolCalls(ctx context.Context, sink Sink, convID, msgID
Destructive: destructive,
}
_ = r.store.InsertToolCall(row)
// Embed the args verbatim only when they're valid JSON. Invalid args
// (truncated stream) would fail the frame's own marshal and the sink
// degrades the whole payload to {} — losing the id and name the UI
// needs to render the (failing) call. Fall back to a JSON string,
// capped: the call is about to be failed anyway, so don't push a huge
// broken blob down the SSE pipe just to label it.
argsPayload := json.RawMessage(row.Args)
if !json.Valid(argsPayload) {
s := row.Args
if len(s) > 2048 {
s = s[:2048] + "…"
}
argsPayload = json.RawMessage(mustJSON(s))
}
_ = sink.Send("tool_call", map[string]any{
"id": row.ID, "call_id": c.ID, "name": c.Name, "group": group,
"args": json.RawMessage(row.Args), "requires_approval": requiresApproval,
"args": argsPayload, "requires_approval": requiresApproval,
})

if requiresApproval {
paused = true
continue // wait for the user; do NOT run it
}
r.runToolCall(ctx, sink, convID, row)
if row.Status == "failed" {
failedSigs = append(failedSigs, c.Name+"\x00"+row.Args)
}
// Record the result for the in-memory history (mirrors persistToolResult:
// a role=tool message keyed by the call id).
results = append(results, llm.Message{Role: llm.RoleTool, Content: row.Result, ToolCallID: row.CallID})
}
return paused, results
return paused, results, failedSigs
}

// runToolCall dispatches one (approved or auto) tool call, persists the
Expand All @@ -318,10 +372,22 @@ func (r *Runner) runToolCall(ctx context.Context, sink Sink, convID string, row
}
started := time.Now().UTC()
row.StartedAt = &started
row.Status = "running"
_ = r.store.UpdateToolCall(row)

out, err := r.dispatch(ctx, row.ToolName, json.RawMessage(emptyToObj(row.Args)))
// Refuse to dispatch arguments that aren't valid JSON. They can't be a
// real call — the most common cause is a provider stream truncated mid
// tool call — and the dispatcher would only fail later with a confusing
// tool-specific unmarshal error. Failing here gives the model (and the
// audit row) a precise, self-correctable signal. (The doomed call also
// skips the interim status=running write — it goes straight to failed.)
var out json.RawMessage
var err error
if raw := json.RawMessage(emptyToObj(row.Args)); !json.Valid(raw) {
err = errors.New("invalid tool arguments: not valid JSON (the model response may have been truncated) — re-issue the call with complete arguments")
} else {
row.Status = "running"
_ = r.store.UpdateToolCall(row)
out, err = r.dispatch(ctx, row.ToolName, raw)
}
var resultJSON string
if err != nil {
resultJSON = mustJSON(map[string]string{"error": err.Error()})
Expand All @@ -340,9 +406,16 @@ func (r *Runner) runToolCall(ctx context.Context, sink Sink, convID string, row
row.Result = resultJSON
_ = r.store.UpdateToolCall(row)
r.persistToolResult(convID, row, resultJSON)
// Same {}-degradation hazard as the tool_call frame: a dispatcher that
// returns non-JSON output would fail the frame's marshal and the sink
// would drop the id/status the UI needs to settle the call's spinner.
resultPayload := json.RawMessage(resultJSON)
if !json.Valid(resultPayload) {
resultPayload = json.RawMessage(mustJSON(resultJSON))
}
_ = sink.Send("tool_result", map[string]any{
"id": row.ID, "call_id": row.CallID, "status": row.Status,
"result": json.RawMessage(resultJSON),
"result": resultPayload,
})
}

Expand Down
155 changes: 155 additions & 0 deletions backend/internal/ai/agent/agent_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
package agent

import (
"context"
"encoding/json"
"errors"
"strings"
"testing"

"github.com/Harsh-2002/Orva/backend/internal/ai/llm"
"github.com/Harsh-2002/Orva/backend/internal/database"
)

// ─── fakes ───────────────────────────────────────────────────────────────────

type fakeStore struct {
messages []*database.AIMessage
toolCalls []*database.AIToolCall
}

func (s *fakeStore) InsertMessage(m *database.AIMessage) error {
s.messages = append(s.messages, m)
return nil
}
func (s *fakeStore) UpdateMessage(id, content, parts, tokenUsage string) error { return nil }
func (s *fakeStore) ListMessages(conversationID string, sinceSeq int) ([]*database.AIMessage, error) {
return s.messages, nil
}
func (s *fakeStore) InsertToolCall(t *database.AIToolCall) error {
s.toolCalls = append(s.toolCalls, t)
return nil
}
func (s *fakeStore) GetToolCall(id string) (*database.AIToolCall, error) {
return nil, errors.New("not found")
}
func (s *fakeStore) UpdateToolCall(t *database.AIToolCall) error { return nil }
func (s *fakeStore) ListToolCalls(conversationID string) ([]*database.AIToolCall, error) {
return s.toolCalls, nil
}
func (s *fakeStore) TouchConversation(id string) error { return nil }

type fakeSink struct{ events []string }

func (s *fakeSink) Send(event string, data any) error {
s.events = append(s.events, event)
return nil
}

func testRunner(dispatch Dispatcher, tools ...Tool) *Runner {
return New(nil, tools, dispatch, &fakeStore{}, Config{ApprovalPolicy: "auto"})
}

func readOnlyTool(name string) Tool {
return Tool{Def: llm.ToolDef{Name: name}, Group: "test", Perm: "read", ReadOnly: true}
}

// ─── invalid-argument guard (truncated tool calls) ──────────────────────────

// TestRunToolCallRejectsInvalidArgsWithoutDispatch: arguments that aren't
// valid JSON (the signature of a provider stream truncated mid tool call)
// must fail fast with a self-correctable error — and never reach the
// dispatcher, whose tool-specific unmarshal error would obscure the cause.
func TestRunToolCallRejectsInvalidArgsWithoutDispatch(t *testing.T) {
dispatched := false
r := testRunner(func(ctx context.Context, name string, args json.RawMessage) (json.RawMessage, error) {
dispatched = true
return json.RawMessage(`{}`), nil
}, readOnlyTool("list_functions"))

row := &database.AIToolCall{
ConversationID: "conv1", CallID: "call_0",
ToolName: "list_functions",
Args: `{"name":"demo","runt`, // truncated mid-stream
}
r.runToolCall(context.Background(), &fakeSink{}, "conv1", row)

if dispatched {
t.Error("dispatcher must not be called with invalid JSON arguments")
}
if row.Status != "failed" {
t.Errorf("status = %q, want failed", row.Status)
}
if !strings.Contains(row.Result, "invalid tool arguments") {
t.Errorf("result should explain the invalid arguments, got %q", row.Result)
}
}

// TestRunToolCallValidArgsStillDispatch: the guard must not get in the way of
// well-formed calls (including empty args, which normalize to {}).
func TestRunToolCallValidArgsStillDispatch(t *testing.T) {
for _, args := range []string{`{"limit":5}`, "", " "} {
dispatched := false
r := testRunner(func(ctx context.Context, name string, args json.RawMessage) (json.RawMessage, error) {
dispatched = true
return json.RawMessage(`{"ok":true}`), nil
}, readOnlyTool("list_functions"))

row := &database.AIToolCall{ConversationID: "c", CallID: "call_0", ToolName: "list_functions", Args: args}
r.runToolCall(context.Background(), &fakeSink{}, "c", row)

if !dispatched {
t.Errorf("args %q: dispatcher should have been called", args)
}
if row.Status != "succeeded" {
t.Errorf("args %q: status = %q, want succeeded", args, row.Status)
}
}
}

// ─── loop-breaker signal ─────────────────────────────────────────────────────

// TestProcessToolCallsFailedSignatures: the third return value feeds the
// agent's loop-breaker; it must carry the signature (name + args) of exactly
// the calls that failed, so the breaker can track per-call streaks — a
// succeeding companion call must not mask a repeatedly failing one, and
// distinct failing probes must not be conflated into one signature.
func TestProcessToolCallsFailedSignatures(t *testing.T) {
succeedFor := func(okName string) Dispatcher {
return func(ctx context.Context, name string, args json.RawMessage) (json.RawMessage, error) {
if name == okName {
return json.RawMessage(`{}`), nil
}
return nil, errors.New("boom")
}
}
calls := []llm.ToolCall{
{ID: "c1", Name: "alpha", Arguments: `{"x":1}`},
{ID: "c2", Name: "beta", Arguments: `{}`},
}
tools := []Tool{readOnlyTool("alpha"), readOnlyTool("beta")}

r := testRunner(succeedFor("beta"), tools...)
_, _, sigs := r.processToolCalls(context.Background(), &fakeSink{}, "c", "m", calls)
if len(sigs) != 1 {
t.Fatalf("exactly the failing call must be reported, got %d signatures", len(sigs))
}
if sigs[0] != "alpha\x00"+`{"x":1}` {
t.Errorf("signature must be name+args, got %q", sigs[0])
}

r = testRunner(succeedFor("alpha"), tools...)
_, _, sigs = r.processToolCalls(context.Background(), &fakeSink{}, "c", "m", calls)
if len(sigs) != 1 || sigs[0] != "beta\x00{}" {
t.Errorf("expected only beta's signature, got %v", sigs)
}

allOK := func(ctx context.Context, name string, args json.RawMessage) (json.RawMessage, error) {
return json.RawMessage(`{}`), nil
}
r = testRunner(allOK, tools...)
_, _, sigs = r.processToolCalls(context.Background(), &fakeSink{}, "c", "m", calls)
if len(sigs) != 0 {
t.Errorf("no failures → no signatures, got %v", sigs)
}
}
Loading
Loading