Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
178 changes: 104 additions & 74 deletions client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
"fmt"
"sync"

lru "github.com/hashicorp/golang-lru/v2"
"github.com/transparency-dev/formats/log"
"github.com/transparency-dev/merkle/compact"
"github.com/transparency-dev/merkle/proof"
Expand All @@ -32,6 +33,8 @@ import (
"github.com/transparency-dev/tessera/internal/otel"
"go.opentelemetry.io/otel/trace"
"golang.org/x/mod/sumdb/note"
"golang.org/x/sync/errgroup"
"golang.org/x/sync/singleflight"
)

var (
Expand Down Expand Up @@ -125,15 +128,7 @@ func FetchRangeNodes(ctx context.Context, s uint64, f TileFetcherFunc) ([][]byte
nc := newNodeCache(f, s)
nIDs := make([]compact.NodeID, 0, compact.RangeSize(0, s))
nIDs = compact.RangeNodes(0, s, nIDs)
hashes := make([][]byte, 0, len(nIDs))
for _, n := range nIDs {
h, err := nc.GetNode(ctx, n)
if err != nil {
return nil, err
}
hashes = append(hashes, h)
}
return hashes, nil
return nc.GetNodes(ctx, nIDs)
})
}

Expand Down Expand Up @@ -180,7 +175,7 @@ func GetEntryBundle(ctx context.Context, f EntryBundleFetcherFunc, i, logSize ui
// at a given tree size.
type ProofBuilder struct {
treeSize uint64
nodeCache nodeCache
nodeCache *nodeCache
}

// NewProofBuilder creates a new ProofBuilder object for a given tree size.
Expand All @@ -204,7 +199,7 @@ func (pb *ProofBuilder) InclusionProof(ctx context.Context, index uint64) ([][]b
if err != nil {
return nil, fmt.Errorf("failed to calculate inclusion proof node list: %v", err)
}
return pb.fetchNodes(ctx, nodes)
return pb.materialiseProof(ctx, nodes)
})
}

Expand All @@ -221,22 +216,16 @@ func (pb *ProofBuilder) ConsistencyProof(ctx context.Context, smaller, larger ui
if err != nil {
return nil, fmt.Errorf("failed to calculate consistency proof node list: %v", err)
}
return pb.fetchNodes(ctx, nodes)
return pb.materialiseProof(ctx, nodes)
})
}

// fetchNodes retrieves the specified proof nodes via pb's nodeCache.
func (pb *ProofBuilder) fetchNodes(ctx context.Context, nodes proof.Nodes) ([][]byte, error) {
hashes := make([][]byte, 0, len(nodes.IDs))
// TODO(al) parallelise this.
for _, id := range nodes.IDs {
h, err := pb.nodeCache.GetNode(ctx, id)
if err != nil {
return nil, fmt.Errorf("failed to get node (%v): %v", id, err)
}
hashes = append(hashes, h)
// materialiseProof retrieves the specified proof nodes via pb's nodeCache, recreating ephemeral nodes if necessary.
func (pb *ProofBuilder) materialiseProof(ctx context.Context, nodes proof.Nodes) ([][]byte, error) {
hashes, err := pb.nodeCache.GetNodes(ctx, nodes.IDs)
if err != nil {
return nil, err
}
var err error
if hashes, err = nodes.Rehash(hashes, hasher.HashChildren); err != nil {
return nil, fmt.Errorf("failed to rehash proof: %v", err)
}
Expand Down Expand Up @@ -344,82 +333,123 @@ func (lst *LogStateTracker) Latest() log.Checkpoint {
return lst.latestConsistent
}

// tileKey is used as a key in nodeCache's tile map.
type tileKey struct {
tileLevel uint64
tileIndex uint64
}

// nodeCache hides the tiles abstraction away, and improves
// performance by caching tiles it's seen.
// Not threadsafe, and intended to be only used throughout the course
// of a single request.
// Threadsafe.
type nodeCache struct {
logSize uint64
ephemeral map[compact.NodeID][]byte
tiles map[tileKey]api.HashTile
getTile TileFetcherFunc
logSize uint64
nodes *lru.Cache[compact.NodeID, []byte]
getTile TileFetcherFunc
g singleflight.Group
}

// newNodeCache creates a new nodeCache instance for a given log size.
func newNodeCache(f TileFetcherFunc, logSize uint64) nodeCache {
return nodeCache{
logSize: logSize,
ephemeral: make(map[compact.NodeID][]byte),
tiles: make(map[tileKey]api.HashTile),
getTile: f,
func newNodeCache(f TileFetcherFunc, logSize uint64) *nodeCache {
c, err := lru.New[compact.NodeID, []byte](64 << 10)
if err != nil {
panic(fmt.Errorf("lru.New: %v", err))
}
return &nodeCache{
logSize: logSize,
nodes: c,
getTile: f,
}
}

// SetEphemeralNode stored a derived "ephemeral" tree node.
func (n *nodeCache) SetEphemeralNode(id compact.NodeID, h []byte) {
n.ephemeral[id] = h
}

// GetNode returns the internal log tree node hash for the specified node ID.
// A previously set ephemeral node will be returned if id matches, otherwise
// the tile containing the requested node will be fetched and cached, and the
// node hash returned.
// The tile containing the node will be fetched if necessary.
func (n *nodeCache) GetNode(ctx context.Context, id compact.NodeID) ([]byte, error) {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a race condition in GetNode() when the LRU cache is full.

Scenario:

  1. The first goroutine calls GetNode(). It's a cache miss, then the goroutine fetches the tile and add the node to the LRU cache.
  2. Meanwhile, there are many other goroutines fetching other tiles and adding the nodes into the LRU cache.
  3. The LRU cache is full and evicts the node from the first goroutine.
  4. The first goroutine tries to call n.nodes.Get(id); and couldn't find the node.
  5. The unexpected error internal error: missing node %+v is returned.

return otel.Trace(ctx, "tessera.client.nodecache.GetNode", tracer, func(ctx context.Context, span trace.Span) ([]byte, error) {
span.SetAttributes(indexKey.Int64(otel.Clamp64(id.Index)), levelKey.Int64(int64(id.Level)))

// First check for ephemeral nodes:
if e := n.ephemeral[id]; len(e) != 0 {
// Check if we've already cached this node, return it directly if so, otherwise we'll need to fetch it.
if e, ok := n.nodes.Get(id); ok {
return e, nil
}
// Otherwise look in fetched tiles:
tileLevel, tileIndex, nodeLevel, nodeIndex := layout.NodeCoordsToTileAddress(uint64(id.Level), uint64(id.Index))
tKey := tileKey{tileLevel, tileIndex}
t, ok := n.tiles[tKey]
if !ok {

// We now need to fetch the tile and use the contents to populate the cache.
// We'll fetch/parse the tile, and then reconsistute all the internal nodes into the
// node cache. We only want to do this once per tile, so use singleflight keyed by the _tile_ ID
// to make that happen.
tileLevel, tileIndex, _, _ := layout.NodeCoordsToTileAddress(uint64(id.Level), uint64(id.Index))
k := fmt.Sprintf("%d/%d", tileLevel, tileIndex)
defer n.g.Forget(k)
_, err, _ := n.g.Do(k, func() (any, error) {
if n.nodes.Contains(id) {
return struct{}{}, nil
}
span.AddEvent("cache miss")

p := layout.PartialTileSize(tileLevel, tileIndex, n.logSize)
tileRaw, err := n.getTile(ctx, tileLevel, tileIndex, p)
if err != nil {
return nil, fmt.Errorf("failed to fetch tile: %v", err)
if err := n.fetchTileAndPopulateCache(ctx, tileLevel, tileIndex, p); err != nil {
return struct{}{}, err
}
var tile api.HashTile
if err := tile.UnmarshalText(tileRaw); err != nil {
return nil, fmt.Errorf("failed to parse tile: %v", err)

return struct{}{}, nil
})
if err != nil {
return nil, fmt.Errorf("failed to fetch and populate node cache: %v", err)
}
if e, ok := n.nodes.Get(id); ok {
return e, nil
}
return nil, fmt.Errorf("internal error: missing node %+v", id)
})
}

// GetNodes returns the tree hashes at the provided locations.
func (n *nodeCache) GetNodes(ctx context.Context, nIDs []compact.NodeID) ([][]byte, error) {
hashes := make([][]byte, len(nIDs))
g, ctx := errgroup.WithContext(ctx)
for i, id := range nIDs {
g.Go(func() error {
h, err := n.GetNode(ctx, id)
if err != nil {
return err
}
t = tile
n.tiles[tKey] = tile
hashes[i] = h
return nil
})
}
if err := g.Wait(); err != nil {
return nil, err
}
return hashes, nil
}

// fetchTileAndPopulateCache fetches and parses the specified tile from storage, populating the nodes it contains into the node cache.
func (n *nodeCache) fetchTileAndPopulateCache(ctx context.Context, tileLevel, tileIndex uint64, p uint8) error {
return otel.TraceErr(ctx, "tessera.client.nodecache.fetchTileAndPopulateCache", tracer, func(ctx context.Context, span trace.Span) error {
tileRaw, err := n.getTile(ctx, tileLevel, tileIndex, p)
if err != nil {
return fmt.Errorf("failed to fetch tile: %v", err)
}
// We've got the tile, now we need to look up (or calculate) the node inside of it
numLeaves := 1 << nodeLevel
firstLeaf := int(nodeIndex) * numLeaves
lastLeaf := firstLeaf + numLeaves
if lastLeaf > len(t.Nodes) {
return nil, fmt.Errorf("require leaf nodes [%d, %d) but only got %d leaves", firstLeaf, lastLeaf, len(t.Nodes))

var tile api.HashTile
if err := tile.UnmarshalText(tileRaw); err != nil {
return fmt.Errorf("failed to parse tile: %v", err)
}

// visitFn is a visitor callback which populates the nodes cache.
// Used by the calls to compact range below.
visitFn := func(intID compact.NodeID, h []byte) {
// Figure out the "global" nodeID for the node intID in the requested tile.
i := compact.NodeID{
Level: uint(tileLevel*layout.TileHeight) + intID.Level,
Index: (tileIndex*layout.TileWidth)>>intID.Level + intID.Index,
}
_ = n.nodes.Add(i, h)
}
rf := compact.RangeFactory{Hash: hasher.HashChildren}
r := rf.NewEmptyRange(0)
for _, l := range t.Nodes[firstLeaf:lastLeaf] {
if err := r.Append(l, nil); err != nil {
return nil, fmt.Errorf("failed to Append: %v", err)
for _, l := range tile.Nodes {
if err := r.Append(l, visitFn); err != nil {
return fmt.Errorf("failed to Append: %v", err)
}
}
return r.GetRootHash(nil)
if _, err := r.GetRootHash(visitFn); err != nil {
return fmt.Errorf("failed to visit all nodes: %v", err)
}
return nil
})
}
83 changes: 82 additions & 1 deletion client/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,7 @@ func TestNodeFetcherAddressing(t *testing.T) {
if err != nil {
t.Fatalf("NewProofBuilder: %v", err)
}
_, err = pb.fetchNodes(t.Context(), proof.Nodes{IDs: []compact.NodeID{compact.NewNodeID(test.nodeLevel, test.nodeIdx)}})
_, err = pb.materialiseProof(t.Context(), proof.Nodes{IDs: []compact.NodeID{compact.NewNodeID(test.nodeLevel, test.nodeIdx)}})
if err != nil {
t.Fatalf("fetchNodes: %v", err)
}
Expand All @@ -350,3 +350,84 @@ func TestNodeFetcherAddressing(t *testing.T) {
})
}
}

func BenchmarkProofBuilder(b *testing.B) {
ctx := context.Background()
const treeSize = 1_000_000
dummyHash := sha256.Sum256([]byte("dummy"))

// Pre-generate a full tile to avoid marshaling in the loop.
fullTile := &api.HashTile{
Nodes: make([][]byte, 256),
}
for i := range fullTile.Nodes {
fullTile.Nodes[i] = dummyHash[:]
}
fullTileBytes, err := fullTile.MarshalText()
if err != nil {
b.Fatalf("failed to marshal full tile: %v", err)
}

// We'll ignore partial tiles for simplicity in this benchmark as they are rare in a 1M tree
// except for the very last tiles of each level.
f := func(_ context.Context, _, _ uint64, p uint8) ([]byte, error) {
if p == 0 {
return fullTileBytes, nil
}
// Handle partial tiles just in case, though they might not be hit often.
partialTile := &api.HashTile{
Nodes: make([][]byte, p),
}
for i := range partialTile.Nodes {
partialTile.Nodes[i] = dummyHash[:]
}
return partialTile.MarshalText()
}

b.Run("InclusionProof", func(b *testing.B) {
b.Run("WarmCache", func(b *testing.B) {
pb, _ := NewProofBuilder(ctx, treeSize, f)
// Warm up the cache with some proofs.
for i := uint64(0); i < 100; i++ {
_, _ = pb.InclusionProof(ctx, i)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, _ = pb.InclusionProof(ctx, uint64(i%treeSize))
}
})

b.Run("ColdCache", func(b *testing.B) {
for i := 0; i < b.N; i++ {
b.StopTimer()
pb, _ := NewProofBuilder(ctx, treeSize, f)
b.StartTimer()
_, _ = pb.InclusionProof(ctx, uint64(i%treeSize))
}
})
})

b.Run("ConsistencyProof", func(b *testing.B) {
b.Run("WarmCache", func(b *testing.B) {
pb, _ := NewProofBuilder(ctx, treeSize, f)
// Warm up.
for i := uint64(0); i < 100; i++ {
_, _ = pb.ConsistencyProof(ctx, i, i+1)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
// Benchmark consistency proof from a smaller size to the full tree size.
_, _ = pb.ConsistencyProof(ctx, uint64(i%treeSize), treeSize)
}
})

b.Run("ColdCache", func(b *testing.B) {
for i := 0; i < b.N; i++ {
b.StopTimer()
pb, _ := NewProofBuilder(ctx, treeSize, f)
b.StartTimer()
_, _ = pb.ConsistencyProof(ctx, uint64(i%treeSize), treeSize)
}
})
})
}
Loading