Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 85 additions & 5 deletions agent/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,13 @@ type AgentConfig struct {
// today's streaming behavior). A DBOS-backed runner lives in agent/durable/dbos.
StepRunner types.StepRunner

// Store persists the conversation tree. Defaults to nil, which preserves
// today's in-memory-only behavior (the tree lives only for the agent's
// lifetime). When set, each node added during a run is written to the Store
// best-effort (errors are logged, never fatal), and NewAgent can hydrate the
// tree from the Store via WithStore (see WithStore / LoadTreeFromStore).
Store types.Store

// File pipeline configuration.
Resolvers map[string]types.Resolver // URI scheme → Resolver (e.g. "file", "https", "s3")
Extractors map[types.MediaType]types.Extractor // MediaType → Extractor for non-native types
Expand Down Expand Up @@ -105,6 +112,15 @@ func WithStepRunner(r types.StepRunner) AgentOption {
return func(c *AgentConfig) { c.StepRunner = r }
}

// WithStore configures a types.Store so the conversation tree is persisted.
// With no Store (the default) the tree is in-memory only — fully backward
// compatible. When a Store is set, each node added during a run is written
// best-effort; loading a previously-persisted tree is done explicitly via
// LoadTreeFromStore before NewAgent (pass the rebuilt tree with WithTree).
func WithStore(s types.Store) AgentOption {
return func(c *AgentConfig) { c.Store = s }
}

// Agent runs an LLM agent loop with tool execution.
// All conversations are backed by a Tree.
type Agent struct {
Expand Down Expand Up @@ -153,6 +169,16 @@ func NewAgent(cfg AgentConfig, opts ...AgentOption) *Agent {

a := &Agent{cfg: cfg, tools: tools}

// When a Store is configured, persist the tree's root node + main branch tip
// up front so a later LoadTreeFromStore has an anchor even before the first
// Invoke. Best-effort: a failure here is logged, never fatal (mirrors the
// per-node persistence in runLoop).
if cfg.Store != nil {
if root := cfg.Tree.Root(); root != nil {
a.persistNode(context.Background(), root)
}
}

// Build the handoff group (if any). The entry agent shares this tree; each
// member's handoff_to_<target> tools are wired into its registry.
if len(cfg.Handoffs) > 0 {
Expand Down Expand Up @@ -601,6 +627,50 @@ func uriScheme(uri string) string {
return u.Scheme
}

// ── Store persistence ────────────────────────────────────────────────

// persistNode writes a freshly-added node and its branch tip to the configured
// Store. It is best-effort: a nil Store is a no-op, and any error is logged but
// never propagated, so persistence failures cannot break a live agent run. When
// the Store exposes a transaction, the node + branch tip are committed together
// so a reader never observes a tip pointing at an unsaved node.
func (a *Agent) persistNode(ctx context.Context, node *types.Node) {
if a.cfg.Store == nil || node == nil {
return
}
err := a.cfg.Store.Tx(ctx, func(tx types.StoreTx) error {
if err := tx.SaveNode(ctx, node); err != nil {
return err
}
return tx.SaveBranch(ctx, node.BranchID, node.ID)
})
if err != nil {
a.cfg.Logger.Warn("store persist failed",
"agent", a.cfg.Name, "node", node.ID, "branch", node.BranchID, "error", err)
}
}

// LoadTreeFromStore reconstructs a conversation tree from a Store by loading the
// subtree rooted at rootID and rebuilding it via tree.FromStore. The returned
// tree can be passed to NewAgent via WithTree to resume a persisted session. The
// active branch defaults to "main" unless overridden.
//
// This is the read counterpart to WithStore's write path. It is a free function
// (not a method) so a tree can be hydrated before an Agent exists.
func LoadTreeFromStore(ctx context.Context, store types.Store, rootID types.NodeID, active types.BranchID) (*tree.Tree, error) {
if store == nil {
return nil, fmt.Errorf("agent: nil store")
}
nodes, branches, err := store.LoadTree(ctx, rootID)
if err != nil {
return nil, fmt.Errorf("load tree: %w", err)
}
if active == "" {
active = types.BranchID("main")
}
return tree.FromStore(nodes, branches, nil, rootID, active)
}

// ── Run loop ─────────────────────────────────────────────────────────

func (a *Agent) runLoop(ctx context.Context, stream *EventStream, input []types.Message, branch types.BranchID) {
Expand Down Expand Up @@ -683,10 +753,12 @@ func (a *Agent) runLoop(ctx context.Context, stream *EventStream, input []types.
stream.send(types.ErrorDelta{Error: err})
return
}
if _, err := tr.AddChild(ctx, tip.ID, *msg); err != nil {
assistantNode, err := tr.AddChild(ctx, tip.ID, *msg)
if err != nil {
stream.send(types.ErrorDelta{Error: err})
return
}
a.persistNode(ctx, assistantNode)

toolCalls := assistantToolCalls(msg)
if len(toolCalls) == 0 {
Expand Down Expand Up @@ -721,10 +793,12 @@ func (a *Agent) appendInput(ctx context.Context, tr *tree.Tree, stream *EventStr
stream.send(types.ErrorDelta{Error: err})
return false
}
if _, err := tr.AddChild(ctx, tip.ID, msg); err != nil {
node, err := tr.AddChild(ctx, tip.ID, msg)
if err != nil {
stream.send(types.ErrorDelta{Error: err})
return false
}
a.persistNode(ctx, node)
}
return true
}
Expand Down Expand Up @@ -793,8 +867,12 @@ func (a *Agent) persistToolResults(ctx context.Context, tr *tree.Tree, branch ty
if err != nil {
return err
}
_, err = tr.AddChild(ctx, tip.ID, types.NewToolResultMessage(contents...))
return err
node, err := tr.AddChild(ctx, tip.ID, types.NewToolResultMessage(contents...))
if err != nil {
return err
}
a.persistNode(ctx, node)
return nil
}

// applyHandoff applies the first handoff signal in results, if any, appending a
Expand Down Expand Up @@ -828,10 +906,12 @@ func (a *Agent) applyHandoff(ctx context.Context, tr *tree.Tree, stream *EventSt
stream.send(types.ErrorDelta{Error: err})
return true
}
if _, err := tr.AddChild(ctx, tip.ID, overlay); err != nil {
node, err := tr.AddChild(ctx, tip.ID, overlay)
if err != nil {
stream.send(types.ErrorDelta{Error: err})
return true
}
a.persistNode(ctx, node)
return false
}

Expand Down
186 changes: 186 additions & 0 deletions agent/agent_store_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
package agent

import (
"context"
"strings"
"testing"

"github.com/urmzd/saige/agent/agenttest"
"github.com/urmzd/saige/agent/store/memstore"
"github.com/urmzd/saige/agent/types"
)

// messageText projects any message into a comparable text string, concatenating
// every text-bearing content block. It is good enough to assert that the
// flattened conversation round-trips through a Store.
func messageText(msg types.Message) string {
var sb strings.Builder
sb.WriteString(string(msg.Role()))
sb.WriteString(":")
switch m := msg.(type) {
case types.SystemMessage:
for _, c := range m.Content {
if tc, ok := c.(types.TextContent); ok {
sb.WriteString(tc.Text)
}
if trc, ok := c.(types.ToolResultContent); ok {
sb.WriteString("tool_result(")
sb.WriteString(trc.Text)
sb.WriteString(")")
}
}
case types.UserMessage:
for _, c := range m.Content {
if tc, ok := c.(types.TextContent); ok {
sb.WriteString(tc.Text)
}
}
case types.AssistantMessage:
for _, c := range m.Content {
if tc, ok := c.(types.TextContent); ok {
sb.WriteString(tc.Text)
}
}
}
return sb.String()
}

func flattenedText(t *testing.T, msgs []types.Message) []string {
t.Helper()
out := make([]string, len(msgs))
for i, m := range msgs {
out[i] = messageText(m)
}
return out
}

// TestStoreMultiTurnRoundTrip builds an agent backed by a memstore, runs two
// Invoke turns, then reconstructs the tree purely from the Store (LoadTree +
// FromStore) and asserts the full message history round-trips byte-for-byte.
func TestStoreMultiTurnRoundTrip(t *testing.T) {
ctx := context.Background()
store := memstore.New()

provider := &agenttest.ScriptedProvider{
Responses: [][]types.Delta{
agenttest.TextResponse("hello there"),
agenttest.TextResponse("second answer"),
},
}

ag := NewAgent(AgentConfig{
Name: "store-agent",
SystemPrompt: "you are helpful",
Provider: provider,
}, WithStore(store))

rootID := ag.Tree().Root().ID

// Turn 1.
stream := ag.Invoke(ctx, []types.Message{types.NewUserMessage("first question")})
if err := stream.Wait(); err != nil {
t.Fatalf("turn 1 invoke: %v", err)
}
// Turn 2.
stream = ag.Invoke(ctx, []types.Message{types.NewUserMessage("second question")})
if err := stream.Wait(); err != nil {
t.Fatalf("turn 2 invoke: %v", err)
}

branch := ag.Tree().Active()
liveMsgs, err := ag.Tree().FlattenBranch(branch)
if err != nil {
t.Fatalf("flatten live branch: %v", err)
}
if len(liveMsgs) < 5 {
t.Fatalf("expected at least system+2*(user+assistant) messages, got %d", len(liveMsgs))
}

// Reconstruct the tree from the Store only.
restored, err := LoadTreeFromStore(ctx, store, rootID, branch)
if err != nil {
t.Fatalf("LoadTreeFromStore: %v", err)
}

restoredMsgs, err := restored.FlattenBranch(branch)
if err != nil {
t.Fatalf("flatten restored branch: %v", err)
}

live := flattenedText(t, liveMsgs)
got := flattenedText(t, restoredMsgs)
if len(live) != len(got) {
t.Fatalf("message count mismatch: live=%d restored=%d\nlive=%v\nrestored=%v",
len(live), len(got), live, got)
}
for i := range live {
if live[i] != got[i] {
t.Fatalf("message[%d] mismatch:\n live=%q\n got =%q", i, live[i], got[i])
}
}

// Spot-check the actual conversation made it through.
joined := strings.Join(got, "|")
for _, want := range []string{"first question", "hello there", "second question", "second answer"} {
if !strings.Contains(joined, want) {
t.Fatalf("restored history missing %q: %v", want, got)
}
}
}

// TestStoreNilIsBackwardCompatible verifies that with no Store configured the
// agent behaves exactly as before — Invoke works and the tree is populated.
func TestStoreNilIsBackwardCompatible(t *testing.T) {
ctx := context.Background()
provider := &agenttest.ScriptedProvider{
Responses: [][]types.Delta{agenttest.TextResponse("ok")},
}
ag := NewAgent(AgentConfig{
Name: "no-store",
SystemPrompt: "sys",
Provider: provider,
})
if ag.cfg.Store != nil {
t.Fatal("expected nil Store by default")
}
stream := ag.Invoke(ctx, []types.Message{types.NewUserMessage("hi")})
if err := stream.Wait(); err != nil {
t.Fatalf("invoke: %v", err)
}
msgs, err := ag.Tree().FlattenBranch(ag.Tree().Active())
if err != nil {
t.Fatal(err)
}
if len(msgs) < 3 {
t.Fatalf("expected system+user+assistant, got %d", len(msgs))
}
}

// TestStorePersistsRootOnConstruction verifies the root node + main branch are
// written to the Store as soon as the agent is constructed, so a reload anchor
// exists before the first Invoke.
func TestStorePersistsRootOnConstruction(t *testing.T) {
ctx := context.Background()
store := memstore.New()
ag := NewAgent(AgentConfig{
Name: "root-persist",
SystemPrompt: "sys prompt",
Provider: &agenttest.ScriptedProvider{},
}, WithStore(store))

rootID := ag.Tree().Root().ID
got, err := store.LoadNode(ctx, rootID)
if err != nil {
t.Fatalf("root not persisted on construction: %v", err)
}
if got.ID != rootID {
t.Fatalf("persisted root mismatch: %s != %s", got.ID, rootID)
}
tip, err := store.LoadBranch(ctx, "main")
if err != nil {
t.Fatalf("main branch not persisted: %v", err)
}
if tip != rootID {
t.Fatalf("main tip = %s, want root %s", tip, rootID)
}
}
Loading