Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions pkg/concurrent/slice.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,21 @@ func (s *Slice[V]) Update(index int, f func(V) V) bool {
return true
}

// FindAndUpdate atomically finds the first element matching the predicate
// and applies f to it. It returns true if an element was found and updated.
func (s *Slice[V]) FindAndUpdate(predicate func(V) bool, f func(V) V) bool {
s.mu.Lock()
defer s.mu.Unlock()

for i, v := range s.values {
if predicate(v) {
s.values[i] = f(s.values[i])
return true
}
}
return false
}

func (s *Slice[V]) Clear() {
s.mu.Lock()
defer s.mu.Unlock()
Expand Down
132 changes: 76 additions & 56 deletions pkg/tools/builtin/todo.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,17 +86,19 @@ type TodoStorage interface {
All() []Todo
// Len returns the number of todo items.
Len() int
// FindByID returns the index of the todo with the given ID, or -1 if not found.
FindByID(id string) int
// Update modifies the todo at the given index using the provided function.
Update(index int, fn func(Todo) Todo)
// NextID returns a unique, monotonically increasing ID for a new todo.
NextID() int64
// UpdateByID atomically finds a todo by ID and applies fn to it.
// It returns true if the todo was found and updated, false otherwise.
UpdateByID(id string, fn func(Todo) Todo) bool
// Clear removes all todo items.
Clear()
}

// MemoryTodoStorage is an in-memory, concurrency-safe implementation of TodoStorage.
type MemoryTodoStorage struct {
todos *concurrent.Slice[Todo]
todos *concurrent.Slice[Todo]
nextID atomic.Int64
}

func NewMemoryTodoStorage() *MemoryTodoStorage {
Expand All @@ -110,20 +112,23 @@ func (s *MemoryTodoStorage) Add(todo Todo) {
}

func (s *MemoryTodoStorage) All() []Todo {
return s.todos.All()
all := s.todos.All()
if all == nil {
return []Todo{}
}
return all
}

func (s *MemoryTodoStorage) Len() int {
return s.todos.Length()
}

func (s *MemoryTodoStorage) FindByID(id string) int {
_, idx := s.todos.Find(func(t Todo) bool { return t.ID == id })
return idx
func (s *MemoryTodoStorage) NextID() int64 {
return s.nextID.Add(1)
}

func (s *MemoryTodoStorage) Update(index int, fn func(Todo) Todo) {
s.todos.Update(index, fn)
func (s *MemoryTodoStorage) UpdateByID(id string, fn func(Todo) Todo) bool {
return s.todos.FindAndUpdate(func(t Todo) bool { return t.ID == id }, fn)
}

func (s *MemoryTodoStorage) Clear() {
Expand All @@ -146,7 +151,6 @@ func WithStorage(storage TodoStorage) TodoOption {

type todoHandler struct {
storage TodoStorage
nextID atomic.Int64
}

var NewSharedTodoTool = sync.OnceValue(func() *TodoTool { return NewTodoTool() })
Expand Down Expand Up @@ -177,115 +181,131 @@ Track task progress with todos:
// addTodo creates a new todo and adds it to storage.
func (h *todoHandler) addTodo(description string) Todo {
todo := Todo{
ID: fmt.Sprintf("todo_%d", h.nextID.Add(1)),
ID: fmt.Sprintf("todo_%d", h.storage.NextID()),
Description: description,
Status: "pending",
}
h.storage.Add(todo)
return todo
}

// jsonResult builds a ToolCallResult with a JSON-serialized output and current storage as Meta.
func (h *todoHandler) jsonResult(v any) (*tools.ToolCallResult, error) {
// jsonResult builds a ToolCallResult with a JSON-serialized output and allTodos as Meta.
func (h *todoHandler) jsonResult(v any, allTodos []Todo) (*tools.ToolCallResult, error) {
out, err := json.Marshal(v)
if err != nil {
return nil, fmt.Errorf("marshaling todo output: %w", err)
}
return &tools.ToolCallResult{
Output: string(out),
Meta: h.storage.All(),
Meta: allTodos,
}, nil
}

func (h *todoHandler) createTodo(_ context.Context, params CreateTodoArgs) (*tools.ToolCallResult, error) {
created := h.addTodo(params.Description)
allTodos := h.storage.All()
return h.jsonResult(CreateTodoOutput{
Created: created,
AllTodos: h.storage.All(),
Reminder: h.incompleteReminder(),
})
AllTodos: allTodos,
Reminder: incompleteReminder(allTodos),
}, allTodos)
}

func (h *todoHandler) createTodos(_ context.Context, params CreateTodosArgs) (*tools.ToolCallResult, error) {
created := make([]Todo, 0, len(params.Descriptions))
for _, desc := range params.Descriptions {
created = append(created, h.addTodo(desc))
}
allTodos := h.storage.All()
return h.jsonResult(CreateTodosOutput{
Created: created,
AllTodos: h.storage.All(),
Reminder: h.incompleteReminder(),
})
AllTodos: allTodos,
Reminder: incompleteReminder(allTodos),
}, allTodos)
}

func (h *todoHandler) updateTodos(_ context.Context, params UpdateTodosArgs) (*tools.ToolCallResult, error) {
result := UpdateTodosOutput{}
// validTodoStatuses defines the set of allowed todo statuses.
var validTodoStatuses = map[string]bool{
"pending": true,
"in-progress": true,
"completed": true,
}

func (h *todoHandler) updateTodos(_ context.Context, params UpdateTodosArgs) (*tools.ToolCallResult, error) {
for _, update := range params.Updates {
idx := h.storage.FindByID(update.ID)
if idx == -1 {
result.NotFound = append(result.NotFound, update.ID)
if validTodoStatuses[update.Status] {
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔴 CRITICAL: Inverted validation logic

The validation check is inverted. This code checks if the status IS valid and then continues, which means only INVALID statuses will reach the error-handling code below.

Current behavior: Valid statuses like "pending", "in-progress", "completed" will skip the error block and proceed. Invalid statuses like "done" will trigger the error.

Expected behavior: Invalid statuses should trigger the error, valid ones should proceed.

Fix: Add a negation operator:

if !validTodoStatuses[update.Status] {

This bug will cause TestTodoTool_UpdateTodos_InvalidStatus to fail, as the test expects invalid status "done" to be rejected.

continue
}
allTodos := h.storage.All()
res, err := h.jsonResult(UpdateTodosOutput{
AllTodos: allTodos,
Reminder: fmt.Sprintf("invalid status %q for todo %s: must be one of pending, in-progress, completed", update.Status, update.ID),
}, allTodos)
if err != nil {
return nil, err
}
res.IsError = true
return res, nil
}

h.storage.Update(idx, func(t Todo) Todo {
result := UpdateTodosOutput{}

for _, update := range params.Updates {
ok := h.storage.UpdateByID(update.ID, func(t Todo) Todo {
t.Status = update.Status
return t
})
if !ok {
result.NotFound = append(result.NotFound, update.ID)
continue
}
result.Updated = append(result.Updated, update)
}

allTodos := h.storage.All()
result.AllTodos = allTodos
result.Reminder = incompleteReminder(allTodos)

if len(result.NotFound) > 0 && len(result.Updated) == 0 {
res, err := h.jsonResult(result)
res, err := h.jsonResult(result, allTodos)
if err != nil {
return nil, err
}
res.IsError = true
return res, nil
}

result.AllTodos = h.storage.All()
result.Reminder = h.incompleteReminder()

return h.jsonResult(result)
return h.jsonResult(result, allTodos)
}

// incompleteReminder returns a reminder string listing any non-completed todos,
// or an empty string if all are completed (or storage is empty).
func (h *todoHandler) incompleteReminder() string {
all := h.storage.All()
var pending, inProgress []string
// or an empty string if all are completed (or the list is empty).
func incompleteReminder(all []Todo) string {
var b strings.Builder
for _, todo := range all {
var prefix string
switch todo.Status {
case "pending":
pending = append(pending, fmt.Sprintf("[%s] %s", todo.ID, todo.Description))
prefix = " (pending) "
case "in-progress":
inProgress = append(inProgress, fmt.Sprintf("[%s] %s", todo.ID, todo.Description))
prefix = " (in-progress) "
default:
continue
}
}
if len(pending) == 0 && len(inProgress) == 0 {
return ""
}

var b strings.Builder
b.WriteString("The following todos are still incomplete and MUST be completed:")
for _, s := range inProgress {
b.WriteString(" (in-progress) " + s)
}
for _, s := range pending {
b.WriteString(" (pending) " + s)
if b.Len() == 0 {
b.WriteString("The following todos are still incomplete and MUST be completed:")
}
b.WriteString(prefix)
fmt.Fprintf(&b, "[%s] %s", todo.ID, todo.Description)
}
return b.String()
}

func (h *todoHandler) listTodos(_ context.Context, _ tools.ToolCall) (*tools.ToolCallResult, error) {
todos := h.storage.All()
if todos == nil {
todos = []Todo{}
}
out := ListTodosOutput{Todos: todos}
out.Reminder = h.incompleteReminder()
return h.jsonResult(out)
out.Reminder = incompleteReminder(todos)
return h.jsonResult(out, todos)
}

func (t *TodoTool) Tools(context.Context) ([]tools.Tool, error) {
Expand Down
28 changes: 28 additions & 0 deletions pkg/tools/builtin/todo_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,34 @@ func TestTodoTool_UpdateTodos_AllNotFound(t *testing.T) {
assert.Equal(t, "nonexistent2", output.NotFound[1])
}

func TestTodoTool_UpdateTodos_InvalidStatus(t *testing.T) {
storage := NewMemoryTodoStorage()
tool := NewTodoTool(WithStorage(storage))

_, err := tool.handler.createTodo(t.Context(), CreateTodoArgs{Description: "Task"})
require.NoError(t, err)

result, err := tool.handler.updateTodos(t.Context(), UpdateTodosArgs{
Updates: []TodoUpdate{
{ID: "todo_1", Status: "done"},
},
})
require.NoError(t, err)
assert.True(t, result.IsError)

var output UpdateTodosOutput
require.NoError(t, json.Unmarshal([]byte(result.Output), &output))
assert.Contains(t, output.Reminder, "done")
assert.Empty(t, output.Updated)
assert.Empty(t, output.NotFound)
require.Len(t, output.AllTodos, 1)

// Storage should be unchanged — no partial mutation.
todos := storage.All()
require.Len(t, todos, 1)
assert.Equal(t, "pending", todos[0].Status)
}

func TestTodoTool_UpdateTodos_AllCompleted_NoAutoRemoval(t *testing.T) {
storage := NewMemoryTodoStorage()
tool := NewTodoTool(WithStorage(storage))
Expand Down
Loading