@@ -23,6 +23,7 @@ import (
2323 "fmt"
2424 "sync"
2525
26+ lru "github.com/hashicorp/golang-lru/v2"
2627 "github.com/transparency-dev/formats/log"
2728 "github.com/transparency-dev/merkle/compact"
2829 "github.com/transparency-dev/merkle/proof"
@@ -32,6 +33,7 @@ import (
3233 "github.com/transparency-dev/tessera/internal/otel"
3334 "go.opentelemetry.io/otel/trace"
3435 "golang.org/x/mod/sumdb/note"
36+ "golang.org/x/sync/errgroup"
3537)
3638
3739var (
@@ -125,15 +127,7 @@ func FetchRangeNodes(ctx context.Context, s uint64, f TileFetcherFunc) ([][]byte
125127 nc := newNodeCache (f , s )
126128 nIDs := make ([]compact.NodeID , 0 , compact .RangeSize (0 , s ))
127129 nIDs = compact .RangeNodes (0 , s , nIDs )
128- hashes := make ([][]byte , 0 , len (nIDs ))
129- for _ , n := range nIDs {
130- h , err := nc .GetNode (ctx , n )
131- if err != nil {
132- return nil , err
133- }
134- hashes = append (hashes , h )
135- }
136- return hashes , nil
130+ return nc .GetNodes (ctx , nIDs )
137131 })
138132}
139133
@@ -180,12 +174,12 @@ func GetEntryBundle(ctx context.Context, f EntryBundleFetcherFunc, i, logSize ui
180174// at a given tree size.
181175type ProofBuilder struct {
182176 treeSize uint64
183- nodeCache nodeCache
177+ nodeCache * nodeCache
184178}
185179
186180// NewProofBuilder creates a new ProofBuilder object for a given tree size.
187- // The returned ProofBuilder can be re-used for proofs related to a given tree size, but
188- // it is not thread-safe and should not be accessed concurrently .
181+ // The returned ProofBuilder can be re-used for proofs related to a given tree size, and is
182+ // thread-safe.
189183func NewProofBuilder (ctx context.Context , treeSize uint64 , f TileFetcherFunc ) (* ProofBuilder , error ) {
190184 pb := & ProofBuilder {
191185 treeSize : treeSize ,
@@ -204,7 +198,7 @@ func (pb *ProofBuilder) InclusionProof(ctx context.Context, index uint64) ([][]b
204198 if err != nil {
205199 return nil , fmt .Errorf ("failed to calculate inclusion proof node list: %v" , err )
206200 }
207- return pb .fetchNodes (ctx , nodes )
201+ return pb .materialiseProof (ctx , nodes )
208202 })
209203}
210204
@@ -221,22 +215,16 @@ func (pb *ProofBuilder) ConsistencyProof(ctx context.Context, smaller, larger ui
221215 if err != nil {
222216 return nil , fmt .Errorf ("failed to calculate consistency proof node list: %v" , err )
223217 }
224- return pb .fetchNodes (ctx , nodes )
218+ return pb .materialiseProof (ctx , nodes )
225219 })
226220}
227221
228- // fetchNodes retrieves the specified proof nodes via pb's nodeCache.
229- func (pb * ProofBuilder ) fetchNodes (ctx context.Context , nodes proof.Nodes ) ([][]byte , error ) {
230- hashes := make ([][]byte , 0 , len (nodes .IDs ))
231- // TODO(al) parallelise this.
232- for _ , id := range nodes .IDs {
233- h , err := pb .nodeCache .GetNode (ctx , id )
234- if err != nil {
235- return nil , fmt .Errorf ("failed to get node (%v): %v" , id , err )
236- }
237- hashes = append (hashes , h )
222+ // materialiseProof retrieves the specified proof nodes via pb's nodeCache, recreating ephemeral nodes if necessary.
223+ func (pb * ProofBuilder ) materialiseProof (ctx context.Context , nodes proof.Nodes ) ([][]byte , error ) {
224+ hashes , err := pb .nodeCache .GetNodes (ctx , nodes .IDs )
225+ if err != nil {
226+ return nil , err
238227 }
239- var err error
240228 if hashes , err = nodes .Rehash (hashes , hasher .HashChildren ); err != nil {
241229 return nil , fmt .Errorf ("failed to rehash proof: %v" , err )
242230 }
@@ -344,82 +332,179 @@ func (lst *LogStateTracker) Latest() log.Checkpoint {
344332 return lst .latestConsistent
345333}
346334
347- // tileKey is used as a key in nodeCache's tile map.
348- type tileKey struct {
349- tileLevel uint64
350- tileIndex uint64
351- }
352-
353335// nodeCache hides the tiles abstraction away, and improves
354336// performance by caching tiles it's seen.
355- // Not threadsafe, and intended to be only used throughout the course
356- // of a single request.
337+ // Threadsafe.
357338type nodeCache struct {
358339 logSize uint64
359- ephemeral map [compact.NodeID ][]byte
360- tiles map [tileKey ]api.HashTile
340+ nodes * lru.Cache [compact.NodeID , []byte ]
361341 getTile TileFetcherFunc
342+ tileLocks * shardedMutex [compact.NodeID ]
362343}
363344
364345// newNodeCache creates a new nodeCache instance for a given log size.
365- func newNodeCache (f TileFetcherFunc , logSize uint64 ) nodeCache {
366- return nodeCache {
346+ func newNodeCache (f TileFetcherFunc , logSize uint64 ) * nodeCache {
347+ c , err := lru.New [compact.NodeID , []byte ](64 << 10 )
348+ if err != nil {
349+ panic (fmt .Errorf ("lru.New: %v" , err ))
350+ }
351+ return & nodeCache {
367352 logSize : logSize ,
368- ephemeral : make (map [compact.NodeID ][]byte ),
369- tiles : make (map [tileKey ]api.HashTile ),
353+ nodes : c ,
370354 getTile : f ,
355+ tileLocks : newShardedMutex [compact.NodeID ](),
371356 }
372357}
373358
374- // SetEphemeralNode stored a derived "ephemeral" tree node.
375- func (n * nodeCache ) SetEphemeralNode (id compact.NodeID , h []byte ) {
376- n .ephemeral [id ] = h
377- }
378-
379359// GetNode returns the internal log tree node hash for the specified node ID.
380- // A previously set ephemeral node will be returned if id matches, otherwise
381- // the tile containing the requested node will be fetched and cached, and the
382- // node hash returned.
360+ // The tile containing the node will be fetched if necessary.
383361func (n * nodeCache ) GetNode (ctx context.Context , id compact.NodeID ) ([]byte , error ) {
384362 return otel .Trace (ctx , "tessera.client.nodecache.GetNode" , tracer , func (ctx context.Context , span trace.Span ) ([]byte , error ) {
385363 span .SetAttributes (indexKey .Int64 (otel .Clamp64 (id .Index )), levelKey .Int64 (int64 (id .Level )))
364+ // Fast-path: check to see we have this node in the cache and return it directly if so, otherwise we'll need to fetch it.
365+ if e , ok := n .nodes .Get (id ); ok {
366+ return e , nil
367+ }
386368
387- // First check for ephemeral nodes:
388- if e := n .ephemeral [id ]; len (e ) != 0 {
369+ // No dice, so we need to fetch the tile and use the contents to populate the cache.
370+ // We only want to do this once per tile, so lock keyed by the _tile_ ID here.
371+ tileLevel , tileIndex , _ , _ := layout .NodeCoordsToTileAddress (uint64 (id .Level ), uint64 (id .Index ))
372+ k := compact.NodeID {Level : uint (tileLevel ), Index : tileIndex }
373+ n .tileLocks .Lock (k )
374+ defer n .tileLocks .Unlock (k )
375+ // Re-check if we have the node cached - since we're under lock here it's possible that another goroutine
376+ // managed to get into this section before us and populate the cache.
377+ if e , ok := n .nodes .Get (id ); ok {
389378 return e , nil
390379 }
391- // Otherwise look in fetched tiles:
392- tileLevel , tileIndex , nodeLevel , nodeIndex := layout .NodeCoordsToTileAddress (uint64 (id .Level ), uint64 (id .Index ))
393- tKey := tileKey {tileLevel , tileIndex }
394- t , ok := n .tiles [tKey ]
395- if ! ok {
396- span .AddEvent ("cache miss" )
397- p := layout .PartialTileSize (tileLevel , tileIndex , n .logSize )
398- tileRaw , err := n .getTile (ctx , tileLevel , tileIndex , p )
380+ span .AddEvent ("cache miss" )
381+
382+ p := layout .PartialTileSize (tileLevel , tileIndex , n .logSize )
383+ nodes , err := n .fetchTileNodes (ctx , tileLevel , tileIndex , p )
384+ if err != nil {
385+ return nil , fmt .Errorf ("failed to fetch and populate node cache: %v" , err )
386+ }
387+ for k , v := range nodes {
388+ n .nodes .Add (k , v )
389+ }
390+ if e , ok := nodes [id ]; ok {
391+ return e , nil
392+ }
393+ return nil , fmt .Errorf ("internal error: missing node %+v" , id )
394+ })
395+ }
396+
397+ // GetNodes returns the tree hashes at the provided locations.
398+ func (n * nodeCache ) GetNodes (ctx context.Context , nIDs []compact.NodeID ) ([][]byte , error ) {
399+ hashes := make ([][]byte , len (nIDs ))
400+ g , ctx := errgroup .WithContext (ctx )
401+ for i , id := range nIDs {
402+ g .Go (func () error {
403+ h , err := n .GetNode (ctx , id )
399404 if err != nil {
400- return nil , fmt .Errorf ("failed to fetch tile: %v" , err )
401- }
402- var tile api.HashTile
403- if err := tile .UnmarshalText (tileRaw ); err != nil {
404- return nil , fmt .Errorf ("failed to parse tile: %v" , err )
405+ return err
405406 }
406- t = tile
407- n .tiles [tKey ] = tile
407+ hashes [i ] = h
408+ return nil
409+ })
410+ }
411+ if err := g .Wait (); err != nil {
412+ return nil , err
413+ }
414+ return hashes , nil
415+ }
416+
417+ // fetchTileNodes retrieves the specified tile, parses it, and returns a map of tree-space-coordinate to node hash.
418+ func (n * nodeCache ) fetchTileNodes (ctx context.Context , tileLevel , tileIndex uint64 , p uint8 ) (map [compact.NodeID ][]byte , error ) {
419+ return otel .Trace (ctx , "tessera.client.nodecache.fetchTileNodes" , tracer , func (ctx context.Context , span trace.Span ) (map [compact.NodeID ][]byte , error ) {
420+ tileRaw , err := n .getTile (ctx , tileLevel , tileIndex , p )
421+ if err != nil {
422+ return nil , fmt .Errorf ("failed to fetch tile: %v" , err )
408423 }
409- // We've got the tile, now we need to look up (or calculate) the node inside of it
410- numLeaves := 1 << nodeLevel
411- firstLeaf := int (nodeIndex ) * numLeaves
412- lastLeaf := firstLeaf + numLeaves
413- if lastLeaf > len (t .Nodes ) {
414- return nil , fmt .Errorf ("require leaf nodes [%d, %d) but only got %d leaves" , firstLeaf , lastLeaf , len (t .Nodes ))
424+
425+ var tile api.HashTile
426+ if err := tile .UnmarshalText (tileRaw ); err != nil {
427+ return nil , fmt .Errorf ("failed to parse tile: %v" , err )
428+ }
429+
430+ ret := make (map [compact.NodeID ][]byte , 256 * 2 - 1 )
431+ // visitFn is a visitor callback which populates the nodes cache.
432+ // Used by the calls to compact range below.
433+ visitFn := func (intID compact.NodeID , h []byte ) {
434+ // Figure out the "global" nodeID for the node intID in the requested tile.
435+ i := compact.NodeID {
436+ Level : uint (tileLevel * layout .TileHeight ) + intID .Level ,
437+ Index : (tileIndex * layout .TileWidth )>> intID .Level + intID .Index ,
438+ }
439+ ret [i ] = h
415440 }
416441 rf := compact.RangeFactory {Hash : hasher .HashChildren }
417442 r := rf .NewEmptyRange (0 )
418- for _ , l := range t .Nodes [ firstLeaf : lastLeaf ] {
419- if err := r .Append (l , nil ); err != nil {
443+ for _ , l := range tile .Nodes {
444+ if err := r .Append (l , visitFn ); err != nil {
420445 return nil , fmt .Errorf ("failed to Append: %v" , err )
421446 }
422447 }
423- return r .GetRootHash (nil )
448+ if _ , err := r .GetRootHash (visitFn ); err != nil {
449+ return nil , fmt .Errorf ("failed to visit all nodes: %v" , err )
450+ }
451+ return ret , nil
424452 })
425453}
454+
455+ // cLock is a mutex which keeps track of the number of goroutines attempting to acquire a lock.
456+ type cLock struct {
457+ sync.Mutex
458+ n int64
459+ }
460+
461+ // shardedMutex is a set of mutexes sharded by key.
462+ //
463+ // For a given key, it acts as a regular mutex with the exception that it also tracks the number
464+ // of blocked goroutines waiting to acquire the lock.
465+ //
466+ // If a mutex doesn't exist for a given key at the point that Lock is called, one will be created.
467+ // To help guard against unbounded growth, mutexes with zero pending waiters at the point they're unlocked are deleted.
468+ type shardedMutex [K comparable ] struct {
469+ // m protects the locks map and the waiter counts it contains.
470+ m * sync.Mutex
471+ locks map [K ]* cLock
472+ }
473+
474+ // newShardedMutex creates a new shardedLock instance.
475+ func newShardedMutex [K comparable ]() * shardedMutex [K ] {
476+ return & shardedMutex [K ]{
477+ m : & sync.Mutex {},
478+ locks : make (map [K ]* cLock ),
479+ }
480+ }
481+
482+ // Lock locks the given key.
483+ func (sl * shardedMutex [K ]) Lock (k K ) {
484+ sl .m .Lock ()
485+ l , ok := sl .locks [k ]
486+ if ! ok {
487+ l = & cLock {}
488+ sl .locks [k ] = l
489+ }
490+ l .n ++
491+ sl .m .Unlock ()
492+
493+ l .Lock ()
494+ }
495+
496+ // Unlock unlocks the given key.
497+ func (sl * shardedMutex [K ]) Unlock (k K ) {
498+ sl .m .Lock ()
499+ l , ok := sl .locks [k ]
500+ if ! ok {
501+ panic ("unlock on non-existent key" )
502+ }
503+ l .n --
504+ if l .n == 0 {
505+ delete (sl .locks , k )
506+ }
507+ sl .m .Unlock ()
508+
509+ l .Unlock ()
510+ }
0 commit comments