diff --git a/client/client.go b/client/client.go index 5b04f88f..e1f90181 100644 --- a/client/client.go +++ b/client/client.go @@ -23,6 +23,7 @@ import ( "fmt" "sync" + lru "github.com/hashicorp/golang-lru/v2" "github.com/transparency-dev/formats/log" "github.com/transparency-dev/merkle/compact" "github.com/transparency-dev/merkle/proof" @@ -32,6 +33,8 @@ import ( "github.com/transparency-dev/tessera/internal/otel" "go.opentelemetry.io/otel/trace" "golang.org/x/mod/sumdb/note" + "golang.org/x/sync/errgroup" + "golang.org/x/sync/singleflight" ) var ( @@ -125,15 +128,7 @@ func FetchRangeNodes(ctx context.Context, s uint64, f TileFetcherFunc) ([][]byte nc := newNodeCache(f, s) nIDs := make([]compact.NodeID, 0, compact.RangeSize(0, s)) nIDs = compact.RangeNodes(0, s, nIDs) - hashes := make([][]byte, 0, len(nIDs)) - for _, n := range nIDs { - h, err := nc.GetNode(ctx, n) - if err != nil { - return nil, err - } - hashes = append(hashes, h) - } - return hashes, nil + return nc.GetNodes(ctx, nIDs) }) } @@ -180,7 +175,7 @@ func GetEntryBundle(ctx context.Context, f EntryBundleFetcherFunc, i, logSize ui // at a given tree size. type ProofBuilder struct { treeSize uint64 - nodeCache nodeCache + nodeCache *nodeCache } // NewProofBuilder creates a new ProofBuilder object for a given tree size. @@ -204,7 +199,7 @@ func (pb *ProofBuilder) InclusionProof(ctx context.Context, index uint64) ([][]b if err != nil { return nil, fmt.Errorf("failed to calculate inclusion proof node list: %v", err) } - return pb.fetchNodes(ctx, nodes) + return pb.materialiseProof(ctx, nodes) }) } @@ -221,22 +216,16 @@ func (pb *ProofBuilder) ConsistencyProof(ctx context.Context, smaller, larger ui if err != nil { return nil, fmt.Errorf("failed to calculate consistency proof node list: %v", err) } - return pb.fetchNodes(ctx, nodes) + return pb.materialiseProof(ctx, nodes) }) } -// fetchNodes retrieves the specified proof nodes via pb's nodeCache. -func (pb *ProofBuilder) fetchNodes(ctx context.Context, nodes proof.Nodes) ([][]byte, error) { - hashes := make([][]byte, 0, len(nodes.IDs)) - // TODO(al) parallelise this. - for _, id := range nodes.IDs { - h, err := pb.nodeCache.GetNode(ctx, id) - if err != nil { - return nil, fmt.Errorf("failed to get node (%v): %v", id, err) - } - hashes = append(hashes, h) +// materialiseProof retrieves the specified proof nodes via pb's nodeCache, recreating ephemeral nodes if necessary. +func (pb *ProofBuilder) materialiseProof(ctx context.Context, nodes proof.Nodes) ([][]byte, error) { + hashes, err := pb.nodeCache.GetNodes(ctx, nodes.IDs) + if err != nil { + return nil, err } - var err error if hashes, err = nodes.Rehash(hashes, hasher.HashChildren); err != nil { return nil, fmt.Errorf("failed to rehash proof: %v", err) } @@ -344,82 +333,123 @@ func (lst *LogStateTracker) Latest() log.Checkpoint { return lst.latestConsistent } -// tileKey is used as a key in nodeCache's tile map. -type tileKey struct { - tileLevel uint64 - tileIndex uint64 -} - // nodeCache hides the tiles abstraction away, and improves // performance by caching tiles it's seen. -// Not threadsafe, and intended to be only used throughout the course -// of a single request. +// Threadsafe. type nodeCache struct { - logSize uint64 - ephemeral map[compact.NodeID][]byte - tiles map[tileKey]api.HashTile - getTile TileFetcherFunc + logSize uint64 + nodes *lru.Cache[compact.NodeID, []byte] + getTile TileFetcherFunc + g singleflight.Group } // newNodeCache creates a new nodeCache instance for a given log size. -func newNodeCache(f TileFetcherFunc, logSize uint64) nodeCache { - return nodeCache{ - logSize: logSize, - ephemeral: make(map[compact.NodeID][]byte), - tiles: make(map[tileKey]api.HashTile), - getTile: f, +func newNodeCache(f TileFetcherFunc, logSize uint64) *nodeCache { + c, err := lru.New[compact.NodeID, []byte](64 << 10) + if err != nil { + panic(fmt.Errorf("lru.New: %v", err)) + } + return &nodeCache{ + logSize: logSize, + nodes: c, + getTile: f, } -} - -// SetEphemeralNode stored a derived "ephemeral" tree node. -func (n *nodeCache) SetEphemeralNode(id compact.NodeID, h []byte) { - n.ephemeral[id] = h } // GetNode returns the internal log tree node hash for the specified node ID. -// A previously set ephemeral node will be returned if id matches, otherwise -// the tile containing the requested node will be fetched and cached, and the -// node hash returned. +// The tile containing the node will be fetched if necessary. func (n *nodeCache) GetNode(ctx context.Context, id compact.NodeID) ([]byte, error) { return otel.Trace(ctx, "tessera.client.nodecache.GetNode", tracer, func(ctx context.Context, span trace.Span) ([]byte, error) { span.SetAttributes(indexKey.Int64(otel.Clamp64(id.Index)), levelKey.Int64(int64(id.Level))) - // First check for ephemeral nodes: - if e := n.ephemeral[id]; len(e) != 0 { + // Check if we've already cached this node, return it directly if so, otherwise we'll need to fetch it. + if e, ok := n.nodes.Get(id); ok { return e, nil } - // Otherwise look in fetched tiles: - tileLevel, tileIndex, nodeLevel, nodeIndex := layout.NodeCoordsToTileAddress(uint64(id.Level), uint64(id.Index)) - tKey := tileKey{tileLevel, tileIndex} - t, ok := n.tiles[tKey] - if !ok { + + // We now need to fetch the tile and use the contents to populate the cache. + // We'll fetch/parse the tile, and then reconsistute all the internal nodes into the + // node cache. We only want to do this once per tile, so use singleflight keyed by the _tile_ ID + // to make that happen. + tileLevel, tileIndex, _, _ := layout.NodeCoordsToTileAddress(uint64(id.Level), uint64(id.Index)) + k := fmt.Sprintf("%d/%d", tileLevel, tileIndex) + defer n.g.Forget(k) + _, err, _ := n.g.Do(k, func() (any, error) { + if n.nodes.Contains(id) { + return struct{}{}, nil + } span.AddEvent("cache miss") + p := layout.PartialTileSize(tileLevel, tileIndex, n.logSize) - tileRaw, err := n.getTile(ctx, tileLevel, tileIndex, p) - if err != nil { - return nil, fmt.Errorf("failed to fetch tile: %v", err) + if err := n.fetchTileAndPopulateCache(ctx, tileLevel, tileIndex, p); err != nil { + return struct{}{}, err } - var tile api.HashTile - if err := tile.UnmarshalText(tileRaw); err != nil { - return nil, fmt.Errorf("failed to parse tile: %v", err) + + return struct{}{}, nil + }) + if err != nil { + return nil, fmt.Errorf("failed to fetch and populate node cache: %v", err) + } + if e, ok := n.nodes.Get(id); ok { + return e, nil + } + return nil, fmt.Errorf("internal error: missing node %+v", id) + }) +} + +// GetNodes returns the tree hashes at the provided locations. +func (n *nodeCache) GetNodes(ctx context.Context, nIDs []compact.NodeID) ([][]byte, error) { + hashes := make([][]byte, len(nIDs)) + g, ctx := errgroup.WithContext(ctx) + for i, id := range nIDs { + g.Go(func() error { + h, err := n.GetNode(ctx, id) + if err != nil { + return err } - t = tile - n.tiles[tKey] = tile + hashes[i] = h + return nil + }) + } + if err := g.Wait(); err != nil { + return nil, err + } + return hashes, nil +} + +// fetchTileAndPopulateCache fetches and parses the specified tile from storage, populating the nodes it contains into the node cache. +func (n *nodeCache) fetchTileAndPopulateCache(ctx context.Context, tileLevel, tileIndex uint64, p uint8) error { + return otel.TraceErr(ctx, "tessera.client.nodecache.fetchTileAndPopulateCache", tracer, func(ctx context.Context, span trace.Span) error { + tileRaw, err := n.getTile(ctx, tileLevel, tileIndex, p) + if err != nil { + return fmt.Errorf("failed to fetch tile: %v", err) } - // We've got the tile, now we need to look up (or calculate) the node inside of it - numLeaves := 1 << nodeLevel - firstLeaf := int(nodeIndex) * numLeaves - lastLeaf := firstLeaf + numLeaves - if lastLeaf > len(t.Nodes) { - return nil, fmt.Errorf("require leaf nodes [%d, %d) but only got %d leaves", firstLeaf, lastLeaf, len(t.Nodes)) + + var tile api.HashTile + if err := tile.UnmarshalText(tileRaw); err != nil { + return fmt.Errorf("failed to parse tile: %v", err) + } + + // visitFn is a visitor callback which populates the nodes cache. + // Used by the calls to compact range below. + visitFn := func(intID compact.NodeID, h []byte) { + // Figure out the "global" nodeID for the node intID in the requested tile. + i := compact.NodeID{ + Level: uint(tileLevel*layout.TileHeight) + intID.Level, + Index: (tileIndex*layout.TileWidth)>>intID.Level + intID.Index, + } + _ = n.nodes.Add(i, h) } rf := compact.RangeFactory{Hash: hasher.HashChildren} r := rf.NewEmptyRange(0) - for _, l := range t.Nodes[firstLeaf:lastLeaf] { - if err := r.Append(l, nil); err != nil { - return nil, fmt.Errorf("failed to Append: %v", err) + for _, l := range tile.Nodes { + if err := r.Append(l, visitFn); err != nil { + return fmt.Errorf("failed to Append: %v", err) } } - return r.GetRootHash(nil) + if _, err := r.GetRootHash(visitFn); err != nil { + return fmt.Errorf("failed to visit all nodes: %v", err) + } + return nil }) } diff --git a/client/client_test.go b/client/client_test.go index e6a60616..1cba4f6b 100644 --- a/client/client_test.go +++ b/client/client_test.go @@ -334,7 +334,7 @@ func TestNodeFetcherAddressing(t *testing.T) { if err != nil { t.Fatalf("NewProofBuilder: %v", err) } - _, err = pb.fetchNodes(t.Context(), proof.Nodes{IDs: []compact.NodeID{compact.NewNodeID(test.nodeLevel, test.nodeIdx)}}) + _, err = pb.materialiseProof(t.Context(), proof.Nodes{IDs: []compact.NodeID{compact.NewNodeID(test.nodeLevel, test.nodeIdx)}}) if err != nil { t.Fatalf("fetchNodes: %v", err) } @@ -350,3 +350,84 @@ func TestNodeFetcherAddressing(t *testing.T) { }) } } + +func BenchmarkProofBuilder(b *testing.B) { + ctx := context.Background() + const treeSize = 1_000_000 + dummyHash := sha256.Sum256([]byte("dummy")) + + // Pre-generate a full tile to avoid marshaling in the loop. + fullTile := &api.HashTile{ + Nodes: make([][]byte, 256), + } + for i := range fullTile.Nodes { + fullTile.Nodes[i] = dummyHash[:] + } + fullTileBytes, err := fullTile.MarshalText() + if err != nil { + b.Fatalf("failed to marshal full tile: %v", err) + } + + // We'll ignore partial tiles for simplicity in this benchmark as they are rare in a 1M tree + // except for the very last tiles of each level. + f := func(_ context.Context, _, _ uint64, p uint8) ([]byte, error) { + if p == 0 { + return fullTileBytes, nil + } + // Handle partial tiles just in case, though they might not be hit often. + partialTile := &api.HashTile{ + Nodes: make([][]byte, p), + } + for i := range partialTile.Nodes { + partialTile.Nodes[i] = dummyHash[:] + } + return partialTile.MarshalText() + } + + b.Run("InclusionProof", func(b *testing.B) { + b.Run("WarmCache", func(b *testing.B) { + pb, _ := NewProofBuilder(ctx, treeSize, f) + // Warm up the cache with some proofs. + for i := uint64(0); i < 100; i++ { + _, _ = pb.InclusionProof(ctx, i) + } + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _ = pb.InclusionProof(ctx, uint64(i%treeSize)) + } + }) + + b.Run("ColdCache", func(b *testing.B) { + for i := 0; i < b.N; i++ { + b.StopTimer() + pb, _ := NewProofBuilder(ctx, treeSize, f) + b.StartTimer() + _, _ = pb.InclusionProof(ctx, uint64(i%treeSize)) + } + }) + }) + + b.Run("ConsistencyProof", func(b *testing.B) { + b.Run("WarmCache", func(b *testing.B) { + pb, _ := NewProofBuilder(ctx, treeSize, f) + // Warm up. + for i := uint64(0); i < 100; i++ { + _, _ = pb.ConsistencyProof(ctx, i, i+1) + } + b.ResetTimer() + for i := 0; i < b.N; i++ { + // Benchmark consistency proof from a smaller size to the full tree size. + _, _ = pb.ConsistencyProof(ctx, uint64(i%treeSize), treeSize) + } + }) + + b.Run("ColdCache", func(b *testing.B) { + for i := 0; i < b.N; i++ { + b.StopTimer() + pb, _ := NewProofBuilder(ctx, treeSize, f) + b.StartTimer() + _, _ = pb.ConsistencyProof(ctx, uint64(i%treeSize), treeSize) + } + }) + }) +}