From ea28a47f00e37c56e8564b7a4aa4d38be61ce8d3 Mon Sep 17 00:00:00 2001 From: Urmzd Mukhammadnaim Date: Sun, 31 May 2026 15:15:37 -0500 Subject: [PATCH] feat(agent): persist conversation tree via Store seam + in-memory Store Wire AGENT TREE PERSISTENCE behind the existing types.Store seam, testable without Postgres. - Add AgentConfig.Store + WithStore option (default nil = today's in-memory-only behavior, fully backward compatible). - runLoop persists each new node (and branch tip) to the Store as it is added, via Store.Tx so the tip never points at an unsaved node. Best-effort: errors are logged, never fatal. - NewAgent persists the root node + main branch up front when a Store is configured, giving LoadTreeFromStore an anchor before the first Invoke. - Add LoadTreeFromStore helper (Store.LoadTree + tree.FromStore) for the read/resume path. - New package agent/store/memstore: in-memory types.Store implementing the full interface (SaveNode/LoadNode/LoadChildren/LoadPath/SaveBranch/ LoadBranch/ListBranches/SaveCheckpoint/LoadCheckpoint/LoadTree/Tx) with atomic buffered transactions. Tests: memstore unit tests (round-trip, children order, path, branches, checkpoints, reachable-subtree LoadTree, Tx commit/rollback); agent multi-turn Invoke -> reconstruct tree from memstore -> assert full message history round-trips; backward-compat (nil Store) and root-on-construction. --- agent/agent.go | 90 ++++++++- agent/agent_store_test.go | 186 ++++++++++++++++++ agent/store/memstore/memstore.go | 271 ++++++++++++++++++++++++++ agent/store/memstore/memstore_test.go | 238 ++++++++++++++++++++++ 4 files changed, 780 insertions(+), 5 deletions(-) create mode 100644 agent/agent_store_test.go create mode 100644 agent/store/memstore/memstore.go create mode 100644 agent/store/memstore/memstore_test.go diff --git a/agent/agent.go b/agent/agent.go index d69cd95..bb4e5fa 100644 --- a/agent/agent.go +++ b/agent/agent.go @@ -35,6 +35,13 @@ type AgentConfig struct { // today's streaming behavior). A DBOS-backed runner lives in agent/durable/dbos. StepRunner types.StepRunner + // Store persists the conversation tree. Defaults to nil, which preserves + // today's in-memory-only behavior (the tree lives only for the agent's + // lifetime). When set, each node added during a run is written to the Store + // best-effort (errors are logged, never fatal), and NewAgent can hydrate the + // tree from the Store via WithStore (see WithStore / LoadTreeFromStore). + Store types.Store + // File pipeline configuration. Resolvers map[string]types.Resolver // URI scheme → Resolver (e.g. "file", "https", "s3") Extractors map[types.MediaType]types.Extractor // MediaType → Extractor for non-native types @@ -105,6 +112,15 @@ func WithStepRunner(r types.StepRunner) AgentOption { return func(c *AgentConfig) { c.StepRunner = r } } +// WithStore configures a types.Store so the conversation tree is persisted. +// With no Store (the default) the tree is in-memory only — fully backward +// compatible. When a Store is set, each node added during a run is written +// best-effort; loading a previously-persisted tree is done explicitly via +// LoadTreeFromStore before NewAgent (pass the rebuilt tree with WithTree). +func WithStore(s types.Store) AgentOption { + return func(c *AgentConfig) { c.Store = s } +} + // Agent runs an LLM agent loop with tool execution. // All conversations are backed by a Tree. type Agent struct { @@ -153,6 +169,16 @@ func NewAgent(cfg AgentConfig, opts ...AgentOption) *Agent { a := &Agent{cfg: cfg, tools: tools} + // When a Store is configured, persist the tree's root node + main branch tip + // up front so a later LoadTreeFromStore has an anchor even before the first + // Invoke. Best-effort: a failure here is logged, never fatal (mirrors the + // per-node persistence in runLoop). + if cfg.Store != nil { + if root := cfg.Tree.Root(); root != nil { + a.persistNode(context.Background(), root) + } + } + // Build the handoff group (if any). The entry agent shares this tree; each // member's handoff_to_ tools are wired into its registry. if len(cfg.Handoffs) > 0 { @@ -601,6 +627,50 @@ func uriScheme(uri string) string { return u.Scheme } +// ── Store persistence ──────────────────────────────────────────────── + +// persistNode writes a freshly-added node and its branch tip to the configured +// Store. It is best-effort: a nil Store is a no-op, and any error is logged but +// never propagated, so persistence failures cannot break a live agent run. When +// the Store exposes a transaction, the node + branch tip are committed together +// so a reader never observes a tip pointing at an unsaved node. +func (a *Agent) persistNode(ctx context.Context, node *types.Node) { + if a.cfg.Store == nil || node == nil { + return + } + err := a.cfg.Store.Tx(ctx, func(tx types.StoreTx) error { + if err := tx.SaveNode(ctx, node); err != nil { + return err + } + return tx.SaveBranch(ctx, node.BranchID, node.ID) + }) + if err != nil { + a.cfg.Logger.Warn("store persist failed", + "agent", a.cfg.Name, "node", node.ID, "branch", node.BranchID, "error", err) + } +} + +// LoadTreeFromStore reconstructs a conversation tree from a Store by loading the +// subtree rooted at rootID and rebuilding it via tree.FromStore. The returned +// tree can be passed to NewAgent via WithTree to resume a persisted session. The +// active branch defaults to "main" unless overridden. +// +// This is the read counterpart to WithStore's write path. It is a free function +// (not a method) so a tree can be hydrated before an Agent exists. +func LoadTreeFromStore(ctx context.Context, store types.Store, rootID types.NodeID, active types.BranchID) (*tree.Tree, error) { + if store == nil { + return nil, fmt.Errorf("agent: nil store") + } + nodes, branches, err := store.LoadTree(ctx, rootID) + if err != nil { + return nil, fmt.Errorf("load tree: %w", err) + } + if active == "" { + active = types.BranchID("main") + } + return tree.FromStore(nodes, branches, nil, rootID, active) +} + // ── Run loop ───────────────────────────────────────────────────────── func (a *Agent) runLoop(ctx context.Context, stream *EventStream, input []types.Message, branch types.BranchID) { @@ -683,10 +753,12 @@ func (a *Agent) runLoop(ctx context.Context, stream *EventStream, input []types. stream.send(types.ErrorDelta{Error: err}) return } - if _, err := tr.AddChild(ctx, tip.ID, *msg); err != nil { + assistantNode, err := tr.AddChild(ctx, tip.ID, *msg) + if err != nil { stream.send(types.ErrorDelta{Error: err}) return } + a.persistNode(ctx, assistantNode) toolCalls := assistantToolCalls(msg) if len(toolCalls) == 0 { @@ -721,10 +793,12 @@ func (a *Agent) appendInput(ctx context.Context, tr *tree.Tree, stream *EventStr stream.send(types.ErrorDelta{Error: err}) return false } - if _, err := tr.AddChild(ctx, tip.ID, msg); err != nil { + node, err := tr.AddChild(ctx, tip.ID, msg) + if err != nil { stream.send(types.ErrorDelta{Error: err}) return false } + a.persistNode(ctx, node) } return true } @@ -793,8 +867,12 @@ func (a *Agent) persistToolResults(ctx context.Context, tr *tree.Tree, branch ty if err != nil { return err } - _, err = tr.AddChild(ctx, tip.ID, types.NewToolResultMessage(contents...)) - return err + node, err := tr.AddChild(ctx, tip.ID, types.NewToolResultMessage(contents...)) + if err != nil { + return err + } + a.persistNode(ctx, node) + return nil } // applyHandoff applies the first handoff signal in results, if any, appending a @@ -828,10 +906,12 @@ func (a *Agent) applyHandoff(ctx context.Context, tr *tree.Tree, stream *EventSt stream.send(types.ErrorDelta{Error: err}) return true } - if _, err := tr.AddChild(ctx, tip.ID, overlay); err != nil { + node, err := tr.AddChild(ctx, tip.ID, overlay) + if err != nil { stream.send(types.ErrorDelta{Error: err}) return true } + a.persistNode(ctx, node) return false } diff --git a/agent/agent_store_test.go b/agent/agent_store_test.go new file mode 100644 index 0000000..4f1b48a --- /dev/null +++ b/agent/agent_store_test.go @@ -0,0 +1,186 @@ +package agent + +import ( + "context" + "strings" + "testing" + + "github.com/urmzd/saige/agent/agenttest" + "github.com/urmzd/saige/agent/store/memstore" + "github.com/urmzd/saige/agent/types" +) + +// messageText projects any message into a comparable text string, concatenating +// every text-bearing content block. It is good enough to assert that the +// flattened conversation round-trips through a Store. +func messageText(msg types.Message) string { + var sb strings.Builder + sb.WriteString(string(msg.Role())) + sb.WriteString(":") + switch m := msg.(type) { + case types.SystemMessage: + for _, c := range m.Content { + if tc, ok := c.(types.TextContent); ok { + sb.WriteString(tc.Text) + } + if trc, ok := c.(types.ToolResultContent); ok { + sb.WriteString("tool_result(") + sb.WriteString(trc.Text) + sb.WriteString(")") + } + } + case types.UserMessage: + for _, c := range m.Content { + if tc, ok := c.(types.TextContent); ok { + sb.WriteString(tc.Text) + } + } + case types.AssistantMessage: + for _, c := range m.Content { + if tc, ok := c.(types.TextContent); ok { + sb.WriteString(tc.Text) + } + } + } + return sb.String() +} + +func flattenedText(t *testing.T, msgs []types.Message) []string { + t.Helper() + out := make([]string, len(msgs)) + for i, m := range msgs { + out[i] = messageText(m) + } + return out +} + +// TestStoreMultiTurnRoundTrip builds an agent backed by a memstore, runs two +// Invoke turns, then reconstructs the tree purely from the Store (LoadTree + +// FromStore) and asserts the full message history round-trips byte-for-byte. +func TestStoreMultiTurnRoundTrip(t *testing.T) { + ctx := context.Background() + store := memstore.New() + + provider := &agenttest.ScriptedProvider{ + Responses: [][]types.Delta{ + agenttest.TextResponse("hello there"), + agenttest.TextResponse("second answer"), + }, + } + + ag := NewAgent(AgentConfig{ + Name: "store-agent", + SystemPrompt: "you are helpful", + Provider: provider, + }, WithStore(store)) + + rootID := ag.Tree().Root().ID + + // Turn 1. + stream := ag.Invoke(ctx, []types.Message{types.NewUserMessage("first question")}) + if err := stream.Wait(); err != nil { + t.Fatalf("turn 1 invoke: %v", err) + } + // Turn 2. + stream = ag.Invoke(ctx, []types.Message{types.NewUserMessage("second question")}) + if err := stream.Wait(); err != nil { + t.Fatalf("turn 2 invoke: %v", err) + } + + branch := ag.Tree().Active() + liveMsgs, err := ag.Tree().FlattenBranch(branch) + if err != nil { + t.Fatalf("flatten live branch: %v", err) + } + if len(liveMsgs) < 5 { + t.Fatalf("expected at least system+2*(user+assistant) messages, got %d", len(liveMsgs)) + } + + // Reconstruct the tree from the Store only. + restored, err := LoadTreeFromStore(ctx, store, rootID, branch) + if err != nil { + t.Fatalf("LoadTreeFromStore: %v", err) + } + + restoredMsgs, err := restored.FlattenBranch(branch) + if err != nil { + t.Fatalf("flatten restored branch: %v", err) + } + + live := flattenedText(t, liveMsgs) + got := flattenedText(t, restoredMsgs) + if len(live) != len(got) { + t.Fatalf("message count mismatch: live=%d restored=%d\nlive=%v\nrestored=%v", + len(live), len(got), live, got) + } + for i := range live { + if live[i] != got[i] { + t.Fatalf("message[%d] mismatch:\n live=%q\n got =%q", i, live[i], got[i]) + } + } + + // Spot-check the actual conversation made it through. + joined := strings.Join(got, "|") + for _, want := range []string{"first question", "hello there", "second question", "second answer"} { + if !strings.Contains(joined, want) { + t.Fatalf("restored history missing %q: %v", want, got) + } + } +} + +// TestStoreNilIsBackwardCompatible verifies that with no Store configured the +// agent behaves exactly as before — Invoke works and the tree is populated. +func TestStoreNilIsBackwardCompatible(t *testing.T) { + ctx := context.Background() + provider := &agenttest.ScriptedProvider{ + Responses: [][]types.Delta{agenttest.TextResponse("ok")}, + } + ag := NewAgent(AgentConfig{ + Name: "no-store", + SystemPrompt: "sys", + Provider: provider, + }) + if ag.cfg.Store != nil { + t.Fatal("expected nil Store by default") + } + stream := ag.Invoke(ctx, []types.Message{types.NewUserMessage("hi")}) + if err := stream.Wait(); err != nil { + t.Fatalf("invoke: %v", err) + } + msgs, err := ag.Tree().FlattenBranch(ag.Tree().Active()) + if err != nil { + t.Fatal(err) + } + if len(msgs) < 3 { + t.Fatalf("expected system+user+assistant, got %d", len(msgs)) + } +} + +// TestStorePersistsRootOnConstruction verifies the root node + main branch are +// written to the Store as soon as the agent is constructed, so a reload anchor +// exists before the first Invoke. +func TestStorePersistsRootOnConstruction(t *testing.T) { + ctx := context.Background() + store := memstore.New() + ag := NewAgent(AgentConfig{ + Name: "root-persist", + SystemPrompt: "sys prompt", + Provider: &agenttest.ScriptedProvider{}, + }, WithStore(store)) + + rootID := ag.Tree().Root().ID + got, err := store.LoadNode(ctx, rootID) + if err != nil { + t.Fatalf("root not persisted on construction: %v", err) + } + if got.ID != rootID { + t.Fatalf("persisted root mismatch: %s != %s", got.ID, rootID) + } + tip, err := store.LoadBranch(ctx, "main") + if err != nil { + t.Fatalf("main branch not persisted: %v", err) + } + if tip != rootID { + t.Fatalf("main tip = %s, want root %s", tip, rootID) + } +} diff --git a/agent/store/memstore/memstore.go b/agent/store/memstore/memstore.go new file mode 100644 index 0000000..9497f9d --- /dev/null +++ b/agent/store/memstore/memstore.go @@ -0,0 +1,271 @@ +// Package memstore implements agent/types.Store using in-memory maps. +// +// It is a drop-in, dependency-free Store suitable for tests and single-process +// use. Unlike pgstore it offers no crash durability — data lives only for the +// lifetime of the process — but it mirrors the full types.Store contract so +// persistence and tree-reconstruction paths can be exercised without Postgres. +package memstore + +import ( + "context" + "fmt" + "sort" + "sync" + + "github.com/urmzd/saige/agent/types" +) + +var ( + _ types.Store = (*Store)(nil) + _ types.StoreTx = (*storeTx)(nil) +) + +// Store is an in-memory implementation of types.Store. +// +// All operations are safe for concurrent use. Nodes are copied on the way in +// and out so callers cannot mutate stored state by holding a returned pointer. +type Store struct { + mu sync.RWMutex + nodes map[types.NodeID]*types.Node + childOrder map[types.NodeID][]types.NodeID // parent -> children in insertion order + branches map[types.BranchID]types.NodeID + checkpoints map[types.CheckpointID]types.Checkpoint +} + +// New creates an empty in-memory Store. +func New() *Store { + return &Store{ + nodes: make(map[types.NodeID]*types.Node), + childOrder: make(map[types.NodeID][]types.NodeID), + branches: make(map[types.BranchID]types.NodeID), + checkpoints: make(map[types.CheckpointID]types.Checkpoint), + } +} + +// cloneNode returns a defensive copy of a node so stored state is immutable +// from the caller's perspective. The Message itself is treated as immutable +// (the tree builds a fresh message per node), so it is shared by reference. +func cloneNode(n *types.Node) *types.Node { + cp := *n + if n.SummaryOf != nil { + cp.SummaryOf = append([]types.NodeID(nil), n.SummaryOf...) + } + if n.ArchivedAt != nil { + at := *n.ArchivedAt + cp.ArchivedAt = &at + } + return &cp +} + +// saveNode is the shared write path used by both Store and storeTx. +func (s *Store) saveNode(node *types.Node) error { + if node.ID == "" { + return fmt.Errorf("memstore: node ID is empty") + } + _, existed := s.nodes[node.ID] + s.nodes[node.ID] = cloneNode(node) + if !existed && node.ParentID != "" { + s.childOrder[node.ParentID] = append(s.childOrder[node.ParentID], node.ID) + } + return nil +} + +// SaveNode persists a node, preserving child insertion order on first write. +func (s *Store) SaveNode(_ context.Context, node *types.Node) error { + s.mu.Lock() + defer s.mu.Unlock() + return s.saveNode(node) +} + +// LoadNode retrieves a single node by ID. +func (s *Store) LoadNode(_ context.Context, id types.NodeID) (*types.Node, error) { + s.mu.RLock() + defer s.mu.RUnlock() + n, ok := s.nodes[id] + if !ok { + return nil, fmt.Errorf("node not found: %s", id) + } + return cloneNode(n), nil +} + +// LoadChildren returns direct children of a node in insertion order. +func (s *Store) LoadChildren(_ context.Context, parentID types.NodeID) ([]*types.Node, error) { + s.mu.RLock() + defer s.mu.RUnlock() + ids := s.childOrder[parentID] + out := make([]*types.Node, 0, len(ids)) + for _, id := range ids { + if n, ok := s.nodes[id]; ok { + out = append(out, cloneNode(n)) + } + } + return out, nil +} + +// LoadPath returns all nodes from root to the given node, root-first. +func (s *Store) LoadPath(_ context.Context, toNodeID types.NodeID) ([]*types.Node, error) { + s.mu.RLock() + defer s.mu.RUnlock() + + var path []*types.Node + current := toNodeID + for current != "" { + n, ok := s.nodes[current] + if !ok { + return nil, fmt.Errorf("node not found: %s", current) + } + path = append(path, cloneNode(n)) + current = n.ParentID + } + // Reverse to root-first order. + for i, j := 0, len(path)-1; i < j; i, j = i+1, j-1 { + path[i], path[j] = path[j], path[i] + } + return path, nil +} + +// SaveBranch persists a branch-to-tip mapping. +func (s *Store) SaveBranch(_ context.Context, branch types.BranchID, tipID types.NodeID) error { + s.mu.Lock() + defer s.mu.Unlock() + s.branches[branch] = tipID + return nil +} + +// LoadBranch retrieves the tip node ID for a branch. +func (s *Store) LoadBranch(_ context.Context, branch types.BranchID) (types.NodeID, error) { + s.mu.RLock() + defer s.mu.RUnlock() + tip, ok := s.branches[branch] + if !ok { + return "", fmt.Errorf("branch not found: %s", branch) + } + return tip, nil +} + +// ListBranches returns a copy of all branch-to-tip mappings. +func (s *Store) ListBranches(_ context.Context) (map[types.BranchID]types.NodeID, error) { + s.mu.RLock() + defer s.mu.RUnlock() + out := make(map[types.BranchID]types.NodeID, len(s.branches)) + for k, v := range s.branches { + out[k] = v + } + return out, nil +} + +// SaveCheckpoint persists a checkpoint. +func (s *Store) SaveCheckpoint(_ context.Context, cp types.Checkpoint) error { + s.mu.Lock() + defer s.mu.Unlock() + s.checkpoints[cp.ID] = cp + return nil +} + +// LoadCheckpoint retrieves a checkpoint by ID. +func (s *Store) LoadCheckpoint(_ context.Context, id types.CheckpointID) (types.Checkpoint, error) { + s.mu.RLock() + defer s.mu.RUnlock() + cp, ok := s.checkpoints[id] + if !ok { + return types.Checkpoint{}, fmt.Errorf("checkpoint not found: %s", id) + } + return cp, nil +} + +// LoadTree returns every node reachable from rootID plus all branch tips. +// Nodes are ordered root-first (ascending depth), matching pgstore semantics +// closely enough for tree.FromStore, which rebuilds the child map from parent +// pointers regardless of slice order. +func (s *Store) LoadTree(_ context.Context, rootID types.NodeID) ([]*types.Node, map[types.BranchID]types.NodeID, error) { + s.mu.RLock() + defer s.mu.RUnlock() + + if _, ok := s.nodes[rootID]; !ok { + return nil, nil, fmt.Errorf("root not found: %s", rootID) + } + + // BFS from the root over the child-order map so only the reachable subtree + // is returned (mirrors pgstore's recursive descendant query). + var nodes []*types.Node + queue := []types.NodeID{rootID} + for len(queue) > 0 { + id := queue[0] + queue = queue[1:] + n, ok := s.nodes[id] + if !ok { + continue + } + nodes = append(nodes, cloneNode(n)) + queue = append(queue, s.childOrder[id]...) + } + + // Stable order: ascending depth, then created time, for deterministic output. + sort.SliceStable(nodes, func(i, j int) bool { + if nodes[i].Depth != nodes[j].Depth { + return nodes[i].Depth < nodes[j].Depth + } + return nodes[i].CreatedAt.Before(nodes[j].CreatedAt) + }) + + branches := make(map[types.BranchID]types.NodeID, len(s.branches)) + for k, v := range s.branches { + branches[k] = v + } + return nodes, branches, nil +} + +// storeTx implements types.StoreTx by buffering writes and applying them +// atomically on commit. If fn returns an error, no buffered writes are applied. +type storeTx struct { + store *Store + nodes []*types.Node + branches map[types.BranchID]types.NodeID + checkpoints []types.Checkpoint +} + +func (t *storeTx) SaveNode(_ context.Context, node *types.Node) error { + if node.ID == "" { + return fmt.Errorf("memstore: node ID is empty") + } + t.nodes = append(t.nodes, cloneNode(node)) + return nil +} + +func (t *storeTx) SaveBranch(_ context.Context, branch types.BranchID, tipID types.NodeID) error { + t.branches[branch] = tipID + return nil +} + +func (t *storeTx) SaveCheckpoint(_ context.Context, cp types.Checkpoint) error { + t.checkpoints = append(t.checkpoints, cp) + return nil +} + +// Tx runs fn against a buffered transaction, applying its writes atomically on +// success. On error nothing is persisted, giving all-or-nothing semantics that +// match pgstore's database transaction. +func (s *Store) Tx(ctx context.Context, fn func(types.StoreTx) error) error { + tx := &storeTx{ + store: s, + branches: make(map[types.BranchID]types.NodeID), + } + if err := fn(tx); err != nil { + return err + } + + s.mu.Lock() + defer s.mu.Unlock() + for _, n := range tx.nodes { + if err := s.saveNode(n); err != nil { + return err + } + } + for b, tip := range tx.branches { + s.branches[b] = tip + } + for _, cp := range tx.checkpoints { + s.checkpoints[cp.ID] = cp + } + return nil +} diff --git a/agent/store/memstore/memstore_test.go b/agent/store/memstore/memstore_test.go new file mode 100644 index 0000000..3d5ec2e --- /dev/null +++ b/agent/store/memstore/memstore_test.go @@ -0,0 +1,238 @@ +package memstore_test + +import ( + "context" + "errors" + "testing" + "time" + + "github.com/urmzd/saige/agent/store/memstore" + "github.com/urmzd/saige/agent/types" +) + +func node(id, parent string, depth int, branch types.BranchID) *types.Node { + return &types.Node{ + ID: types.NodeID(id), + ParentID: types.NodeID(parent), + Message: types.NewUserMessage(id), + State: types.NodeActive, + Version: 1, + Depth: depth, + BranchID: branch, + CreatedAt: time.Now(), + UpdatedAt: time.Now(), + } +} + +func TestStoreSaveLoadNode(t *testing.T) { + ctx := context.Background() + s := memstore.New() + + root := node("root", "", 0, "main") + if err := s.SaveNode(ctx, root); err != nil { + t.Fatalf("SaveNode: %v", err) + } + + got, err := s.LoadNode(ctx, "root") + if err != nil { + t.Fatalf("LoadNode: %v", err) + } + if got.ID != "root" || got.BranchID != "main" { + t.Fatalf("round-trip mismatch: %+v", got) + } + + if _, err := s.LoadNode(ctx, "missing"); err == nil { + t.Fatal("expected error loading missing node") + } +} + +func TestStoreLoadNodeIsACopy(t *testing.T) { + ctx := context.Background() + s := memstore.New() + if err := s.SaveNode(ctx, node("n", "", 0, "main")); err != nil { + t.Fatal(err) + } + got, _ := s.LoadNode(ctx, "n") + got.BranchID = "tampered" + + fresh, _ := s.LoadNode(ctx, "n") + if fresh.BranchID != "main" { + t.Fatalf("stored node mutated via returned pointer: %s", fresh.BranchID) + } +} + +func TestStoreChildrenAndPath(t *testing.T) { + ctx := context.Background() + s := memstore.New() + + for _, n := range []*types.Node{ + node("root", "", 0, "main"), + node("a", "root", 1, "main"), + node("b", "root", 1, "main"), + node("c", "a", 2, "main"), + } { + if err := s.SaveNode(ctx, n); err != nil { + t.Fatal(err) + } + } + + children, err := s.LoadChildren(ctx, "root") + if err != nil { + t.Fatal(err) + } + if len(children) != 2 || children[0].ID != "a" || children[1].ID != "b" { + t.Fatalf("children order wrong: %+v", children) + } + + path, err := s.LoadPath(ctx, "c") + if err != nil { + t.Fatal(err) + } + want := []types.NodeID{"root", "a", "c"} + if len(path) != len(want) { + t.Fatalf("path length: got %d want %d", len(path), len(want)) + } + for i, id := range want { + if path[i].ID != id { + t.Fatalf("path[%d]=%s want %s", i, path[i].ID, id) + } + } +} + +func TestStoreBranches(t *testing.T) { + ctx := context.Background() + s := memstore.New() + + if err := s.SaveBranch(ctx, "main", "tip1"); err != nil { + t.Fatal(err) + } + if err := s.SaveBranch(ctx, "feature", "tip2"); err != nil { + t.Fatal(err) + } + + tip, err := s.LoadBranch(ctx, "main") + if err != nil { + t.Fatal(err) + } + if tip != "tip1" { + t.Fatalf("LoadBranch main: got %s", tip) + } + + all, err := s.ListBranches(ctx) + if err != nil { + t.Fatal(err) + } + if len(all) != 2 || all["feature"] != "tip2" { + t.Fatalf("ListBranches: %+v", all) + } + + if _, err := s.LoadBranch(ctx, "nope"); err == nil { + t.Fatal("expected error for missing branch") + } +} + +func TestStoreCheckpoints(t *testing.T) { + ctx := context.Background() + s := memstore.New() + + cp := types.Checkpoint{ID: "cp1", Branch: "main", NodeID: "n1", Name: "snap", CreatedAt: time.Now()} + if err := s.SaveCheckpoint(ctx, cp); err != nil { + t.Fatal(err) + } + got, err := s.LoadCheckpoint(ctx, "cp1") + if err != nil { + t.Fatal(err) + } + if got.Name != "snap" || got.NodeID != "n1" { + t.Fatalf("checkpoint round-trip: %+v", got) + } + if _, err := s.LoadCheckpoint(ctx, "missing"); err == nil { + t.Fatal("expected error for missing checkpoint") + } +} + +func TestStoreLoadTree(t *testing.T) { + ctx := context.Background() + s := memstore.New() + + nodes := []*types.Node{ + node("root", "", 0, "main"), + node("a", "root", 1, "main"), + node("b", "a", 2, "main"), + // A node on a different subtree not reachable from "a"'s root. + node("orphan", "", 0, "other"), + } + for _, n := range nodes { + if err := s.SaveNode(ctx, n); err != nil { + t.Fatal(err) + } + } + if err := s.SaveBranch(ctx, "main", "b"); err != nil { + t.Fatal(err) + } + + got, branches, err := s.LoadTree(ctx, "root") + if err != nil { + t.Fatal(err) + } + // Only root, a, b are reachable from root — orphan is excluded. + if len(got) != 3 { + t.Fatalf("LoadTree returned %d nodes, want 3: %+v", len(got), got) + } + if got[0].ID != "root" { + t.Fatalf("LoadTree not root-first: %s", got[0].ID) + } + if branches["main"] != "b" { + t.Fatalf("LoadTree branches: %+v", branches) + } + + if _, _, err := s.LoadTree(ctx, "ghost"); err == nil { + t.Fatal("expected error for unknown root") + } +} + +func TestStoreTxAtomicCommit(t *testing.T) { + ctx := context.Background() + s := memstore.New() + + err := s.Tx(ctx, func(tx types.StoreTx) error { + if err := tx.SaveNode(ctx, node("root", "", 0, "main")); err != nil { + return err + } + return tx.SaveBranch(ctx, "main", "root") + }) + if err != nil { + t.Fatalf("Tx: %v", err) + } + + if _, err := s.LoadNode(ctx, "root"); err != nil { + t.Fatalf("committed node missing: %v", err) + } + tip, err := s.LoadBranch(ctx, "main") + if err != nil || tip != "root" { + t.Fatalf("committed branch wrong: tip=%s err=%v", tip, err) + } +} + +func TestStoreTxRollbackOnError(t *testing.T) { + ctx := context.Background() + s := memstore.New() + + sentinel := errors.New("boom") + err := s.Tx(ctx, func(tx types.StoreTx) error { + _ = tx.SaveNode(ctx, node("root", "", 0, "main")) + _ = tx.SaveBranch(ctx, "main", "root") + return sentinel + }) + if !errors.Is(err, sentinel) { + t.Fatalf("Tx error: got %v want %v", err, sentinel) + } + + // Nothing should have been applied. + if _, err := s.LoadNode(ctx, "root"); err == nil { + t.Fatal("rolled-back node was persisted") + } + if _, err := s.LoadBranch(ctx, "main"); err == nil { + t.Fatal("rolled-back branch was persisted") + } +}