diff --git a/pkg/concurrent/slice.go b/pkg/concurrent/slice.go index 539c523ff..bee11541d 100644 --- a/pkg/concurrent/slice.go +++ b/pkg/concurrent/slice.go @@ -92,6 +92,21 @@ func (s *Slice[V]) Update(index int, f func(V) V) bool { return true } +// FindAndUpdate atomically finds the first element matching the predicate +// and applies f to it. It returns true if an element was found and updated. +func (s *Slice[V]) FindAndUpdate(predicate func(V) bool, f func(V) V) bool { + s.mu.Lock() + defer s.mu.Unlock() + + for i, v := range s.values { + if predicate(v) { + s.values[i] = f(s.values[i]) + return true + } + } + return false +} + func (s *Slice[V]) Clear() { s.mu.Lock() defer s.mu.Unlock() diff --git a/pkg/tools/builtin/todo.go b/pkg/tools/builtin/todo.go index 6d412a716..1651cb231 100644 --- a/pkg/tools/builtin/todo.go +++ b/pkg/tools/builtin/todo.go @@ -86,17 +86,19 @@ type TodoStorage interface { All() []Todo // Len returns the number of todo items. Len() int - // FindByID returns the index of the todo with the given ID, or -1 if not found. - FindByID(id string) int - // Update modifies the todo at the given index using the provided function. - Update(index int, fn func(Todo) Todo) + // NextID returns a unique, monotonically increasing ID for a new todo. + NextID() int64 + // UpdateByID atomically finds a todo by ID and applies fn to it. + // It returns true if the todo was found and updated, false otherwise. + UpdateByID(id string, fn func(Todo) Todo) bool // Clear removes all todo items. Clear() } // MemoryTodoStorage is an in-memory, concurrency-safe implementation of TodoStorage. type MemoryTodoStorage struct { - todos *concurrent.Slice[Todo] + todos *concurrent.Slice[Todo] + nextID atomic.Int64 } func NewMemoryTodoStorage() *MemoryTodoStorage { @@ -110,20 +112,23 @@ func (s *MemoryTodoStorage) Add(todo Todo) { } func (s *MemoryTodoStorage) All() []Todo { - return s.todos.All() + all := s.todos.All() + if all == nil { + return []Todo{} + } + return all } func (s *MemoryTodoStorage) Len() int { return s.todos.Length() } -func (s *MemoryTodoStorage) FindByID(id string) int { - _, idx := s.todos.Find(func(t Todo) bool { return t.ID == id }) - return idx +func (s *MemoryTodoStorage) NextID() int64 { + return s.nextID.Add(1) } -func (s *MemoryTodoStorage) Update(index int, fn func(Todo) Todo) { - s.todos.Update(index, fn) +func (s *MemoryTodoStorage) UpdateByID(id string, fn func(Todo) Todo) bool { + return s.todos.FindAndUpdate(func(t Todo) bool { return t.ID == id }, fn) } func (s *MemoryTodoStorage) Clear() { @@ -146,7 +151,6 @@ func WithStorage(storage TodoStorage) TodoOption { type todoHandler struct { storage TodoStorage - nextID atomic.Int64 } var NewSharedTodoTool = sync.OnceValue(func() *TodoTool { return NewTodoTool() }) @@ -177,7 +181,7 @@ Track task progress with todos: // addTodo creates a new todo and adds it to storage. func (h *todoHandler) addTodo(description string) Todo { todo := Todo{ - ID: fmt.Sprintf("todo_%d", h.nextID.Add(1)), + ID: fmt.Sprintf("todo_%d", h.storage.NextID()), Description: description, Status: "pending", } @@ -185,25 +189,26 @@ func (h *todoHandler) addTodo(description string) Todo { return todo } -// jsonResult builds a ToolCallResult with a JSON-serialized output and current storage as Meta. -func (h *todoHandler) jsonResult(v any) (*tools.ToolCallResult, error) { +// jsonResult builds a ToolCallResult with a JSON-serialized output and allTodos as Meta. +func (h *todoHandler) jsonResult(v any, allTodos []Todo) (*tools.ToolCallResult, error) { out, err := json.Marshal(v) if err != nil { return nil, fmt.Errorf("marshaling todo output: %w", err) } return &tools.ToolCallResult{ Output: string(out), - Meta: h.storage.All(), + Meta: allTodos, }, nil } func (h *todoHandler) createTodo(_ context.Context, params CreateTodoArgs) (*tools.ToolCallResult, error) { created := h.addTodo(params.Description) + allTodos := h.storage.All() return h.jsonResult(CreateTodoOutput{ Created: created, - AllTodos: h.storage.All(), - Reminder: h.incompleteReminder(), - }) + AllTodos: allTodos, + Reminder: incompleteReminder(allTodos), + }, allTodos) } func (h *todoHandler) createTodos(_ context.Context, params CreateTodosArgs) (*tools.ToolCallResult, error) { @@ -211,32 +216,58 @@ func (h *todoHandler) createTodos(_ context.Context, params CreateTodosArgs) (*t for _, desc := range params.Descriptions { created = append(created, h.addTodo(desc)) } + allTodos := h.storage.All() return h.jsonResult(CreateTodosOutput{ Created: created, - AllTodos: h.storage.All(), - Reminder: h.incompleteReminder(), - }) + AllTodos: allTodos, + Reminder: incompleteReminder(allTodos), + }, allTodos) } -func (h *todoHandler) updateTodos(_ context.Context, params UpdateTodosArgs) (*tools.ToolCallResult, error) { - result := UpdateTodosOutput{} +// validTodoStatuses defines the set of allowed todo statuses. +var validTodoStatuses = map[string]bool{ + "pending": true, + "in-progress": true, + "completed": true, +} +func (h *todoHandler) updateTodos(_ context.Context, params UpdateTodosArgs) (*tools.ToolCallResult, error) { for _, update := range params.Updates { - idx := h.storage.FindByID(update.ID) - if idx == -1 { - result.NotFound = append(result.NotFound, update.ID) + if validTodoStatuses[update.Status] { continue } + allTodos := h.storage.All() + res, err := h.jsonResult(UpdateTodosOutput{ + AllTodos: allTodos, + Reminder: fmt.Sprintf("invalid status %q for todo %s: must be one of pending, in-progress, completed", update.Status, update.ID), + }, allTodos) + if err != nil { + return nil, err + } + res.IsError = true + return res, nil + } - h.storage.Update(idx, func(t Todo) Todo { + result := UpdateTodosOutput{} + + for _, update := range params.Updates { + ok := h.storage.UpdateByID(update.ID, func(t Todo) Todo { t.Status = update.Status return t }) + if !ok { + result.NotFound = append(result.NotFound, update.ID) + continue + } result.Updated = append(result.Updated, update) } + allTodos := h.storage.All() + result.AllTodos = allTodos + result.Reminder = incompleteReminder(allTodos) + if len(result.NotFound) > 0 && len(result.Updated) == 0 { - res, err := h.jsonResult(result) + res, err := h.jsonResult(result, allTodos) if err != nil { return nil, err } @@ -244,48 +275,37 @@ func (h *todoHandler) updateTodos(_ context.Context, params UpdateTodosArgs) (*t return res, nil } - result.AllTodos = h.storage.All() - result.Reminder = h.incompleteReminder() - - return h.jsonResult(result) + return h.jsonResult(result, allTodos) } // incompleteReminder returns a reminder string listing any non-completed todos, -// or an empty string if all are completed (or storage is empty). -func (h *todoHandler) incompleteReminder() string { - all := h.storage.All() - var pending, inProgress []string +// or an empty string if all are completed (or the list is empty). +func incompleteReminder(all []Todo) string { + var b strings.Builder for _, todo := range all { + var prefix string switch todo.Status { case "pending": - pending = append(pending, fmt.Sprintf("[%s] %s", todo.ID, todo.Description)) + prefix = " (pending) " case "in-progress": - inProgress = append(inProgress, fmt.Sprintf("[%s] %s", todo.ID, todo.Description)) + prefix = " (in-progress) " + default: + continue } - } - if len(pending) == 0 && len(inProgress) == 0 { - return "" - } - - var b strings.Builder - b.WriteString("The following todos are still incomplete and MUST be completed:") - for _, s := range inProgress { - b.WriteString(" (in-progress) " + s) - } - for _, s := range pending { - b.WriteString(" (pending) " + s) + if b.Len() == 0 { + b.WriteString("The following todos are still incomplete and MUST be completed:") + } + b.WriteString(prefix) + fmt.Fprintf(&b, "[%s] %s", todo.ID, todo.Description) } return b.String() } func (h *todoHandler) listTodos(_ context.Context, _ tools.ToolCall) (*tools.ToolCallResult, error) { todos := h.storage.All() - if todos == nil { - todos = []Todo{} - } out := ListTodosOutput{Todos: todos} - out.Reminder = h.incompleteReminder() - return h.jsonResult(out) + out.Reminder = incompleteReminder(todos) + return h.jsonResult(out, todos) } func (t *TodoTool) Tools(context.Context) ([]tools.Tool, error) { diff --git a/pkg/tools/builtin/todo_test.go b/pkg/tools/builtin/todo_test.go index b17da8103..f35fd45b5 100644 --- a/pkg/tools/builtin/todo_test.go +++ b/pkg/tools/builtin/todo_test.go @@ -230,6 +230,34 @@ func TestTodoTool_UpdateTodos_AllNotFound(t *testing.T) { assert.Equal(t, "nonexistent2", output.NotFound[1]) } +func TestTodoTool_UpdateTodos_InvalidStatus(t *testing.T) { + storage := NewMemoryTodoStorage() + tool := NewTodoTool(WithStorage(storage)) + + _, err := tool.handler.createTodo(t.Context(), CreateTodoArgs{Description: "Task"}) + require.NoError(t, err) + + result, err := tool.handler.updateTodos(t.Context(), UpdateTodosArgs{ + Updates: []TodoUpdate{ + {ID: "todo_1", Status: "done"}, + }, + }) + require.NoError(t, err) + assert.True(t, result.IsError) + + var output UpdateTodosOutput + require.NoError(t, json.Unmarshal([]byte(result.Output), &output)) + assert.Contains(t, output.Reminder, "done") + assert.Empty(t, output.Updated) + assert.Empty(t, output.NotFound) + require.Len(t, output.AllTodos, 1) + + // Storage should be unchanged — no partial mutation. + todos := storage.All() + require.Len(t, todos, 1) + assert.Equal(t, "pending", todos[0].Status) +} + func TestTodoTool_UpdateTodos_AllCompleted_NoAutoRemoval(t *testing.T) { storage := NewMemoryTodoStorage() tool := NewTodoTool(WithStorage(storage))