Skip to content

Commit 41ff86c

Browse files
authored
Improve node cache (#921)
Improves the implementation of the node cache under client. With these changes, nodeCache is now thread-safe, and supports concurrent fetching of arbitrary nodes, potentially speeding up the process of building trees and proofs if the underlying storage has high latency.
1 parent 5523fb5 commit 41ff86c

2 files changed

Lines changed: 240 additions & 74 deletions

File tree

client/client.go

Lines changed: 158 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ import (
2323
"fmt"
2424
"sync"
2525

26+
lru "github.com/hashicorp/golang-lru/v2"
2627
"github.com/transparency-dev/formats/log"
2728
"github.com/transparency-dev/merkle/compact"
2829
"github.com/transparency-dev/merkle/proof"
@@ -32,6 +33,7 @@ import (
3233
"github.com/transparency-dev/tessera/internal/otel"
3334
"go.opentelemetry.io/otel/trace"
3435
"golang.org/x/mod/sumdb/note"
36+
"golang.org/x/sync/errgroup"
3537
)
3638

3739
var (
@@ -125,15 +127,7 @@ func FetchRangeNodes(ctx context.Context, s uint64, f TileFetcherFunc) ([][]byte
125127
nc := newNodeCache(f, s)
126128
nIDs := make([]compact.NodeID, 0, compact.RangeSize(0, s))
127129
nIDs = compact.RangeNodes(0, s, nIDs)
128-
hashes := make([][]byte, 0, len(nIDs))
129-
for _, n := range nIDs {
130-
h, err := nc.GetNode(ctx, n)
131-
if err != nil {
132-
return nil, err
133-
}
134-
hashes = append(hashes, h)
135-
}
136-
return hashes, nil
130+
return nc.GetNodes(ctx, nIDs)
137131
})
138132
}
139133

@@ -180,12 +174,12 @@ func GetEntryBundle(ctx context.Context, f EntryBundleFetcherFunc, i, logSize ui
180174
// at a given tree size.
181175
type ProofBuilder struct {
182176
treeSize uint64
183-
nodeCache nodeCache
177+
nodeCache *nodeCache
184178
}
185179

186180
// NewProofBuilder creates a new ProofBuilder object for a given tree size.
187-
// The returned ProofBuilder can be re-used for proofs related to a given tree size, but
188-
// it is not thread-safe and should not be accessed concurrently.
181+
// The returned ProofBuilder can be re-used for proofs related to a given tree size, and is
182+
// thread-safe.
189183
func NewProofBuilder(ctx context.Context, treeSize uint64, f TileFetcherFunc) (*ProofBuilder, error) {
190184
pb := &ProofBuilder{
191185
treeSize: treeSize,
@@ -204,7 +198,7 @@ func (pb *ProofBuilder) InclusionProof(ctx context.Context, index uint64) ([][]b
204198
if err != nil {
205199
return nil, fmt.Errorf("failed to calculate inclusion proof node list: %v", err)
206200
}
207-
return pb.fetchNodes(ctx, nodes)
201+
return pb.materialiseProof(ctx, nodes)
208202
})
209203
}
210204

@@ -221,22 +215,16 @@ func (pb *ProofBuilder) ConsistencyProof(ctx context.Context, smaller, larger ui
221215
if err != nil {
222216
return nil, fmt.Errorf("failed to calculate consistency proof node list: %v", err)
223217
}
224-
return pb.fetchNodes(ctx, nodes)
218+
return pb.materialiseProof(ctx, nodes)
225219
})
226220
}
227221

228-
// fetchNodes retrieves the specified proof nodes via pb's nodeCache.
229-
func (pb *ProofBuilder) fetchNodes(ctx context.Context, nodes proof.Nodes) ([][]byte, error) {
230-
hashes := make([][]byte, 0, len(nodes.IDs))
231-
// TODO(al) parallelise this.
232-
for _, id := range nodes.IDs {
233-
h, err := pb.nodeCache.GetNode(ctx, id)
234-
if err != nil {
235-
return nil, fmt.Errorf("failed to get node (%v): %v", id, err)
236-
}
237-
hashes = append(hashes, h)
222+
// materialiseProof retrieves the specified proof nodes via pb's nodeCache, recreating ephemeral nodes if necessary.
223+
func (pb *ProofBuilder) materialiseProof(ctx context.Context, nodes proof.Nodes) ([][]byte, error) {
224+
hashes, err := pb.nodeCache.GetNodes(ctx, nodes.IDs)
225+
if err != nil {
226+
return nil, err
238227
}
239-
var err error
240228
if hashes, err = nodes.Rehash(hashes, hasher.HashChildren); err != nil {
241229
return nil, fmt.Errorf("failed to rehash proof: %v", err)
242230
}
@@ -344,82 +332,179 @@ func (lst *LogStateTracker) Latest() log.Checkpoint {
344332
return lst.latestConsistent
345333
}
346334

347-
// tileKey is used as a key in nodeCache's tile map.
348-
type tileKey struct {
349-
tileLevel uint64
350-
tileIndex uint64
351-
}
352-
353335
// nodeCache hides the tiles abstraction away, and improves
354336
// performance by caching tiles it's seen.
355-
// Not threadsafe, and intended to be only used throughout the course
356-
// of a single request.
337+
// Threadsafe.
357338
type nodeCache struct {
358339
logSize uint64
359-
ephemeral map[compact.NodeID][]byte
360-
tiles map[tileKey]api.HashTile
340+
nodes *lru.Cache[compact.NodeID, []byte]
361341
getTile TileFetcherFunc
342+
tileLocks *shardedMutex[compact.NodeID]
362343
}
363344

364345
// newNodeCache creates a new nodeCache instance for a given log size.
365-
func newNodeCache(f TileFetcherFunc, logSize uint64) nodeCache {
366-
return nodeCache{
346+
func newNodeCache(f TileFetcherFunc, logSize uint64) *nodeCache {
347+
c, err := lru.New[compact.NodeID, []byte](64 << 10)
348+
if err != nil {
349+
panic(fmt.Errorf("lru.New: %v", err))
350+
}
351+
return &nodeCache{
367352
logSize: logSize,
368-
ephemeral: make(map[compact.NodeID][]byte),
369-
tiles: make(map[tileKey]api.HashTile),
353+
nodes: c,
370354
getTile: f,
355+
tileLocks: newShardedMutex[compact.NodeID](),
371356
}
372357
}
373358

374-
// SetEphemeralNode stored a derived "ephemeral" tree node.
375-
func (n *nodeCache) SetEphemeralNode(id compact.NodeID, h []byte) {
376-
n.ephemeral[id] = h
377-
}
378-
379359
// GetNode returns the internal log tree node hash for the specified node ID.
380-
// A previously set ephemeral node will be returned if id matches, otherwise
381-
// the tile containing the requested node will be fetched and cached, and the
382-
// node hash returned.
360+
// The tile containing the node will be fetched if necessary.
383361
func (n *nodeCache) GetNode(ctx context.Context, id compact.NodeID) ([]byte, error) {
384362
return otel.Trace(ctx, "tessera.client.nodecache.GetNode", tracer, func(ctx context.Context, span trace.Span) ([]byte, error) {
385363
span.SetAttributes(indexKey.Int64(otel.Clamp64(id.Index)), levelKey.Int64(int64(id.Level)))
364+
// Fast-path: check to see we have this node in the cache and return it directly if so, otherwise we'll need to fetch it.
365+
if e, ok := n.nodes.Get(id); ok {
366+
return e, nil
367+
}
386368

387-
// First check for ephemeral nodes:
388-
if e := n.ephemeral[id]; len(e) != 0 {
369+
// No dice, so we need to fetch the tile and use the contents to populate the cache.
370+
// We only want to do this once per tile, so lock keyed by the _tile_ ID here.
371+
tileLevel, tileIndex, _, _ := layout.NodeCoordsToTileAddress(uint64(id.Level), uint64(id.Index))
372+
k := compact.NodeID{Level: uint(tileLevel), Index: tileIndex}
373+
n.tileLocks.Lock(k)
374+
defer n.tileLocks.Unlock(k)
375+
// Re-check if we have the node cached - since we're under lock here it's possible that another goroutine
376+
// managed to get into this section before us and populate the cache.
377+
if e, ok := n.nodes.Get(id); ok {
389378
return e, nil
390379
}
391-
// Otherwise look in fetched tiles:
392-
tileLevel, tileIndex, nodeLevel, nodeIndex := layout.NodeCoordsToTileAddress(uint64(id.Level), uint64(id.Index))
393-
tKey := tileKey{tileLevel, tileIndex}
394-
t, ok := n.tiles[tKey]
395-
if !ok {
396-
span.AddEvent("cache miss")
397-
p := layout.PartialTileSize(tileLevel, tileIndex, n.logSize)
398-
tileRaw, err := n.getTile(ctx, tileLevel, tileIndex, p)
380+
span.AddEvent("cache miss")
381+
382+
p := layout.PartialTileSize(tileLevel, tileIndex, n.logSize)
383+
nodes, err := n.fetchTileNodes(ctx, tileLevel, tileIndex, p)
384+
if err != nil {
385+
return nil, fmt.Errorf("failed to fetch and populate node cache: %v", err)
386+
}
387+
for k, v := range nodes {
388+
n.nodes.Add(k, v)
389+
}
390+
if e, ok := nodes[id]; ok {
391+
return e, nil
392+
}
393+
return nil, fmt.Errorf("internal error: missing node %+v", id)
394+
})
395+
}
396+
397+
// GetNodes returns the tree hashes at the provided locations.
398+
func (n *nodeCache) GetNodes(ctx context.Context, nIDs []compact.NodeID) ([][]byte, error) {
399+
hashes := make([][]byte, len(nIDs))
400+
g, ctx := errgroup.WithContext(ctx)
401+
for i, id := range nIDs {
402+
g.Go(func() error {
403+
h, err := n.GetNode(ctx, id)
399404
if err != nil {
400-
return nil, fmt.Errorf("failed to fetch tile: %v", err)
401-
}
402-
var tile api.HashTile
403-
if err := tile.UnmarshalText(tileRaw); err != nil {
404-
return nil, fmt.Errorf("failed to parse tile: %v", err)
405+
return err
405406
}
406-
t = tile
407-
n.tiles[tKey] = tile
407+
hashes[i] = h
408+
return nil
409+
})
410+
}
411+
if err := g.Wait(); err != nil {
412+
return nil, err
413+
}
414+
return hashes, nil
415+
}
416+
417+
// fetchTileNodes retrieves the specified tile, parses it, and returns a map of tree-space-coordinate to node hash.
418+
func (n *nodeCache) fetchTileNodes(ctx context.Context, tileLevel, tileIndex uint64, p uint8) (map[compact.NodeID][]byte, error) {
419+
return otel.Trace(ctx, "tessera.client.nodecache.fetchTileNodes", tracer, func(ctx context.Context, span trace.Span) (map[compact.NodeID][]byte, error) {
420+
tileRaw, err := n.getTile(ctx, tileLevel, tileIndex, p)
421+
if err != nil {
422+
return nil, fmt.Errorf("failed to fetch tile: %v", err)
408423
}
409-
// We've got the tile, now we need to look up (or calculate) the node inside of it
410-
numLeaves := 1 << nodeLevel
411-
firstLeaf := int(nodeIndex) * numLeaves
412-
lastLeaf := firstLeaf + numLeaves
413-
if lastLeaf > len(t.Nodes) {
414-
return nil, fmt.Errorf("require leaf nodes [%d, %d) but only got %d leaves", firstLeaf, lastLeaf, len(t.Nodes))
424+
425+
var tile api.HashTile
426+
if err := tile.UnmarshalText(tileRaw); err != nil {
427+
return nil, fmt.Errorf("failed to parse tile: %v", err)
428+
}
429+
430+
ret := make(map[compact.NodeID][]byte, 256*2-1)
431+
// visitFn is a visitor callback which populates the nodes cache.
432+
// Used by the calls to compact range below.
433+
visitFn := func(intID compact.NodeID, h []byte) {
434+
// Figure out the "global" nodeID for the node intID in the requested tile.
435+
i := compact.NodeID{
436+
Level: uint(tileLevel*layout.TileHeight) + intID.Level,
437+
Index: (tileIndex*layout.TileWidth)>>intID.Level + intID.Index,
438+
}
439+
ret[i] = h
415440
}
416441
rf := compact.RangeFactory{Hash: hasher.HashChildren}
417442
r := rf.NewEmptyRange(0)
418-
for _, l := range t.Nodes[firstLeaf:lastLeaf] {
419-
if err := r.Append(l, nil); err != nil {
443+
for _, l := range tile.Nodes {
444+
if err := r.Append(l, visitFn); err != nil {
420445
return nil, fmt.Errorf("failed to Append: %v", err)
421446
}
422447
}
423-
return r.GetRootHash(nil)
448+
if _, err := r.GetRootHash(visitFn); err != nil {
449+
return nil, fmt.Errorf("failed to visit all nodes: %v", err)
450+
}
451+
return ret, nil
424452
})
425453
}
454+
455+
// cLock is a mutex which keeps track of the number of goroutines attempting to acquire a lock.
456+
type cLock struct {
457+
sync.Mutex
458+
n int64
459+
}
460+
461+
// shardedMutex is a set of mutexes sharded by key.
462+
//
463+
// For a given key, it acts as a regular mutex with the exception that it also tracks the number
464+
// of blocked goroutines waiting to acquire the lock.
465+
//
466+
// If a mutex doesn't exist for a given key at the point that Lock is called, one will be created.
467+
// To help guard against unbounded growth, mutexes with zero pending waiters at the point they're unlocked are deleted.
468+
type shardedMutex[K comparable] struct {
469+
// m protects the locks map and the waiter counts it contains.
470+
m *sync.Mutex
471+
locks map[K]*cLock
472+
}
473+
474+
// newShardedMutex creates a new shardedLock instance.
475+
func newShardedMutex[K comparable]() *shardedMutex[K] {
476+
return &shardedMutex[K]{
477+
m: &sync.Mutex{},
478+
locks: make(map[K]*cLock),
479+
}
480+
}
481+
482+
// Lock locks the given key.
483+
func (sl *shardedMutex[K]) Lock(k K) {
484+
sl.m.Lock()
485+
l, ok := sl.locks[k]
486+
if !ok {
487+
l = &cLock{}
488+
sl.locks[k] = l
489+
}
490+
l.n++
491+
sl.m.Unlock()
492+
493+
l.Lock()
494+
}
495+
496+
// Unlock unlocks the given key.
497+
func (sl *shardedMutex[K]) Unlock(k K) {
498+
sl.m.Lock()
499+
l, ok := sl.locks[k]
500+
if !ok {
501+
panic("unlock on non-existent key")
502+
}
503+
l.n--
504+
if l.n == 0 {
505+
delete(sl.locks, k)
506+
}
507+
sl.m.Unlock()
508+
509+
l.Unlock()
510+
}

0 commit comments

Comments
 (0)