diff --git a/consensus/consensus.go b/consensus/consensus.go index 6aabd49126..95d3bcc67a 100644 --- a/consensus/consensus.go +++ b/consensus/consensus.go @@ -5,7 +5,6 @@ import ( "github.com/NethermindEth/juno/blockchain" "github.com/NethermindEth/juno/builder" - consensusDB "github.com/NethermindEth/juno/consensus/db" "github.com/NethermindEth/juno/consensus/driver" "github.com/NethermindEth/juno/consensus/p2p" "github.com/NethermindEth/juno/consensus/p2p/config" @@ -16,6 +15,7 @@ import ( "github.com/NethermindEth/juno/consensus/tendermint" "github.com/NethermindEth/juno/consensus/types" "github.com/NethermindEth/juno/consensus/votecounter" + "github.com/NethermindEth/juno/consensus/walstore" "github.com/NethermindEth/juno/core/felt" "github.com/NethermindEth/juno/db" "github.com/NethermindEth/juno/p2p/sync" @@ -33,6 +33,15 @@ type ConsensusServices struct { CommitListener driver.CommitListener[starknet.Value, starknet.Hash] } +type initOptions struct { + wrapWALStore func( + walstore.TendermintWALStore[starknet.Value, starknet.Hash, starknet.Address], + ) walstore.TendermintWALStore[starknet.Value, starknet.Hash, starknet.Address] + wrapBroadcasters func( + p2p.Broadcasters[starknet.Value, starknet.Hash, starknet.Address], + ) p2p.Broadcasters[starknet.Value, starknet.Hash, starknet.Address] +} + func Init( host host.Host, logger *log.ZapLogger, @@ -45,16 +54,53 @@ func Init( timeoutFn driver.TimeoutFn, bootstrapPeersFn func() []peer.AddrInfo, compiler compiler.Compiler, +) (ConsensusServices, error) { + return initWithOptions( + host, + logger, + database, + blockchain, + vm, + blockFetcher, + nodeAddress, + validators, + timeoutFn, + bootstrapPeersFn, + compiler, + initOptions{}, + ) +} + +func initWithOptions( + host host.Host, + logger *log.ZapLogger, + database db.KeyValueStore, + blockchain *blockchain.Blockchain, + vm vm.VM, + blockFetcher *sync.BlockFetcher, + nodeAddress *starknet.Address, + validators votecounter.Validators[starknet.Address], + timeoutFn driver.TimeoutFn, + bootstrapPeersFn func() []peer.AddrInfo, + compiler compiler.Compiler, + options initOptions, ) (ConsensusServices, error) { chainHeight, err := blockchain.Height() if err != nil && !errors.Is(err, db.ErrKeyNotFound) { return ConsensusServices{}, err } currentHeight := types.Height(chainHeight + 1) - - tendermintDB := consensusDB.NewTendermintDB[ - starknet.Value, starknet.Hash, starknet.Address, + tendermintWALStore, err := walstore.NewTendermintWALStore[ + starknet.Value, + starknet.Hash, + starknet.Address, ](database) + if err != nil { + return ConsensusServices{}, err + } + if options.wrapWALStore != nil { + tendermintWALStore = options.wrapWALStore(tendermintWALStore) + } executor := builder.NewExecutor(blockchain, vm, logger, false, false) builder := builder.New(blockchain, executor) @@ -77,13 +123,17 @@ func Init( commitListener := driver.NewCommitListener(logger, &proposalStore, proposer, p2p) messageExtractor := consensusSync.New(validators, toValue, &proposalStore) + broadcasters := p2p.Broadcasters() + if options.wrapBroadcasters != nil { + broadcasters = options.wrapBroadcasters(broadcasters) + } driver := driver.New( logger, - tendermintDB, + tendermintWALStore, stateMachine, commitListener, - p2p.Broadcasters(), + broadcasters, p2p.Listeners(), blockFetcher, &messageExtractor, diff --git a/consensus/consensus_test.go b/consensus/consensus_test.go index 82fd48048a..9060a438b1 100644 --- a/consensus/consensus_test.go +++ b/consensus/consensus_test.go @@ -206,7 +206,7 @@ func writeBlock( require.NoError(t, err) require.NoError(t, bc.Store(committedBlock.Block, commitments, committedBlock.StateUpdate, committedBlock.NewClasses)) - close(committedBlock.Persisted) + committedBlock.Persisted <- nil commit := commit{ nodeIndex: index, diff --git a/consensus/db/buckets_consensus.go b/consensus/db/buckets_consensus.go deleted file mode 100644 index 2b0c3019bc..0000000000 --- a/consensus/db/buckets_consensus.go +++ /dev/null @@ -1,20 +0,0 @@ -package db - -import "slices" - -// The consensus service uses a separate DB with its own buckets -// -//go:generate go run github.com/dmarkham/enumer -type=BucketConsensus -output=buckets_consensus_enumer.go -type BucketConsensus byte - -// Pebble does not support buckets to differentiate between groups of -// keys like Bolt or MDBX does. We use a global prefix list as a poor -// man's bucket alternative. -const ( - WALEntryBucket BucketConsensus = iota // key: WAL_prefix + Height + MsgIndex. Val: Encoded Tendermint consensus message. -) - -// Key flattens a prefix and series of byte arrays into a single []byte. -func (b BucketConsensus) Key(key ...[]byte) []byte { - return append([]byte{byte(b)}, slices.Concat(key...)...) -} diff --git a/consensus/db/buckets_consensus_enumer.go b/consensus/db/buckets_consensus_enumer.go deleted file mode 100644 index e21b84094c..0000000000 --- a/consensus/db/buckets_consensus_enumer.go +++ /dev/null @@ -1,74 +0,0 @@ -// Code generated by "enumer -type=BucketConsensus -output=buckets_consensus_enumer.go"; DO NOT EDIT. - -package db - -import ( - "fmt" - "strings" -) - -const _BucketConsensusName = "WALEntryBucket" - -var _BucketConsensusIndex = [...]uint8{0, 14} - -const _BucketConsensusLowerName = "walentrybucket" - -func (i BucketConsensus) String() string { - if i >= BucketConsensus(len(_BucketConsensusIndex)-1) { - return fmt.Sprintf("BucketConsensus(%d)", i) - } - return _BucketConsensusName[_BucketConsensusIndex[i]:_BucketConsensusIndex[i+1]] -} - -// An "invalid array index" compiler error signifies that the constant values have changed. -// Re-run the stringer command to generate them again. -func _BucketConsensusNoOp() { - var x [1]struct{} - _ = x[WALEntryBucket-(0)] -} - -var _BucketConsensusValues = []BucketConsensus{WALEntryBucket} - -var _BucketConsensusNameToValueMap = map[string]BucketConsensus{ - _BucketConsensusName[0:14]: WALEntryBucket, - _BucketConsensusLowerName[0:14]: WALEntryBucket, -} - -var _BucketConsensusNames = []string{ - _BucketConsensusName[0:14], -} - -// BucketConsensusString retrieves an enum value from the enum constants string name. -// Throws an error if the param is not part of the enum. -func BucketConsensusString(s string) (BucketConsensus, error) { - if val, ok := _BucketConsensusNameToValueMap[s]; ok { - return val, nil - } - - if val, ok := _BucketConsensusNameToValueMap[strings.ToLower(s)]; ok { - return val, nil - } - return 0, fmt.Errorf("%s does not belong to BucketConsensus values", s) -} - -// BucketConsensusValues returns all values of the enum -func BucketConsensusValues() []BucketConsensus { - return _BucketConsensusValues -} - -// BucketConsensusStrings returns a slice of all String values of the enum -func BucketConsensusStrings() []string { - strs := make([]string, len(_BucketConsensusNames)) - copy(strs, _BucketConsensusNames) - return strs -} - -// IsABucketConsensus returns "true" if the value is listed in the enum definition. "false" otherwise -func (i BucketConsensus) IsABucketConsensus() bool { - for _, v := range _BucketConsensusValues { - if i == v { - return true - } - } - return false -} diff --git a/consensus/db/db.go b/consensus/db/db.go deleted file mode 100644 index 8ef091d1f5..0000000000 --- a/consensus/db/db.go +++ /dev/null @@ -1,162 +0,0 @@ -package db - -import ( - "bytes" - "encoding/binary" - "fmt" - "iter" - - "github.com/NethermindEth/juno/consensus/types" - "github.com/NethermindEth/juno/consensus/types/wal" - "github.com/NethermindEth/juno/db" - "github.com/NethermindEth/juno/db/dbutils" - "github.com/NethermindEth/juno/encoder" -) - -const NumBytesForHeight = 4 - -// walMsgCount tracks the number of wal entries at the current height -type walMsgCount uint32 - -// TendermintDB defines the methods for interacting with the Tendermint WAL database. -// -// The purpose of the WAL is to record any event that may result in a state change. -// These events fall into the following categories: -// 1. Incoming messages. We do not need to store outgoing messages. -// 2. When we propose a value. -// 3. When a timeout is triggered (not when it is scheduled). -// The purpose of the WAL is to allow the node to recover the state it was in before the crash. -// No new messages should be broadcast during replay. -// -// We commit the WAL to disk when: -// 1. We start a new round -// 2. Right before we broadcast a message -// -// We call Delete when we start a new height and commit a block -type TendermintDB[V types.Hashable[H], H types.Hash, A types.Addr] interface { - // Flush writes the accumulated batch operations to the underlying database. - Flush() error - // GetWALEntries retrieves all WAL messages (consensus messages and timeouts) stored for a given height from the database. - LoadAllEntries() iter.Seq2[wal.Entry[V, H, A], error] - // SetWALEntry schedules the storage of a WAL message in the batch. - SetWALEntry(entry wal.Entry[V, H, A]) error - // DeleteWALEntries schedules the deletion of all WAL messages for a specific height in the batch. - DeleteWALEntries(height types.Height) error -} - -// tendermintDB provides database access for Tendermint consensus state. -// We use a Batch to accumulate writes before committing them to the DB. -// This reduces expensive disk I/O. Reads check the batch first. -// WARNING: If the process crashes before Commit(), buffered messages are lost, -// risking consensus issues like equivocation (missing votes, double-signing). -type tendermintDB[V types.Hashable[H], H types.Hash, A types.Addr] struct { - db db.KeyValueStore - batch db.Batch - walCount map[types.Height]walMsgCount -} - -// NewTendermintDB creates a new TMDB instance implementing the TMDBInterface. -func NewTendermintDB[V types.Hashable[H], H types.Hash, A types.Addr](db db.KeyValueStore) TendermintDB[V, H, A] { - return &tendermintDB[V, H, A]{ - db: db, - batch: db.NewBatch(), - walCount: make(map[types.Height]walMsgCount), - } -} - -// Flush implements TMDBInterface. -func (s *tendermintDB[V, H, A]) Flush() error { - if err := s.batch.Write(); err != nil { - return err - } - // A batch must not be used after it has been committed. Reusing a batch after commit will result in a panic. - s.batch = s.db.NewBatch() - return nil -} - -// DeleteWALEntries iterates through the expected message keys based on the stored count. -// Note: This operates on the batch. Changes are only persisted after Flush() is called. -func (s *tendermintDB[V, H, A]) DeleteWALEntries(height types.Height) error { - startKey := WALEntryBucket.Key(encodeHeight(height)) - endKey := dbutils.UpperBound(startKey) - - if err := s.batch.DeleteRange(startKey, endKey); err != nil { - return fmt.Errorf("DeleteWALEntries: failed to add delete range [%x, %x) to batch: %w", startKey, endKey, err) - } - - delete(s.walCount, height) - return nil -} - -// SetWALEntry implements TMDBInterface. -func (s *tendermintDB[V, H, A]) SetWALEntry(entry wal.Entry[V, H, A]) error { - marshaledEntry, err := encoder.Marshal(entry) - if err != nil { - return fmt.Errorf("SetWALEntry: marshal entry failed: %w", err) - } - - msgKey := s.nextKey(entry.GetHeight()) - if err := s.batch.Put(msgKey, marshaledEntry); err != nil { - return fmt.Errorf("writeWALEntryToBatch: failed to set MsgsAtHeight: %w", err) - } - - return nil -} - -// LoadAllEntries implements TMDBInterface. -func (s *tendermintDB[V, H, A]) LoadAllEntries() iter.Seq2[wal.Entry[V, H, A], error] { - return func(yield func(wal.Entry[V, H, A], error) bool) { - err := s.db.View(func(snap db.Snapshot) error { - defer snap.Close() - - iter, err := snap.NewIterator(WALEntryBucket.Key(), true) - if err != nil { - return fmt.Errorf("failed to create iter: %w", err) - } - defer iter.Close() - - for iter.First(); iter.Valid(); iter.Next() { - v, err := iter.Value() - if err != nil { - return fmt.Errorf("failed to get iter value: %w", err) - } - - var walEntry wal.Entry[V, H, A] - if err := encoder.Unmarshal(v, &walEntry); err != nil { - return fmt.Errorf("failed to unmarshal walEntry: %w", err) - } - - expectedKey := s.nextKey(walEntry.GetHeight()) - if !bytes.Equal(iter.Key(), expectedKey) { - return fmt.Errorf("unexpected key %x, expected %x", iter.Key(), expectedKey) - } - - if !yield(walEntry, nil) { - return nil - } - } - return nil - }) - if err != nil { - yield(nil, err) - } - } -} - -func (s *tendermintDB[V, H, A]) nextKey(height types.Height) []byte { - nextIndex := s.walCount[height] - s.walCount[height] = nextIndex + 1 - return WALEntryBucket.Key(encodeHeight(height), encodeNumMsgsAtHeight(nextIndex)) -} - -func encodeHeight(height types.Height) []byte { - heightBytes := make([]byte, NumBytesForHeight) - binary.BigEndian.PutUint32(heightBytes, uint32(height)) - return heightBytes -} - -func encodeNumMsgsAtHeight(numMsgsAtHeight walMsgCount) []byte { - numMsgsAtHeightBytes := make([]byte, NumBytesForHeight) - binary.BigEndian.PutUint32(numMsgsAtHeightBytes, uint32(numMsgsAtHeight)) - return numMsgsAtHeightBytes -} diff --git a/consensus/db/db_test.go b/consensus/db/db_test.go deleted file mode 100644 index 89a35c8e3b..0000000000 --- a/consensus/db/db_test.go +++ /dev/null @@ -1,171 +0,0 @@ -package db - -import ( - "math/rand/v2" - "slices" - "testing" - - "github.com/NethermindEth/juno/consensus/starknet" - "github.com/NethermindEth/juno/consensus/types" - "github.com/NethermindEth/juno/core/felt" - "github.com/NethermindEth/juno/db" - "github.com/NethermindEth/juno/db/pebblev2" - "github.com/stretchr/testify/require" -) - -type testTendermintDB = TendermintDB[starknet.Value, starknet.Hash, starknet.Address] - -func newTestTMDB(t *testing.T) (testTendermintDB, db.KeyValueStore, string) { - t.Helper() - dbPath := t.TempDir() - testDB, err := pebblev2.New(dbPath) - require.NoError(t, err) - - tmState := NewTendermintDB[starknet.Value, starknet.Hash, starknet.Address](testDB) - require.NotNil(t, tmState) - - return tmState, testDB, dbPath -} - -func reopenTestTMDB(t *testing.T, oldDB db.KeyValueStore, dbPath string) (testTendermintDB, db.KeyValueStore) { - t.Helper() - require.NoError(t, oldDB.Close()) - - newDB, err := pebblev2.New(dbPath) - require.NoError(t, err) - - tmState := NewTendermintDB[starknet.Value, starknet.Hash, starknet.Address](newDB) - return tmState, newDB -} - -func buildExpectedEntries(testHeight types.Height) []starknet.WALEntry { - testRound := types.Round(1) - testStep := types.StepPrevote - - sender1 := felt.FromUint64[starknet.Address](1) - sender2 := felt.FromUint64[starknet.Address](2) - sender3 := felt.FromUint64[starknet.Address](3) - val1 := felt.FromUint64[starknet.Value](10) - valHash1 := val1.Hash() - - proposal := starknet.WALProposal{ - MessageHeader: starknet.MessageHeader{Height: testHeight, Round: testRound, Sender: sender1}, - ValidRound: testRound, - Value: &val1, - } - prevote := starknet.WALPrevote{ - MessageHeader: starknet.MessageHeader{Height: testHeight, Round: testRound, Sender: sender2}, - ID: &valHash1, - } - precommit := starknet.WALPrecommit{ - MessageHeader: starknet.MessageHeader{Height: testHeight, Round: testRound, Sender: sender3}, - ID: &valHash1, - } - timeoutMsg := starknet.WALTimeout{Height: testHeight, Round: testRound, Step: testStep} - - return []starknet.WALEntry{ - &proposal, - &prevote, - &precommit, - &timeoutMsg, - } -} - -// This utility function mixes the entries from the different batches into a single slice. -// The entries from the same batch are kept in order. -// The entries from the different batches are interleaved in a random order. -func mixEntries(entryList [][]starknet.WALEntry) []starknet.WALEntry { - totalCount := 0 - for _, entries := range entryList { - totalCount += len(entries) - } - - mixed := make([]starknet.WALEntry, totalCount) - for i := range mixed { - selected := rand.IntN(len(entryList)) - mixed[i] = entryList[selected][0] - entryList[selected] = entryList[selected][1:] - if len(entryList[selected]) == 0 { - entryList[selected] = entryList[len(entryList)-1] - entryList = entryList[:len(entryList)-1] - } - } - return mixed -} - -func testWrite(t *testing.T, tmState testTendermintDB, entryList ...[]starknet.WALEntry) { - t.Helper() - for _, entry := range mixEntries(entryList) { - require.NoError(t, tmState.SetWALEntry(entry)) - } - require.NoError(t, tmState.Flush()) -} - -func testRead(t *testing.T, tmState testTendermintDB, entryList ...[]starknet.WALEntry) { - t.Helper() - index := 0 - entries := slices.Concat(entryList...) - for entry, err := range tmState.LoadAllEntries() { - require.NoError(t, err) - require.Equal(t, entries[index], entry) - index++ - } -} - -func TestWALLifecycle(t *testing.T) { - firstHeight := types.Height(1000) - secondHeight := firstHeight + 1 - thirdHeight := secondHeight + 1 - firstHeightFirstBatch := buildExpectedEntries(firstHeight) - firstHeightSecondBatch := buildExpectedEntries(firstHeight) - secondHeightFirstBatch := buildExpectedEntries(secondHeight) - secondHeightSecondBatch := buildExpectedEntries(secondHeight) - thirdHeightFirstBatch := buildExpectedEntries(thirdHeight) - thirdHeightSecondBatch := buildExpectedEntries(thirdHeight) - tmState, db, dbPath := newTestTMDB(t) - - t.Run("Write entries from height 1 batch 1 and height 2 batch 1", func(t *testing.T) { - testWrite(t, tmState, firstHeightFirstBatch, secondHeightFirstBatch) - }) - - t.Run("Reload the db and get entries from the first 2 heights", func(t *testing.T) { - tmState, db = reopenTestTMDB(t, db, dbPath) - testRead(t, tmState, firstHeightFirstBatch, secondHeightFirstBatch) - }) - - t.Run("Write entries from height 1 batch 2 and height 3 batch 1", func(t *testing.T) { - testWrite(t, tmState, firstHeightSecondBatch, thirdHeightFirstBatch) - }) - - t.Run("Reload the db and get entries from 3 heights", func(t *testing.T) { - tmState, db = reopenTestTMDB(t, db, dbPath) - testRead( - t, - tmState, - firstHeightFirstBatch, - firstHeightSecondBatch, - secondHeightFirstBatch, - thirdHeightFirstBatch, - ) - }) - - t.Run("Delete entries", func(t *testing.T) { - require.NoError(t, tmState.DeleteWALEntries(firstHeight)) - }) - - t.Run("Write entries from height 2 batch 2 and height 3 batch 2", func(t *testing.T) { - testWrite(t, tmState, secondHeightSecondBatch, thirdHeightSecondBatch) - }) - - t.Run("Reload the db and get entries from the last 2 heights", func(t *testing.T) { - tmState, db = reopenTestTMDB(t, db, dbPath) - testRead( - t, - tmState, - secondHeightFirstBatch, - secondHeightSecondBatch, - thirdHeightFirstBatch, - thirdHeightSecondBatch, - ) - }) -} diff --git a/consensus/driver/commit_listener.go b/consensus/driver/commit_listener.go index 868d2c84ab..8538e625ff 100644 --- a/consensus/driver/commit_listener.go +++ b/consensus/driver/commit_listener.go @@ -17,7 +17,7 @@ type CommitHook[V types.Hashable[H], H types.Hash] interface { // CommitListener is a component that is used to notify different components that a new committed block is available. type CommitListener[V types.Hashable[H], H types.Hash] interface { - CommitHook[V, H] + OnCommit(context.Context, types.Height, V) bool // Listen returns a channel that will receive committed blocks. // This is supposed to be used by the component that writes the committed blocks to the database. Listen() <-chan sync.CommittedBlock @@ -44,31 +44,39 @@ func NewCommitListener[V types.Hashable[H], H types.Hash]( } } -func (b *commitListener[V, H]) OnCommit(ctx context.Context, height types.Height, value V) { +func (b *commitListener[V, H]) OnCommit(ctx context.Context, height types.Height, value V) bool { buildResult := b.proposalStore.Get(value.Hash()) if buildResult == nil { // todo(rdr): we can avoid using the ANY by writing some representation into Hash b.logger.Error("failed to get build result", zap.Any("hash", value.Hash())) - return + return false } committedBlock := sync.CommittedBlock{ Block: buildResult.PreConfirmed.Block, StateUpdate: buildResult.PreConfirmed.StateUpdate, NewClasses: buildResult.PreConfirmed.NewClasses, - Persisted: make(chan struct{}), + Persisted: make(chan error, 1), } select { case <-ctx.Done(): - return + return false case b.commits <- committedBlock: } select { case <-ctx.Done(): - return - case <-committedBlock.Persisted: + return false + case err := <-committedBlock.Persisted: + if err != nil { + b.logger.Warn( + "failed to persist committed block", + zap.Uint("height", uint(height)), + zap.Error(err), + ) + return false + } } wg := gosync.WaitGroup{} @@ -78,8 +86,8 @@ func (b *commitListener[V, H]) OnCommit(ctx context.Context, height types.Height }) } wg.Wait() - b.proposalStore.FinalizeHeight(height) + return true } func (b *commitListener[V, H]) Listen() <-chan sync.CommittedBlock { diff --git a/consensus/driver/commit_listener_test.go b/consensus/driver/commit_listener_test.go index 0d29b1c8b7..0cd6c8b467 100644 --- a/consensus/driver/commit_listener_test.go +++ b/consensus/driver/commit_listener_test.go @@ -2,6 +2,7 @@ package driver_test import ( "context" + "errors" "testing" "time" @@ -20,7 +21,7 @@ import ( func TestOnCommit_DropsProposalsAtCommittedHeight(t *testing.T) { proposalStore := &proposal.ProposalStore[starknet.Hash]{} - listener := driver.NewCommitListener[starknet.Value, starknet.Hash]( + listener := driver.NewCommitListener[starknet.Value]( log.NewNopZapLogger(), proposalStore, ) @@ -45,7 +46,7 @@ func TestOnCommit_DropsProposalsAtCommittedHeight(t *testing.T) { }() block := <-listener.Listen() - close(block.Persisted) + block.Persisted <- nil <-done require.Nil(t, proposalStore.Get(committedValue.Hash())) @@ -53,6 +54,44 @@ func TestOnCommit_DropsProposalsAtCommittedHeight(t *testing.T) { require.NotNil(t, proposalStore.Get(nextHeightValue.Hash())) } +type testCommitHook struct { + called chan struct{} +} + +func (h *testCommitHook) OnCommit(context.Context, types.Height, starknet.Value) { + h.called <- struct{}{} +} + +func TestCommitListenerOnCommitReturnsFalseWhenPersistenceFails(t *testing.T) { + committedHeight := types.Height(1) + value := felt.FromUint64[starknet.Value](1) + store := &proposal.ProposalStore[starknet.Hash]{} + store.Store(value.Hash(), buildResultAtHeight(uint64(committedHeight))) + + hook := &testCommitHook{called: make(chan struct{}, 1)} + listener := driver.NewCommitListener( + log.NewNopZapLogger(), + store, + hook, + ) + + resultCh := make(chan bool, 1) + go func() { + resultCh <- listener.OnCommit(context.Background(), committedHeight, value) + }() + + committedBlock := <-listener.Listen() + committedBlock.Persisted <- errors.New("store failed") + + require.False(t, <-resultCh) + + select { + case <-hook.called: + t.Fatal("commit hook should not run when persistence fails") + default: + } +} + func buildResultAtHeight(blockNumber uint64) *builder.BuildResult { return &builder.BuildResult{ PreConfirmed: &pending.PreConfirmed{ diff --git a/consensus/driver/driver.go b/consensus/driver/driver.go index d5f65a4260..57ed561c82 100644 --- a/consensus/driver/driver.go +++ b/consensus/driver/driver.go @@ -2,16 +2,17 @@ package driver import ( "context" + "errors" "fmt" gosync "sync" "time" - "github.com/NethermindEth/juno/consensus/db" "github.com/NethermindEth/juno/consensus/p2p" consensusSync "github.com/NethermindEth/juno/consensus/sync" "github.com/NethermindEth/juno/consensus/tendermint" "github.com/NethermindEth/juno/consensus/types" "github.com/NethermindEth/juno/consensus/types/actions" + "github.com/NethermindEth/juno/consensus/walstore" "github.com/NethermindEth/juno/p2p/sync" "github.com/NethermindEth/juno/utils/log" "go.uber.org/zap" @@ -21,7 +22,7 @@ type TimeoutFn func(step types.Step, round types.Round) time.Duration type Driver[V types.Hashable[H], H types.Hash, A types.Addr] struct { logger log.Logger - db db.TendermintDB[V, H, A] + db walstore.TendermintWALStore[V, H, A] stateMachine tendermint.StateMachine[V, H, A] commitListener CommitListener[V, H] broadcasters p2p.Broadcasters[V, H, A] @@ -39,7 +40,7 @@ type Driver[V types.Hashable[H], H types.Hash, A types.Addr] struct { func New[V types.Hashable[H], H types.Hash, A types.Addr]( logger log.Logger, - db db.TendermintDB[V, H, A], + db walstore.TendermintWALStore[V, H, A], stateMachine tendermint.StateMachine[V, H, A], commitListener CommitListener[V, H], broadcasters p2p.Broadcasters[V, H, A], @@ -70,12 +71,15 @@ func New[V types.Hashable[H], H types.Hash, A types.Addr]( // these messages and returns a set of actions to be executed by the Driver. // The Driver executes these actions (namely broadcasting messages // and triggering scheduled timeouts). -func (d *Driver[V, H, A]) Run(ctx context.Context) error { +func (d *Driver[V, H, A]) Run(ctx context.Context) (err error) { defer func() { for _, tm := range d.scheduledTms { tm.Stop() } }() + defer func() { + err = errors.Join(err, d.db.Close()) + }() if err := d.replay(ctx); err != nil { return err @@ -87,7 +91,10 @@ func (d *Driver[V, H, A]) Run(ctx context.Context) error { func (d *Driver[V, H, A]) replay(ctx context.Context) error { for walEntry, err := range d.db.LoadAllEntries() { if err != nil { - return fmt.Errorf("failed to load WAL entries: %w", err) + return fmt.Errorf("loading WAL entries: %w", err) + } + if walEntry.GetHeight() < d.stateMachine.Height() { + continue } if _, err := d.execute(ctx, true, d.stateMachine.ProcessWAL(walEntry)); err != nil { @@ -174,7 +181,7 @@ func (d *Driver[V, H, A]) execute( for _, action := range resultActions { if !isReplaying && action.RequiresWALFlush() { if err := d.db.Flush(); err != nil { - return false, fmt.Errorf("failed to flush WAL: %w", err) + return false, fmt.Errorf("flushing WAL: %w", err) } } @@ -182,7 +189,7 @@ func (d *Driver[V, H, A]) execute( case *actions.WriteWAL[V, H, A]: if !isReplaying { if err := d.db.SetWALEntry(action.Entry); err != nil { - return false, fmt.Errorf("failed to write WAL: %w", err) + return false, fmt.Errorf("writing WAL: %w", err) } } @@ -226,13 +233,18 @@ func (d *Driver[V, H, A]) commit(ctx context.Context, commit *actions.Commit[V, zap.Uint("height", uint(commit.Height)), zap.Int("round", int(commit.Round)), ) - d.commitListener.OnCommit(ctx, commit.Height, *commit.Value) + if !d.commitListener.OnCommit(ctx, commit.Height, *commit.Value) { + if err := ctx.Err(); err != nil { + return err + } + return errors.New("commit listener failed") + } if err := d.db.DeleteWALEntries(commit.Height); err != nil { - return fmt.Errorf("failed to delete WAL messages during commit: %w", err) + return fmt.Errorf("deleting WAL messages during commit: %w", err) } - return nil + return d.db.Flush() } func (d *Driver[V, H, A]) triggerSync(ctx context.Context, triggerSync actions.TriggerSync) { diff --git a/consensus/driver/driver_test.go b/consensus/driver/driver_test.go index 077dcd7fad..0d324c0144 100644 --- a/consensus/driver/driver_test.go +++ b/consensus/driver/driver_test.go @@ -6,13 +6,14 @@ import ( "testing" "time" - "github.com/NethermindEth/juno/consensus/db" "github.com/NethermindEth/juno/consensus/driver" "github.com/NethermindEth/juno/consensus/mocks" "github.com/NethermindEth/juno/consensus/p2p" "github.com/NethermindEth/juno/consensus/starknet" "github.com/NethermindEth/juno/consensus/types" "github.com/NethermindEth/juno/consensus/types/actions" + consensuswal "github.com/NethermindEth/juno/consensus/types/wal" + "github.com/NethermindEth/juno/consensus/walstore" "github.com/NethermindEth/juno/core/felt" "github.com/NethermindEth/juno/db/pebblev2" "github.com/NethermindEth/juno/utils/log" @@ -23,10 +24,10 @@ import ( ) type ( - listeners = p2p.Listeners[starknet.Value, starknet.Hash, starknet.Address] - broadcasters = p2p.Broadcasters[starknet.Value, starknet.Hash, starknet.Address] - tendermintDB = db.TendermintDB[starknet.Value, starknet.Hash, starknet.Address] - commitListener = driver.CommitListener[starknet.Value, starknet.Hash] + listeners = p2p.Listeners[starknet.Value, starknet.Hash, starknet.Address] + broadcasters = p2p.Broadcasters[starknet.Value, starknet.Hash, starknet.Address] + tendermintWALStore = walstore.TendermintWALStore[starknet.Value, starknet.Hash, starknet.Address] + commitListener = driver.CommitListener[starknet.Value, starknet.Hash] ) const ( @@ -95,22 +96,22 @@ func generateAndRegisterRandomActions( case 0: proposal := getRandProposal(random) expectedBroadcast.proposals = append(expectedBroadcast.proposals, &proposal) - actions[i] = new(starknet.BroadcastProposal(proposal)) + actions[i] = (*starknet.BroadcastProposal)(&proposal) case 1: prevote := getRandPrevote(random) expectedBroadcast.prevotes = append(expectedBroadcast.prevotes, &prevote) - actions[i] = new(starknet.BroadcastPrevote(prevote)) + actions[i] = (*starknet.BroadcastPrevote)(&prevote) case 2: precommit := getRandPrecommit(random) expectedBroadcast.precommits = append(expectedBroadcast.precommits, &precommit) - actions[i] = new(starknet.BroadcastPrecommit(precommit)) + actions[i] = (*starknet.BroadcastPrecommit)(&precommit) } } return actions } func toAction(timeout types.Timeout) starknet.Action { - return new(actions.ScheduleTimeout(timeout)) + return (*actions.ScheduleTimeout)(&timeout) } func increaseBroadcasterWaitGroup[M any]( @@ -134,13 +135,24 @@ func waitAndAssertBroadcaster[M any]( assert.ElementsMatch(t, expectedBroadcast, mockBroadcaster.broadcastedMessages) } -func newTendermintDB(t *testing.T) tendermintDB { +func newTendermintWALStore(t *testing.T) tendermintWALStore { t.Helper() - dbPath := t.TempDir() - pebbleDB, err := pebblev2.New(dbPath) + pebbleDB, err := pebblev2.New(t.TempDir()) require.NoError(t, err) + t.Cleanup(func() { + require.NoError(t, pebbleDB.Close()) + }) - return db.NewTendermintDB[starknet.Value, starknet.Hash, starknet.Address](pebbleDB) + walStore, err := walstore.NewTendermintWALStore[ + starknet.Value, + starknet.Hash, + starknet.Address, + ](pebbleDB) + require.NoError(t, err) + t.Cleanup(func() { + require.NoError(t, walStore.Close()) + }) + return walStore } func TestDriver(t *testing.T) { @@ -173,7 +185,7 @@ func TestDriver(t *testing.T) { driver := driver.New( log.NewNopZapLogger(), - newTendermintDB(t), + newTendermintWALStore(t), stateMachine, newMockCommitListener(t, &commitAction), broadcasters, @@ -241,3 +253,66 @@ func TestDriver(t *testing.T) { cancel() } + +func TestDriverReturnsErrorWhenCommitListenerFails(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + stateMachine := mocks.NewMockStateMachine[starknet.Value, starknet.Hash, starknet.Address](ctrl) + commitAction := starknet.Commit(getRandProposal(rand.New(rand.NewSource(seed + 1)))) + + // Emit the commit straight from ProcessStart to exercise the commit path + // without simulating a full proposal/prevote/precommit round. + stateMachine.EXPECT().ProcessStart(types.Round(0)).Return([]starknet.Action{&commitAction}) + + commitListener := &mockCommitListener{ + t: t, + expectedCommit: &commitAction, + persisted: false, + } + + driver := driver.New( + log.NewNopZapLogger(), + newTendermintWALStore(t), + stateMachine, + commitListener, + broadcasters{}, + listeners{}, + nil, + nil, + mockTimeoutFn, + ) + + require.Error(t, driver.Run(t.Context())) +} + +func TestDriverReplaySkipsEntriesBelowCurrentHeight(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + walStore := newTendermintWALStore(t) + currentHeight := types.Height(2) + staleWALEntry := consensuswal.Start(currentHeight - 1) + require.NoError(t, walStore.SetWALEntry(&staleWALEntry)) + require.NoError(t, walStore.Flush()) + + stateMachine := mocks.NewMockStateMachine[starknet.Value, starknet.Hash, starknet.Address](ctrl) + stateMachine.EXPECT().Height().Return(currentHeight) + stateMachine.EXPECT().ProcessWAL(gomock.Any()).Times(0) + + driver := driver.New( + log.NewNopZapLogger(), + walStore, + stateMachine, + &mockCommitListener{t: t}, + broadcasters{}, + listeners{}, + nil, + nil, + mockTimeoutFn, + ) + + ctx, cancel := context.WithCancel(t.Context()) + cancel() + require.NoError(t, driver.Run(ctx)) +} diff --git a/consensus/driver/mock_blockchain_test.go b/consensus/driver/mock_blockchain_test.go index dc8cb39558..585828d470 100644 --- a/consensus/driver/mock_blockchain_test.go +++ b/consensus/driver/mock_blockchain_test.go @@ -13,17 +13,24 @@ import ( type mockCommitListener struct { t *testing.T expectedCommit *starknet.Commit + persisted bool } -func (m *mockCommitListener) OnCommit(ctx context.Context, height types.Height, value starknet.Value) { +func (m *mockCommitListener) OnCommit( + ctx context.Context, + height types.Height, + value starknet.Value, +) bool { require.Equal(m.t, m.expectedCommit.Value, &value) require.Equal(m.t, m.expectedCommit.Height, height) + return m.persisted } func newMockCommitListener(t *testing.T, expectedCommit *starknet.Commit) commitListener { return &mockCommitListener{ t: t, expectedCommit: expectedCommit, + persisted: true, } } diff --git a/consensus/export_test.go b/consensus/export_test.go new file mode 100644 index 0000000000..4511e53ac8 --- /dev/null +++ b/consensus/export_test.go @@ -0,0 +1,59 @@ +package consensus + +import ( + "github.com/NethermindEth/juno/blockchain" + "github.com/NethermindEth/juno/consensus/driver" + "github.com/NethermindEth/juno/consensus/p2p" + "github.com/NethermindEth/juno/consensus/starknet" + "github.com/NethermindEth/juno/consensus/votecounter" + "github.com/NethermindEth/juno/consensus/walstore" + kvdb "github.com/NethermindEth/juno/db" + p2psync "github.com/NethermindEth/juno/p2p/sync" + "github.com/NethermindEth/juno/starknet/compiler" + "github.com/NethermindEth/juno/utils/log" + "github.com/NethermindEth/juno/vm" + "github.com/libp2p/go-libp2p/core/host" + "github.com/libp2p/go-libp2p/core/peer" +) + +type InitOptionsForTest struct { + WrapWALStore func( + walstore.TendermintWALStore[starknet.Value, starknet.Hash, starknet.Address], + ) walstore.TendermintWALStore[starknet.Value, starknet.Hash, starknet.Address] + WrapBroadcasters func( + p2p.Broadcasters[starknet.Value, starknet.Hash, starknet.Address], + ) p2p.Broadcasters[starknet.Value, starknet.Hash, starknet.Address] +} + +func InitWithOptionsForTest( + host host.Host, + logger *log.ZapLogger, + database kvdb.KeyValueStore, + blockchain *blockchain.Blockchain, + vm vm.VM, + blockFetcher *p2psync.BlockFetcher, + nodeAddress *starknet.Address, + validators votecounter.Validators[starknet.Address], + timeoutFn driver.TimeoutFn, + bootstrapPeersFn func() []peer.AddrInfo, + compiler compiler.Compiler, + options InitOptionsForTest, +) (ConsensusServices, error) { + return initWithOptions( + host, + logger, + database, + blockchain, + vm, + blockFetcher, + nodeAddress, + validators, + timeoutFn, + bootstrapPeersFn, + compiler, + initOptions{ + wrapWALStore: options.WrapWALStore, + wrapBroadcasters: options.WrapBroadcasters, + }, + ) +} diff --git a/consensus/restart_test.go b/consensus/restart_test.go new file mode 100644 index 0000000000..e7574f8d30 --- /dev/null +++ b/consensus/restart_test.go @@ -0,0 +1,530 @@ +package consensus_test + +import ( + "context" + "errors" + "path/filepath" + "testing" + "time" + + "github.com/NethermindEth/juno/blockchain" + "github.com/NethermindEth/juno/blockchain/networks" + "github.com/NethermindEth/juno/consensus" + "github.com/NethermindEth/juno/consensus/driver" + consensusP2P "github.com/NethermindEth/juno/consensus/p2p" + "github.com/NethermindEth/juno/consensus/starknet" + "github.com/NethermindEth/juno/consensus/types" + consensuswal "github.com/NethermindEth/juno/consensus/types/wal" + "github.com/NethermindEth/juno/consensus/walstore" + "github.com/NethermindEth/juno/core" + "github.com/NethermindEth/juno/core/felt" + statetestutils "github.com/NethermindEth/juno/core/state/testutils" + kvdb "github.com/NethermindEth/juno/db" + "github.com/NethermindEth/juno/db/memory" + "github.com/NethermindEth/juno/db/pebblev2" + "github.com/NethermindEth/juno/p2p/pubsub/testutils" + p2psync "github.com/NethermindEth/juno/p2p/sync" + "github.com/NethermindEth/juno/starknet/compiler" + "github.com/NethermindEth/juno/utils/log" + "github.com/NethermindEth/juno/vm" + "github.com/sourcegraph/conc" + "github.com/stretchr/testify/require" +) + +type runningConsensusNode struct { + stop func() + wait func() +} + +type restartNodeStores struct { + consensusStore kvdb.KeyValueStore + blockchainStore kvdb.KeyValueStore + blockchain *blockchain.Blockchain +} + +type consensusNodeOptions struct { + requiredWALEntriesWritten chan<- struct{} + ackCommits bool + sentPrecommits chan<- *starknet.Precommit +} + +type observingWALStore struct { + walstore.TendermintWALStore[starknet.Value, starknet.Hash, starknet.Address] + requiredWALEntriesWritten chan<- struct{} +} + +func (s *observingWALStore) Flush() error { + if err := s.TendermintWALStore.Flush(); err != nil { + return err + } + if !walContainsRequiredReplayEntries(s.TendermintWALStore) { + return nil + } + select { + case s.requiredWALEntriesWritten <- struct{}{}: + default: + } + return nil +} + +type observingPrecommitBroadcaster struct { + consensusP2P.Broadcaster[*starknet.Precommit] + sentPrecommits chan<- *starknet.Precommit +} + +func (b observingPrecommitBroadcaster) Broadcast( + ctx context.Context, + precommit *starknet.Precommit, +) { + b.Broadcaster.Broadcast(ctx, precommit) + select { + case b.sentPrecommits <- precommit: + default: + } +} + +const ( + restartTestNodeCount = 4 + restartTestRequiredPrevotes = 2 * restartTestNodeCount / 3 + restartTestValidatorHighSeed = uint64(0) + restartTestValidatorLowSeed = uint64(0) + restartTestHeight = types.Height(1) + restartTestRound = types.Round(0) +) + +func TestConsensusRestartReplaysPersistentWALAndPrecommits(t *testing.T) { + logger := log.NewNopZapLogger() + genesisDiff, genesisClasses := loadGenesis(t, logger) + + p2pNodes := testutils.BuildNetworks(t, testutils.LineNetworkConfig(restartTestNodeCount)) + // Use a non-proposer so replay must restore proposal/vote state received from peers. + restartNodeIndex := chooseNonProposerNodeIndex(restartTestNodeCount) + restartNodeConsensusAddress := consensus.InitMockServices( + restartTestValidatorHighSeed, + restartTestValidatorLowSeed, + restartNodeIndex, + restartTestNodeCount, + ).NodeAddress + + consensusDBPath := filepath.Join(t.TempDir(), "consensus") + blockchainDBPath := filepath.Join(t.TempDir(), "blockchain") + restartStores := openRestartNodeStores( + t, + consensusDBPath, + blockchainDBPath, + genesisDiff, + genesisClasses, + ) + + requiredWALEntriesWritten := make(chan struct{}, 1) + restartNode, peerNodes := startNetworkBeforeRestart( + t, + t.Context(), + p2pNodes, + restartNodeIndex, + logger, + restartStores, + genesisDiff, + genesisClasses, + requiredWALEntriesWritten, + ) + + // Stop once the restart node has the WAL entries required for replay. + waitForRequiredWALEntriesAndStop(t, requiredWALEntriesWritten, restartNode) + // Stop peers too, so restart cannot use fresh network messages. + stopAndWait(peerNodes) + + // Confirm no block was committed, so recovery must come from WAL replay. + committedHeight, err := restartStores.blockchain.Height() + require.NoError(t, err) + require.Equal(t, uint64(0), committedHeight) + + restartStores.close(t) + + // Reopen stores to simulate a new process, then restart only the selected node. + restartStores = openRestartNodeStores( + t, + consensusDBPath, + blockchainDBPath, + genesisDiff, + genesisClasses, + ) + restartedNodePrecommits := make(chan *starknet.Precommit, 1) + restartedNode := startConsensusNode( + t, + t.Context(), + restartNodeIndex, + &p2pNodes[restartNodeIndex], + logger.Named("restart-node"), + restartStores.consensusStore, + restartStores.blockchain, + consensusNodeOptions{sentPrecommits: restartedNodePrecommits}, + ) + t.Cleanup(func() { + stopAndWait([]runningConsensusNode{restartedNode}) + restartStores.close(t) + }) + + // With peers stopped and no committed block, this precommit must come from WAL replay. + waitForPrecommitFromNode(t, restartedNodePrecommits, restartNodeConsensusAddress) +} + +func chooseNonProposerNodeIndex(nodeCount int) int { + const validatorSetNodeIndex = 0 + validators := consensus.InitMockServices( + restartTestValidatorHighSeed, + restartTestValidatorLowSeed, + validatorSetNodeIndex, + nodeCount, + ).Validators + for i := range nodeCount { + nodeAddress := consensus.InitMockServices( + restartTestValidatorHighSeed, + restartTestValidatorLowSeed, + i, + nodeCount, + ).NodeAddress + if validators.Proposer(restartTestHeight, restartTestRound) != nodeAddress { + return i + } + } + panic("expected to find non-proposer node") +} + +func openRestartNodeStores( + t *testing.T, + consensusDBPath string, + blockchainDBPath string, + genesisDiff core.StateDiff, + genesisClasses map[felt.Felt]core.ClassDefinition, +) restartNodeStores { + t.Helper() + + consensusStore, err := pebblev2.New(consensusDBPath) + require.NoError(t, err) + + blockchainStore, err := pebblev2.New(blockchainDBPath) + require.NoError(t, err) + + chain := blockchain.New( + blockchainStore, + &networks.Mainnet, + blockchain.WithNewState(statetestutils.UseNewState()), + ) + + _, err = chain.Height() + switch { + case errors.Is(err, kvdb.ErrKeyNotFound): + require.NoError(t, chain.StoreGenesis(&genesisDiff, genesisClasses)) + case err == nil: + default: + require.NoError(t, err) + } + + return restartNodeStores{ + consensusStore: consensusStore, + blockchainStore: blockchainStore, + blockchain: chain, + } +} + +func (s restartNodeStores) close(t *testing.T) { + t.Helper() + + require.NoError(t, s.consensusStore.Close()) + require.NoError(t, s.blockchainStore.Close()) +} + +func startNetworkBeforeRestart( + t *testing.T, + parent context.Context, + p2pNodes []testutils.Node, + restartNodeIndex int, + logger *log.ZapLogger, + restartStores restartNodeStores, + genesisDiff core.StateDiff, + genesisClasses map[felt.Felt]core.ClassDefinition, + requiredWALEntriesWritten chan<- struct{}, +) (runningConsensusNode, []runningConsensusNode) { + t.Helper() + + var restartNode runningConsensusNode + peerNodes := make([]runningConsensusNode, 0, len(p2pNodes)-1) + for i := range p2pNodes { + if i == restartNodeIndex { + restartNode = startConsensusNode( + t, + parent, + i, + &p2pNodes[i], + logger.Named("before-restart"), + restartStores.consensusStore, + restartStores.blockchain, + consensusNodeOptions{requiredWALEntriesWritten: requiredWALEntriesWritten}, + ) + continue + } + + peerNodes = append(peerNodes, startPeerNode( + t, + parent, + i, + &p2pNodes[i], + logger.Named("peer"), + genesisDiff, + genesisClasses, + )) + } + + return restartNode, peerNodes +} + +func startPeerNode( + t *testing.T, + parent context.Context, + nodeIndex int, + p2pNode *testutils.Node, + logger *log.ZapLogger, + genesisDiff core.StateDiff, + genesisClasses map[felt.Felt]core.ClassDefinition, +) runningConsensusNode { + t.Helper() + + return startConsensusNode( + t, + parent, + nodeIndex, + p2pNode, + logger, + memoryDB(t), + getBlockchain(t, genesisDiff, genesisClasses), + // Peers must drain and ack commits, otherwise their drivers block at commit + // and stop gossiping the proposal/prevotes the restart node needs in its WAL. + consensusNodeOptions{ackCommits: true}, + ) +} + +func startConsensusNode( + t *testing.T, + parent context.Context, + nodeIndex int, + p2pNode *testutils.Node, + logger *log.ZapLogger, + consensusStore kvdb.KeyValueStore, + chain *blockchain.Blockchain, + options consensusNodeOptions, +) runningConsensusNode { + t.Helper() + + mockServices := consensus.InitMockServices( + restartTestValidatorHighSeed, + restartTestValidatorLowSeed, + nodeIndex, + restartTestNodeCount, + ) + + network := &networks.Mainnet + vm := vm.New(&vm.ChainInfo{ + ChainID: network.L2ChainID, + FeeTokenAddresses: networks.DefaultFeeTokenAddresses, + }, false, logger) + blockFetcher := p2psync.NewBlockFetcher( + chain, + compiler.NewUnsafe(), + p2pNode.Host, + network, + logger, + ) + + initOpts := consensus.InitOptionsForTest{} + if options.sentPrecommits != nil { + initOpts.WrapBroadcasters = func( + broadcasters consensusP2P.Broadcasters[starknet.Value, starknet.Hash, starknet.Address], + ) consensusP2P.Broadcasters[starknet.Value, starknet.Hash, starknet.Address] { + broadcasters.PrecommitBroadcaster = observingPrecommitBroadcaster{ + Broadcaster: broadcasters.PrecommitBroadcaster, + sentPrecommits: options.sentPrecommits, + } + return broadcasters + } + } + if options.requiredWALEntriesWritten != nil { + initOpts.WrapWALStore = func( + walStore walstore.TendermintWALStore[starknet.Value, starknet.Hash, starknet.Address], + ) walstore.TendermintWALStore[starknet.Value, starknet.Hash, starknet.Address] { + return &observingWALStore{ + TendermintWALStore: walStore, + requiredWALEntriesWritten: options.requiredWALEntriesWritten, + } + } + } + + services, err := consensus.InitWithOptionsForTest( + p2pNode.Host, + logger, + consensusStore, + chain, + vm, + &blockFetcher, + &mockServices.NodeAddress, + mockServices.Validators, + mockServices.TimeoutFn, + p2pNode.GetBootstrapPeers, + compiler.NewUnsafe(), + initOpts, + ) + require.NoError(t, err) + + ctx, cancel := context.WithCancel(parent) + errs := make(chan error, 4) + wg := conc.NewWaitGroup() + + run := func(name string, fn func(context.Context) error) { + wg.Go(func() { + if err := fn(ctx); err != nil && !errors.Is(err, context.Canceled) { + errs <- errors.New(name + ": " + err.Error()) + } + }) + } + + run("proposer", services.Proposer.Run) + run("driver", services.Driver.Run) + run("p2p", services.P2P.Run) + + if options.ackCommits { + run("persist", func(ctx context.Context) error { + return persistCommittedBlocks(ctx, chain, services.CommitListener) + }) + } + + return runningConsensusNode{ + stop: cancel, + wait: func() { + wg.Wait() + close(errs) + for err := range errs { + require.NoError(t, err) + } + }, + } +} + +func waitForRequiredWALEntriesAndStop( + t *testing.T, + requiredWALEntriesWritten <-chan struct{}, + restartNode runningConsensusNode, +) { + t.Helper() + + select { + case <-requiredWALEntriesWritten: + case <-time.After(30 * time.Second): + restartNode.stop() + restartNode.wait() + require.FailNow(t, "timed out waiting for required WAL entries") + } + + restartNode.stop() + restartNode.wait() +} + +func stopAndWait(nodes []runningConsensusNode) { + for _, node := range nodes { + node.stop() + } + for _, node := range nodes { + node.wait() + } +} + +func waitForPrecommitFromNode( + t *testing.T, + precommits <-chan *starknet.Precommit, + nodeAddress starknet.Address, +) { + t.Helper() + + timeout := time.After(30 * time.Second) + for { + select { + case precommit := <-precommits: + if precommit.Sender == nodeAddress && precommit.Height == restartTestHeight { + return + } + case <-timeout: + require.FailNow(t, "timed out waiting for restarted node precommit") + } + } +} + +func walContainsRequiredReplayEntries( + walStore walstore.TendermintWALStore[starknet.Value, starknet.Hash, starknet.Address], +) bool { + hasStartForHeight := false + hasProposalForRound := false + prevoteSenders := make(map[starknet.Address]struct{}, restartTestRequiredPrevotes) + + for entry, err := range walStore.LoadAllEntries() { + if err != nil { + return false + } + switch entry := entry.(type) { + case *consensuswal.Start: + hasStartForHeight = hasStartForHeight || entry.GetHeight() == restartTestHeight + case *consensuswal.Proposal[starknet.Value, starknet.Hash, starknet.Address]: + hasProposalForRound = hasProposalForRound || isRestartRound(entry.Height, entry.Round) + case *consensuswal.Prevote[starknet.Hash, starknet.Address]: + if isRestartRound(entry.Height, entry.Round) { + prevoteSenders[entry.Sender] = struct{}{} + } + } + } + return hasStartForHeight && + hasProposalForRound && + len(prevoteSenders) >= restartTestRequiredPrevotes +} + +func isRestartRound(height types.Height, round types.Round) bool { + return height == restartTestHeight && round == restartTestRound +} + +func memoryDB(t *testing.T) kvdb.KeyValueStore { + t.Helper() + db := memory.New() + t.Cleanup(func() { + require.NoError(t, db.Close()) + }) + return db +} + +func persistCommittedBlocks( + ctx context.Context, + chain *blockchain.Blockchain, + commitListener driver.CommitListener[starknet.Value, starknet.Hash], +) error { + for { + select { + case <-ctx.Done(): + return nil + case committedBlock := <-commitListener.Listen(): + commitments, err := chain.SanityCheckNewHeight( + committedBlock.Block, + committedBlock.StateUpdate, + committedBlock.NewClasses, + ) + if err != nil { + return err + } + if err := chain.Store( + committedBlock.Block, + commitments, + committedBlock.StateUpdate, + committedBlock.NewClasses, + ); err != nil { + return err + } + + committedBlock.Persisted <- nil + } + } +} diff --git a/consensus/walstore/codec.go b/consensus/walstore/codec.go new file mode 100644 index 0000000000..3c36005ab9 --- /dev/null +++ b/consensus/walstore/codec.go @@ -0,0 +1,458 @@ +package walstore + +import ( + "encoding/binary" + "errors" + "fmt" + "io" + "math" + "reflect" + "slices" + + "github.com/NethermindEth/juno/consensus/types" + "github.com/NethermindEth/juno/consensus/types/wal" + "github.com/cockroachdb/pebble/v2" + "github.com/cockroachdb/pebble/v2/batchrepr" +) + +const valueLenBytes = 5 + +// Encode Pebble's batchrepr format directly so replay can parse WAL records +// with ReadHeader and Read. +func encodeBatch[V types.Hashable[H], H types.Hash, A types.Addr]( + records []walRecordEnvelope[V, H, A], + seqNum uint64, + encodedBatch []byte, +) ([]byte, error) { + if len(records) > math.MaxUint32 { + return nil, errors.New("encodeBatch: too many WAL records in single flush") + } + + const ( + kindBytes = 1 + keyLenBytes = 1 + keySizeBytes = 4 + estimatedPayloadBytes = 128 + keyLenOffset = kindBytes + keyOffset = keyLenOffset + keyLenBytes + valueLenOffset = keyOffset + keySizeBytes + recordHeaderBytes = valueLenOffset + valueLenBytes + estimatedRecordBytes = recordHeaderBytes + estimatedPayloadBytes + ) + estimatedBatchBytes := batchrepr.HeaderLen + len(records)*estimatedRecordBytes + if cap(encodedBatch) < estimatedBatchBytes { + encodedBatch = make([]byte, batchrepr.HeaderLen, estimatedBatchBytes) + } else { + encodedBatch = encodedBatch[:batchrepr.HeaderLen] + } + // batchrepr header: seq num (8 bytes) followed by record count (4 bytes). + binary.LittleEndian.PutUint64(encodedBatch[:8], seqNum) + binary.LittleEndian.PutUint32(encodedBatch[8:batchrepr.HeaderLen], uint32(len(records))) + + for i := range records { + record := &records[i] + // batchrepr entry layout: [kind:1][keyLen:1][key:4][valueLen:5][payload...] + // keyLen stores the size of the following 4-byte record index key. + // valueLen is reserved up front and backfilled after encoding the payload. + encodedBatch = slices.Grow(encodedBatch, recordHeaderBytes) + recordStart := len(encodedBatch) + keyLenStart := recordStart + keyLenOffset + keyStart := recordStart + keyOffset + valueLenStart := recordStart + valueLenOffset + payloadStart := recordStart + recordHeaderBytes + encodedBatch = encodedBatch[:payloadStart] + + encodedBatch[recordStart] = byte(pebble.InternalKeyKindSet) + encodedBatch[keyLenStart] = keySizeBytes + binary.BigEndian.PutUint32(encodedBatch[keyStart:valueLenStart], uint32(i)) + + var err error + // Append the custom WAL payload as the batchrepr value bytes. + encodedBatch, err = appendWALRecordPayload(encodedBatch, record) + if err != nil { + return nil, fmt.Errorf("encodeBatch: encode WAL envelope: %w", err) + } + + // batchrepr stores value lengths as uvarints. + valueLen := len(encodedBatch) - payloadStart + putFixedUvarint32(encodedBatch[valueLenStart:payloadStart], uint32(valueLen)) + } + + return encodedBatch, nil +} + +func appendWALRecordPayload[V types.Hashable[H], H types.Hash, A types.Addr]( + payload []byte, + record *walRecordEnvelope[V, H, A], +) ([]byte, error) { + payload = append(payload, byte(record.Kind)) + switch record.Kind { + case walRecordEntry: + payload = append(payload, byte(record.EntryKind)) + switch record.EntryKind { + case walEntryStart: + payload = appendUint64(payload, uint64(record.StartHeight)) + case walEntryProposal: + proposalMessage := (*types.Proposal[V, H, A])(record.ProposalEntry) + payload = appendMessageHeader(payload, proposalMessage.MessageHeader) + payload = appendInt64(payload, int64(proposalMessage.ValidRound)) + if proposalMessage.Value == nil { + payload = append(payload, 0) + } else { + payload = append(payload, 1) + var err error + payload, err = appendValue(payload, proposalMessage.Value) + if err != nil { + return nil, err + } + } + case walEntryPrevote: + payload = appendVotePayload(payload, (*types.Vote[H, A])(record.PrevoteEntry)) + case walEntryPrecommit: + payload = appendVotePayload(payload, (*types.Vote[H, A])(record.PrecommitEntry)) + case walEntryTimeout: + timeoutMessage := (*types.Timeout)(record.TimeoutEntry) + payload = append(payload, byte(timeoutMessage.Step)) + payload = appendUint64(payload, uint64(timeoutMessage.Height)) + payload = appendInt64(payload, int64(timeoutMessage.Round)) + default: + return nil, fmt.Errorf("unknown WAL entry kind %d", record.EntryKind) + } + case walRecordPruneUpToHeight: + payload = appendUint64(payload, uint64(record.Height)) + default: + return nil, fmt.Errorf("unknown WAL record kind %d", record.Kind) + } + return payload, nil +} + +func putFixedUvarint32(buf []byte, value uint32) { + const fixedUvarintShift = 7 + + buf[0] = byte(value) | 0x80 + for i := 1; i < valueLenBytes-1; i++ { + buf[i] = byte(value>>(fixedUvarintShift*i)) | 0x80 + } + buf[valueLenBytes-1] = byte(value >> (fixedUvarintShift * (valueLenBytes - 1))) +} + +func decodeWALRecord[V types.Hashable[H], H types.Hash, A types.Addr]( + payload []byte, +) (walRecordEnvelope[V, H, A], error) { + decoder := walRecordDecoder{data: payload} + kind, err := decoder.readByte() + if err != nil { + return walRecordEnvelope[V, H, A]{}, err + } + + record := walRecordEnvelope[V, H, A]{Kind: walRecordKind(kind)} + switch record.Kind { + case walRecordEntry: + record, err = decodeWALEntryRecord(&decoder, record) + if err != nil { + return walRecordEnvelope[V, H, A]{}, err + } + case walRecordPruneUpToHeight: + height, err := decoder.readHeight() + if err != nil { + return walRecordEnvelope[V, H, A]{}, err + } + record.Height = height + default: + return walRecordEnvelope[V, H, A]{}, fmt.Errorf("unknown WAL record kind %d", record.Kind) + } + if decoder.remaining() != 0 { + return walRecordEnvelope[V, H, A]{}, fmt.Errorf( + "unexpected remaining WAL record bytes: %d", + decoder.remaining(), + ) + } + return record, nil +} + +func decodeWALEntryRecord[V types.Hashable[H], H types.Hash, A types.Addr]( + decoder *walRecordDecoder, + record walRecordEnvelope[V, H, A], +) (walRecordEnvelope[V, H, A], error) { + entryKind, err := decoder.readByte() + if err != nil { + return walRecordEnvelope[V, H, A]{}, err + } + + record.EntryKind = walEntryKind(entryKind) + switch record.EntryKind { + case walEntryStart: + height, err := decoder.readHeight() + if err != nil { + return walRecordEnvelope[V, H, A]{}, err + } + record.StartHeight = height + case walEntryProposal: + proposal, err := decodeProposalRecord[V, H, A](decoder) + if err != nil { + return walRecordEnvelope[V, H, A]{}, err + } + record.ProposalEntry = proposal + case walEntryPrevote: + vote, err := decodeVotePayload[H, A](decoder) + if err != nil { + return walRecordEnvelope[V, H, A]{}, err + } + prevote := wal.Prevote[H, A](vote) + record.PrevoteEntry = &prevote + case walEntryPrecommit: + vote, err := decodeVotePayload[H, A](decoder) + if err != nil { + return walRecordEnvelope[V, H, A]{}, err + } + precommit := wal.Precommit[H, A](vote) + record.PrecommitEntry = &precommit + case walEntryTimeout: + timeout, err := decodeTimeoutRecord(decoder) + if err != nil { + return walRecordEnvelope[V, H, A]{}, err + } + record.TimeoutEntry = timeout + default: + return walRecordEnvelope[V, H, A]{}, fmt.Errorf("unknown WAL entry kind %d", record.EntryKind) + } + + return record, nil +} + +// Proposal payload: message header, valid round, optional value. +func decodeProposalRecord[V types.Hashable[H], H types.Hash, A types.Addr]( + decoder *walRecordDecoder, +) (*wal.Proposal[V, H, A], error) { + header, err := readMessageHeader[A](decoder) + if err != nil { + return nil, err + } + validRound, err := decoder.readRound() + if err != nil { + return nil, err + } + hasValue, err := decoder.readPresenceByte() + if err != nil { + return nil, err + } + proposal := wal.Proposal[V, H, A]{ + MessageHeader: header, + ValidRound: validRound, + } + if hasValue { + value, err := readValue[V](decoder) + if err != nil { + return nil, err + } + proposal.Value = &value + } + return &proposal, nil +} + +// Vote payload: message header, optional value ID. Used by prevote and precommit. +func decodeVotePayload[H types.Hash, A types.Addr](d *walRecordDecoder) (types.Vote[H, A], error) { + header, err := readMessageHeader[A](d) + if err != nil { + return types.Vote[H, A]{}, err + } + hasID, err := d.readPresenceByte() + if err != nil { + return types.Vote[H, A]{}, err + } + vote := types.Vote[H, A]{MessageHeader: header} + if hasID { + id, err := d.readUint64Array() + if err != nil { + return types.Vote[H, A]{}, err + } + hash := H(id) + vote.ID = &hash + } + return vote, nil +} + +// Timeout payload: step, height, round. +func decodeTimeoutRecord(decoder *walRecordDecoder) (*wal.Timeout, error) { + step, err := decoder.readByte() + if err != nil { + return nil, err + } + height, err := decoder.readHeight() + if err != nil { + return nil, err + } + round, err := decoder.readRound() + if err != nil { + return nil, err + } + timeout := wal.Timeout(types.Timeout{ + Step: types.Step(step), + Height: height, + Round: round, + }) + return &timeout, nil +} + +func appendMessageHeader[A types.Addr](payload []byte, header types.MessageHeader[A]) []byte { + payload = appendUint64(payload, uint64(header.Height)) + payload = appendInt64(payload, int64(header.Round)) + return appendUint64Array(payload, &header.Sender) +} + +func appendVotePayload[H types.Hash, A types.Addr](payload []byte, vote *types.Vote[H, A]) []byte { + payload = appendMessageHeader(payload, vote.MessageHeader) + if vote.ID == nil { + return append(payload, 0) + } + payload = append(payload, 1) + return appendUint64Array(payload, vote.ID) +} + +func appendUint64(payload []byte, value uint64) []byte { + // Match Pebble batchrepr's little-endian header encoding within this WAL record. + return binary.LittleEndian.AppendUint64(payload, value) +} + +func appendInt64(payload []byte, value int64) []byte { + return appendUint64(payload, uint64(value)) +} + +func appendUint64Array[T ~[4]uint64](payload []byte, value *T) []byte { + for i := range 4 { + payload = appendUint64(payload, (*value)[i]) + } + return payload +} + +func appendValue[V any](payload []byte, value *V) ([]byte, error) { + array, err := valueToUint64Array(value) + if err != nil { + return nil, err + } + for _, limb := range array { + payload = appendUint64(payload, limb) + } + return payload, nil +} + +type walRecordDecoder struct { + data []byte + pos int +} + +func (d *walRecordDecoder) remaining() int { + return len(d.data) - d.pos +} + +func (d *walRecordDecoder) readByte() (byte, error) { + if d.remaining() < 1 { + return 0, io.ErrUnexpectedEOF + } + value := d.data[d.pos] + d.pos++ + return value, nil +} + +func (d *walRecordDecoder) readPresenceByte() (bool, error) { + value, err := d.readByte() + if err != nil { + return false, err + } + switch value { + case 0: + return false, nil + case 1: + return true, nil + default: + return false, fmt.Errorf("invalid presence byte %d", value) + } +} + +func (d *walRecordDecoder) readUint64() (uint64, error) { + if d.remaining() < 8 { + return 0, io.ErrUnexpectedEOF + } + value := binary.LittleEndian.Uint64(d.data[d.pos : d.pos+8]) + d.pos += 8 + return value, nil +} + +func (d *walRecordDecoder) readHeight() (types.Height, error) { + value, err := d.readUint64() + return types.Height(value), err +} + +func (d *walRecordDecoder) readRound() (types.Round, error) { + value, err := d.readUint64() + return types.Round(int64(value)), err +} + +func (d *walRecordDecoder) readUint64Array() ([4]uint64, error) { + var value [4]uint64 + for i := range value { + limb, err := d.readUint64() + if err != nil { + return [4]uint64{}, err + } + value[i] = limb + } + return value, nil +} + +func readMessageHeader[A types.Addr](d *walRecordDecoder) (types.MessageHeader[A], error) { + height, err := d.readHeight() + if err != nil { + return types.MessageHeader[A]{}, err + } + round, err := d.readRound() + if err != nil { + return types.MessageHeader[A]{}, err + } + sender, err := d.readUint64Array() + if err != nil { + return types.MessageHeader[A]{}, err + } + return types.MessageHeader[A]{ + Height: height, + Round: round, + Sender: A(sender), + }, nil +} + +func readValue[V any](d *walRecordDecoder) (V, error) { + array, err := d.readUint64Array() + if err != nil { + var zero V + return zero, err + } + return uint64ArrayToValue[V](array) +} + +func valueToUint64Array[V any](value *V) ([4]uint64, error) { + reflectValue := reflect.ValueOf(value).Elem() + if reflectValue.Kind() != reflect.Array || + reflectValue.Len() != 4 || + reflectValue.Type().Elem().Kind() != reflect.Uint64 { + return [4]uint64{}, fmt.Errorf("unsupported WAL value type %T", value) + } + var array [4]uint64 + for i := range array { + array[i] = reflectValue.Index(i).Uint() + } + return array, nil +} + +func uint64ArrayToValue[V any](array [4]uint64) (V, error) { + var value V + reflectValue := reflect.ValueOf(&value).Elem() + if reflectValue.Kind() != reflect.Array || + reflectValue.Len() != 4 || + reflectValue.Type().Elem().Kind() != reflect.Uint64 { + return value, fmt.Errorf("unsupported WAL value type %T", value) + } + for i, limb := range array { + reflectValue.Index(i).SetUint(limb) + } + return value, nil +} diff --git a/consensus/walstore/codec_test.go b/consensus/walstore/codec_test.go new file mode 100644 index 0000000000..c6b5ef108e --- /dev/null +++ b/consensus/walstore/codec_test.go @@ -0,0 +1,95 @@ +package walstore_test + +import ( + "testing" + + "github.com/NethermindEth/juno/consensus/starknet" + "github.com/NethermindEth/juno/consensus/types" + "github.com/NethermindEth/juno/consensus/types/wal" + "github.com/NethermindEth/juno/consensus/walstore" + "github.com/NethermindEth/juno/core/felt" + "github.com/stretchr/testify/require" +) + +// TestWALRecordsRoundTripThroughReopen verifies that persisted WAL records +// survive a store reopen unchanged across all entry kinds and edge branches. +func TestWALRecordsRoundTripThroughReopen(t *testing.T) { + height := types.Height(1) + round := types.Round(1) + value := felt.FromUint64[starknet.Value](1) + valueHash := value.Hash() + + start := wal.Start(height) + proposalWithValue := &starknet.WALProposal{ + MessageHeader: starknet.MessageHeader{Height: height, Round: round, Sender: walAddress(1)}, + ValidRound: 1, + Value: &value, + } + proposalWithoutValue := &starknet.WALProposal{ + MessageHeader: starknet.MessageHeader{Height: height, Round: round, Sender: walAddress(2)}, + ValidRound: -1, + Value: nil, + } + prevoteWithID := &starknet.WALPrevote{ + MessageHeader: starknet.MessageHeader{Height: height, Round: round, Sender: walAddress(3)}, + ID: &valueHash, + } + prevoteNilID := &starknet.WALPrevote{ + MessageHeader: starknet.MessageHeader{Height: height, Round: round, Sender: walAddress(4)}, + ID: nil, + } + precommitWithID := &starknet.WALPrecommit{ + MessageHeader: starknet.MessageHeader{Height: height, Round: round, Sender: walAddress(5)}, + ID: &valueHash, + } + precommitNilID := &starknet.WALPrecommit{ + MessageHeader: starknet.MessageHeader{Height: height, Round: round, Sender: walAddress(6)}, + ID: nil, + } + timeout := &starknet.WALTimeout{Height: height, Round: round, Step: types.StepPrecommit} + + entries := []starknet.WALEntry{ + &start, + proposalWithValue, + proposalWithoutValue, + prevoteWithID, + prevoteNilID, + precommitWithID, + precommitNilID, + timeout, + } + + walStore, testDB, dbPath := newTestTendermintWALStore(t) + writeAndFlushEntries(t, walStore, entries) + + walStore, testDB = reopenTestTendermintWALStore(t, walStore, testDB, dbPath) + t.Cleanup(func() { + require.NoError(t, walStore.Close()) + require.NoError(t, testDB.Close()) + }) + + assertLoadedEntries(t, walStore, entries) +} + +func TestWALRecordDecodeRejectsMalformedPayloads(t *testing.T) { + start := wal.Start(types.Height(1)) + valid, err := walstore.EncodeWALRecordPayload(&start) + require.NoError(t, err) + + tests := map[string][]byte{ + "empty": {}, + "unknown kind": {0xFF}, + "truncated": valid[:len(valid)-1], + "trailing byte": append(append([]byte(nil), valid...), 0x00), + } + + for name, payload := range tests { + t.Run(name, func(t *testing.T) { + require.Error(t, walstore.DecodeWALRecordPayload(payload)) + }) + } +} + +func walAddress(v uint64) starknet.Address { + return felt.FromUint64[starknet.Address](v) +} diff --git a/consensus/walstore/export_test.go b/consensus/walstore/export_test.go new file mode 100644 index 0000000000..f2d1df260d --- /dev/null +++ b/consensus/walstore/export_test.go @@ -0,0 +1,37 @@ +package walstore + +import ( + "github.com/NethermindEth/juno/consensus/starknet" + "github.com/NethermindEth/juno/consensus/types" +) + +func LoadPruneWatermark(walDir string) (types.Height, error) { + return loadPruneWatermark(walDir) +} + +func WritePruneWatermark(walDir string, height types.Height) error { + return writePruneWatermark(walDir, height) +} + +// ForcePendingCleanup primes the store so the next prune flush triggers cleanup. +func ForcePendingCleanup( + walStore TendermintWALStore[starknet.Value, starknet.Hash, starknet.Address], +) { + walStore.(*tendermintWALStore[starknet.Value, starknet.Hash, starknet.Address]). + pruneRecordsSinceCleanup = cleanupPruneRecordInterval - 1 +} + +func EncodeWALRecordPayload(entry starknet.WALEntry) ([]byte, error) { + envelope := walRecordEnvelope[starknet.Value, starknet.Hash, starknet.Address]{ + Kind: walRecordEntry, + } + if err := envelope.setEntry(entry); err != nil { + return nil, err + } + return appendWALRecordPayload(nil, &envelope) +} + +func DecodeWALRecordPayload(payload []byte) error { + _, err := decodeWALRecord[starknet.Value, starknet.Hash, starknet.Address](payload) + return err +} diff --git a/consensus/db/init_test.go b/consensus/walstore/init_test.go similarity index 74% rename from consensus/db/init_test.go rename to consensus/walstore/init_test.go index e104b1db6f..e5049bfc61 100644 --- a/consensus/db/init_test.go +++ b/consensus/walstore/init_test.go @@ -1,4 +1,4 @@ -package db +package walstore_test import ( _ "github.com/NethermindEth/juno/encoder/registry" diff --git a/consensus/walstore/prune_watermark.go b/consensus/walstore/prune_watermark.go new file mode 100644 index 0000000000..fada23e6f2 --- /dev/null +++ b/consensus/walstore/prune_watermark.go @@ -0,0 +1,90 @@ +package walstore + +import ( + "encoding/binary" + "errors" + "fmt" + "os" + "path/filepath" + + "github.com/NethermindEth/juno/consensus/types" +) + +const ( + pruneWatermarkHeader = "juno-wal-prune-watermark-v1" + pruneWatermarkSize = len(pruneWatermarkHeader) + 8 +) + +func loadPruneWatermark(walDir string) (types.Height, error) { + data, err := os.ReadFile(pruneWatermarkPath(walDir)) + if errors.Is(err, os.ErrNotExist) { + return 0, nil + } + if err != nil { + return 0, fmt.Errorf("loadPruneWatermark: read watermark: %w", err) + } + if len(data) != pruneWatermarkSize { + return 0, fmt.Errorf("loadPruneWatermark: invalid watermark size %d", len(data)) + } + if string(data[:len(pruneWatermarkHeader)]) != pruneWatermarkHeader { + return 0, errors.New("loadPruneWatermark: invalid watermark header") + } + return types.Height(binary.BigEndian.Uint64(data[len(pruneWatermarkHeader):])), nil +} + +func writePruneWatermark(walDir string, height types.Height) error { + const pruneWatermarkFilePerm = 0o644 + + var data [pruneWatermarkSize]byte + copy(data[:], pruneWatermarkHeader) + binary.BigEndian.PutUint64(data[len(pruneWatermarkHeader):], uint64(height)) + + path := pruneWatermarkPath(walDir) + tmpPath := path + ".tmp" + file, err := os.OpenFile(tmpPath, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, pruneWatermarkFilePerm) + if err != nil { + return fmt.Errorf("writePruneWatermark: create temp watermark: %w", err) + } + + _, writeErr := file.Write(data[:]) + if writeErr == nil { + writeErr = file.Sync() + } + closeErr := file.Close() + if writeErr != nil || closeErr != nil { + _ = os.Remove(tmpPath) + if writeErr != nil { + writeErr = fmt.Errorf("writePruneWatermark: write temp watermark: %w", writeErr) + } + if closeErr != nil { + closeErr = fmt.Errorf("writePruneWatermark: close temp watermark: %w", closeErr) + } + return errors.Join(writeErr, closeErr) + } + + if err := os.Rename(tmpPath, path); err != nil { + _ = os.Remove(tmpPath) + return fmt.Errorf("writePruneWatermark: replace watermark: %w", err) + } + if err := syncDir(walDir); err != nil { + return fmt.Errorf("writePruneWatermark: sync watermark directory: %w", err) + } + return nil +} + +func pruneWatermarkPath(walDir string) string { + return filepath.Join(walDir, "prune-watermark") +} + +func syncDir(path string) error { + dir, err := os.Open(path) + if err != nil { + return err + } + defer dir.Close() + + if err := dir.Sync(); err != nil { + return err + } + return nil +} diff --git a/consensus/walstore/prune_watermark_test.go b/consensus/walstore/prune_watermark_test.go new file mode 100644 index 0000000000..e2361dcdae --- /dev/null +++ b/consensus/walstore/prune_watermark_test.go @@ -0,0 +1,52 @@ +package walstore_test + +import ( + "os" + "path/filepath" + "testing" + + "github.com/NethermindEth/juno/consensus/types" + "github.com/NethermindEth/juno/consensus/walstore" + "github.com/stretchr/testify/require" +) + +const testPruneWatermarkHeight = types.Height(1) + +func TestPruneWatermarkRoundTrips(t *testing.T) { + walDir := t.TempDir() + + height, err := walstore.LoadPruneWatermark(walDir) + require.NoError(t, err) + require.Zero(t, height) + + require.NoError(t, walstore.WritePruneWatermark(walDir, testPruneWatermarkHeight)) + height, err = walstore.LoadPruneWatermark(walDir) + require.NoError(t, err) + require.Equal(t, testPruneWatermarkHeight, height) +} + +func TestLoadPruneWatermarkRejectsCorruptFile(t *testing.T) { + walDir := t.TempDir() + require.NoError(t, walstore.WritePruneWatermark(walDir, testPruneWatermarkHeight)) + validWatermark, err := os.ReadFile(filepath.Join(walDir, "prune-watermark")) + require.NoError(t, err) + + wrongHeader := make([]byte, len(validWatermark)) + copy(wrongHeader, "wrong-header") + + tests := map[string][]byte{ + "wrong size": []byte("short"), + "bad header": wrongHeader, + } + + for name, contents := range tests { + t.Run(name, func(t *testing.T) { + walDir := t.TempDir() + path := filepath.Join(walDir, "prune-watermark") + require.NoError(t, os.WriteFile(path, contents, 0o644)) + + _, err := walstore.LoadPruneWatermark(walDir) + require.Error(t, err) + }) + } +} diff --git a/consensus/walstore/record.go b/consensus/walstore/record.go new file mode 100644 index 0000000000..dc031811b4 --- /dev/null +++ b/consensus/walstore/record.go @@ -0,0 +1,98 @@ +package walstore + +import ( + "errors" + "fmt" + + "github.com/NethermindEth/juno/consensus/types" + "github.com/NethermindEth/juno/consensus/types/wal" +) + +type walRecordKind uint8 + +const ( + walRecordEntry walRecordKind = iota + 1 + walRecordPruneUpToHeight +) + +type walEntryKind uint8 + +const ( + walEntryStart walEntryKind = iota + 1 + walEntryProposal + walEntryPrevote + walEntryPrecommit + walEntryTimeout +) + +type walRecordEnvelope[V types.Hashable[H], H types.Hash, A types.Addr] struct { + Kind walRecordKind + EntryKind walEntryKind + StartHeight types.Height + ProposalEntry *wal.Proposal[V, H, A] + PrevoteEntry *wal.Prevote[H, A] + PrecommitEntry *wal.Precommit[H, A] + TimeoutEntry *wal.Timeout + Height types.Height +} + +func (e *walRecordEnvelope[V, H, A]) setEntry(entry wal.Entry[V, H, A]) error { + switch entry := entry.(type) { + case *wal.Start: + if entry == nil { + return errors.New("nil start WAL entry") + } + e.EntryKind = walEntryStart + e.StartHeight = types.Height(*entry) + case *wal.Proposal[V, H, A]: + if entry == nil { + return errors.New("nil proposal WAL entry") + } + e.EntryKind = walEntryProposal + proposal := *entry + e.ProposalEntry = &proposal + case *wal.Prevote[H, A]: + if entry == nil { + return errors.New("nil prevote WAL entry") + } + e.EntryKind = walEntryPrevote + prevote := *entry + e.PrevoteEntry = &prevote + case *wal.Precommit[H, A]: + if entry == nil { + return errors.New("nil precommit WAL entry") + } + e.EntryKind = walEntryPrecommit + precommit := *entry + e.PrecommitEntry = &precommit + case *wal.Timeout: + if entry == nil { + return errors.New("nil timeout WAL entry") + } + e.EntryKind = walEntryTimeout + timeout := *entry + e.TimeoutEntry = &timeout + default: + return fmt.Errorf("unsupported WAL entry type %T", entry) + } + + return nil +} + +func (e walRecordEnvelope[V, H, A]) entry() wal.Entry[V, H, A] { + switch e.EntryKind { + case walEntryStart: + start := wal.Start(e.StartHeight) + return &start + case walEntryProposal: + return e.ProposalEntry + case walEntryPrevote: + return e.PrevoteEntry + case walEntryPrecommit: + return e.PrecommitEntry + case walEntryTimeout: + return e.TimeoutEntry + default: + panic(fmt.Sprintf("unknown WAL entry kind %d", e.EntryKind)) + } +} diff --git a/consensus/walstore/replay.go b/consensus/walstore/replay.go new file mode 100644 index 0000000000..0e6f6c5713 --- /dev/null +++ b/consensus/walstore/replay.go @@ -0,0 +1,117 @@ +package walstore + +import ( + "errors" + "fmt" + "io" + + "github.com/cockroachdb/pebble/v2" + "github.com/cockroachdb/pebble/v2/batchrepr" + "github.com/cockroachdb/pebble/v2/record" + pebblewal "github.com/cockroachdb/pebble/v2/wal" +) + +func (s *tendermintWALStore[V, H, A]) loadExistingEntries(logs pebblewal.Logs) error { + for i, log := range logs { + if err := s.loadLogicalLog(log, i == len(logs)-1); err != nil { + return err + } + } + + return nil +} + +func (s *tendermintWALStore[V, H, A]) loadLogicalLog( + log pebblewal.LogicalLog, + tolerateTail bool, +) error { + reader := log.OpenForRead() + defer reader.Close() + + for { + recordReader, _, err := reader.NextRecord() + switch { + case err == nil: + case errors.Is(err, io.EOF): + return nil + case tolerateTail && record.IsInvalidRecord(err): + return nil + default: + return fmt.Errorf("loadLogicalLog: read WAL %s: %w", log.Num, err) + } + + encodedBatch, err := io.ReadAll(recordReader) + if err != nil { + return fmt.Errorf("loadLogicalLog: read WAL %s record: %w", log.Num, err) + } + + if err := s.applyEncodedBatch(log.Num, encodedBatch); err != nil { + return fmt.Errorf("loadLogicalLog: apply WAL %s record: %w", log.Num, err) + } + } +} + +func (s *tendermintWALStore[V, H, A]) applyEncodedBatch( + walNum pebblewal.NumWAL, + encodedBatch []byte, +) error { + header, ok := batchrepr.ReadHeader(encodedBatch) + if !ok { + return errors.New("applyEncodedBatch: missing batch header") + } + if nextSeq := uint64(header.SeqNum) + uint64(header.Count); nextSeq > s.nextBatchSeqNum { + s.nextBatchSeqNum = nextSeq + } + + reader := batchrepr.Read(encodedBatch) + seen := uint32(0) + for { + kind, _, value, ok, err := reader.Next() + if err != nil { + return fmt.Errorf("applyEncodedBatch: iterate batch record: %w", err) + } + if !ok { + if seen != header.Count { + return fmt.Errorf( + "applyEncodedBatch: batch header count %d does not match record count %d", + header.Count, + seen, + ) + } + return nil + } + if kind != pebble.InternalKeyKindSet { + return fmt.Errorf("applyEncodedBatch: unexpected batch key kind %v", kind) + } + + if err := s.applyEncodedRecord(walNum, value); err != nil { + return err + } + seen++ + } +} + +func (s *tendermintWALStore[V, H, A]) applyEncodedRecord( + walNum pebblewal.NumWAL, + value []byte, +) error { + envelope, err := decodeWALRecord[V, H, A](value) + if err != nil { + return fmt.Errorf("applyEncodedRecord: decode WAL envelope: %w", err) + } + + switch envelope.Kind { + case walRecordEntry: + entry := envelope.entry() + if entry.GetHeight() <= s.prunedUpToHeight { + return nil + } + s.addLiveEntry(walNum, entry) + case walRecordPruneUpToHeight: + s.pruneLiveEntriesUpTo(envelope.Height) + default: + return fmt.Errorf("applyEncodedRecord: unknown WAL envelope kind %d", envelope.Kind) + } + + return nil +} diff --git a/consensus/walstore/replay_test.go b/consensus/walstore/replay_test.go new file mode 100644 index 0000000000..5a3cc63dc5 --- /dev/null +++ b/consensus/walstore/replay_test.go @@ -0,0 +1,55 @@ +package walstore_test + +import ( + "os" + "path/filepath" + "testing" + + "github.com/NethermindEth/juno/consensus/walstore" + "github.com/stretchr/testify/require" +) + +func TestWALReplayTruncatesInvalidLatestWALTail(t *testing.T) { + expectedEntries := makeWALEntriesForHeight(1) + walStore, testDB, dbPath := newTestTendermintWALStore(t) + writeAndFlushEntries(t, walStore, expectedEntries) + require.NoError(t, walStore.Close()) + require.NoError(t, testDB.Close()) + + walDir := walstore.DefaultWALDir(dbPath) + walPath := filepath.Join(walDir, "000001.log") + + info, err := os.Stat(walPath) + require.NoError(t, err) + sizeBeforeInvalidTail := info.Size() + + appendInvalidWALTail(t, walPath, []byte("partial-record-tail")) + info, err = os.Stat(walPath) + require.NoError(t, err) + corruptedSize := info.Size() + require.Greater(t, corruptedSize, sizeBeforeInvalidTail) + + walStore, testDB = openTestTendermintWALStore(t, dbPath) + defer func() { + require.NoError(t, walStore.Close()) + require.NoError(t, testDB.Close()) + }() + + assertLoadedEntries(t, walStore, expectedEntries) + + info, err = os.Stat(walPath) + require.NoError(t, err) + require.LessOrEqual(t, info.Size(), sizeBeforeInvalidTail) +} + +func appendInvalidWALTail(t *testing.T, path string, payload []byte) { + t.Helper() + + file, err := os.OpenFile(path, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0o644) + require.NoError(t, err) + defer file.Close() + + _, err = file.Write(payload) + require.NoError(t, err) + require.NoError(t, file.Sync()) +} diff --git a/consensus/walstore/wal_index.go b/consensus/walstore/wal_index.go new file mode 100644 index 0000000000..795d329151 --- /dev/null +++ b/consensus/walstore/wal_index.go @@ -0,0 +1,95 @@ +package walstore + +import ( + "slices" + + "github.com/NethermindEth/juno/consensus/types" + "github.com/NethermindEth/juno/consensus/types/wal" + pebblewal "github.com/cockroachdb/pebble/v2/wal" +) + +// walNumSet stores the first WAL inline and subsequent WALs in rest. +// It avoids map overhead because most heights reference only one WAL. +type walNumSet struct { + first pebblewal.NumWAL + rest []pebblewal.NumWAL +} + +func (s *walNumSet) addIfMissing(walNum pebblewal.NumWAL) bool { + if s.first == 0 { + s.first = walNum + return true + } + if s.first == walNum || slices.Contains(s.rest, walNum) { + return false + } + s.rest = append(s.rest, walNum) + return true +} + +func (s walNumSet) rangeOver(fn func(pebblewal.NumWAL)) { + if s.first == 0 { + return + } + fn(s.first) + for _, walNum := range s.rest { + fn(walNum) + } +} + +func (s *tendermintWALStore[V, H, A]) updateIndexesFromCommittedRecords( + walNum pebblewal.NumWAL, + records []walRecordEnvelope[V, H, A], +) { + for _, record := range records { + switch record.Kind { + case walRecordEntry: + entry := record.entry() + if entry.GetHeight() <= s.prunedUpToHeight { + continue + } + s.addLiveEntry(walNum, entry) + case walRecordPruneUpToHeight: + s.pruneLiveEntriesUpTo(record.Height) + } + } +} + +func (s *tendermintWALStore[V, H, A]) addLiveEntry( + walNum pebblewal.NumWAL, + entry wal.Entry[V, H, A], +) { + height := entry.GetHeight() + s.entriesByHeight[height] = append(s.entriesByHeight[height], entry) + + referencedWALs := s.walFilesByHeight[height] + if referencedWALs.addIfMissing(walNum) { + s.walFilesByHeight[height] = referencedWALs + s.walHeightRefs[walNum]++ + } +} + +func (s *tendermintWALStore[V, H, A]) deleteLiveHeight(height types.Height) { + delete(s.entriesByHeight, height) + + referencedWALs := s.walFilesByHeight[height] + delete(s.walFilesByHeight, height) + referencedWALs.rangeOver(func(walNum pebblewal.NumWAL) { + s.walHeightRefs[walNum]-- + if s.walHeightRefs[walNum] == 0 { + delete(s.walHeightRefs, walNum) + } + }) +} + +func (s *tendermintWALStore[V, H, A]) pruneLiveEntriesUpTo(height types.Height) { + if height <= s.prunedUpToHeight { + return + } + s.prunedUpToHeight = height + for liveHeight := range s.entriesByHeight { + if liveHeight <= height { + s.deleteLiveHeight(liveHeight) + } + } +} diff --git a/consensus/walstore/wal_store.go b/consensus/walstore/wal_store.go new file mode 100644 index 0000000000..624ad11a9a --- /dev/null +++ b/consensus/walstore/wal_store.go @@ -0,0 +1,326 @@ +package walstore + +import ( + "errors" + "fmt" + "iter" + "os" + "path/filepath" + "slices" + "sync" + "time" + + "github.com/NethermindEth/juno/consensus/types" + "github.com/NethermindEth/juno/consensus/types/wal" + kvdb "github.com/NethermindEth/juno/db" + "github.com/cockroachdb/pebble/v2" + "github.com/cockroachdb/pebble/v2/vfs" + pebblewal "github.com/cockroachdb/pebble/v2/wal" +) + +const ( + initialWALNum = 1 + initialSeqNum = 1 + cleanupPruneRecordInterval = 256 // Amortizes crash-safe cleanup fsyncs. +) + +// TendermintWALStore persists consensus WAL entries so the node can recover after a crash. +type TendermintWALStore[V types.Hashable[H], H types.Hash, A types.Addr] interface { + // Flush writes pending records to the WAL. + Flush() error + // LoadAllEntries returns stored WAL entries ordered by height. + LoadAllEntries() iter.Seq2[wal.Entry[V, H, A], error] + // SetWALEntry buffers an entry until Flush or Close. + SetWALEntry(entry wal.Entry[V, H, A]) error + // DeleteWALEntries buffers a prune for all entries up to and including height. + DeleteWALEntries(height types.Height) error + // Close flushes pending records and closes WAL resources. + Close() error +} + +type tendermintWALStore[V types.Hashable[H], H types.Hash, A types.Addr] struct { + mu sync.Mutex + closed bool + wal *walWriter + + nextBatchSeqNum uint64 + entriesByHeight map[types.Height][]wal.Entry[V, H, A] + // Height deletes need to find all WAL files that contain the deleted height. + walFilesByHeight map[types.Height]walNumSet + // WAL cleanup needs to know whether any remaining height still references the file. + walHeightRefs map[pebblewal.NumWAL]int + + // prunedUpToHeight is the inclusive watermark for discarded WAL recovery data. + prunedUpToHeight types.Height + + // pendingRecords are only visible through LoadAllEntries after a successful flush. + pendingRecords []walRecordEnvelope[V, H, A] + encodedBatch []byte + pruneRecordsSinceCleanup uint64 +} + +func DefaultWALDir(dbPath string) string { + const walDirName = "consensus-wal" + + return filepath.Join(dbPath, walDirName) +} + +// NewTendermintWALStore creates a new TendermintWALStore +// backed by Pebble's standalone WAL implementation. +func NewTendermintWALStore[V types.Hashable[H], H types.Hash, A types.Addr]( + database kvdb.KeyValueStore, +) (TendermintWALStore[V, H, A], error) { + const walDirPerm = 0o755 + + // Store the standalone consensus WAL next to the backing DB to keep all local + // consensus persistence under the same data path. + dbPath := database.Path() + if dbPath == "" { + return nil, errors.New( + "NewTendermintWALStore: consensus WAL requires a local path; remote DBs are unsupported", + ) + } + walDir := DefaultWALDir(dbPath) + if err := vfs.Default.MkdirAll(walDir, walDirPerm); err != nil { + return nil, fmt.Errorf("NewTendermintWALStore: create WAL directory: %w", err) + } + + dir := pebblewal.Dir{ + FS: vfs.Default, + Dirname: walDir, + } + + logs, err := pebblewal.Scan(dir) + if err != nil { + return nil, fmt.Errorf("NewTendermintWALStore: scan WAL directory: %w", err) + } + if err := recoverLatestWALTail(logs); err != nil { + return nil, fmt.Errorf("NewTendermintWALStore: recover latest WAL tail: %w", err) + } + prunedUpToHeight, err := loadPruneWatermark(walDir) + if err != nil { + return nil, fmt.Errorf("NewTendermintWALStore: load prune watermark: %w", err) + } + + manager, err := pebblewal.Init(pebblewal.Options{ + Primary: dir, + MinUnflushedWALNum: initialWALNum, + Logger: pebble.DefaultLogger, + EventListener: noopEventListener{}, + // No prealloc: batches are small and short-lived. + PreallocateSize: func() int { return 0 }, + // Consensus requires durability per batch. + MinSyncInterval: func() time.Duration { return 0 }, + // We do not use Pebble's flush pipeline. + WriteWALSyncOffsets: func() bool { return false }, + }, logs) + if err != nil { + return nil, fmt.Errorf("NewTendermintWALStore: initialize WAL manager: %w", err) + } + + walStore := &tendermintWALStore[V, H, A]{ + wal: newWALWriter(manager, walDir, nextWALNum(logs)), + nextBatchSeqNum: initialSeqNum, + entriesByHeight: make(map[types.Height][]wal.Entry[V, H, A]), + walFilesByHeight: make(map[types.Height]walNumSet), + walHeightRefs: make(map[pebblewal.NumWAL]int), + prunedUpToHeight: prunedUpToHeight, + } + + if err := walStore.loadExistingEntries(logs); err != nil { + return nil, errors.Join(err, manager.Close()) + } + + return walStore, nil +} + +func (s *tendermintWALStore[V, H, A]) Flush() error { + s.mu.Lock() + defer s.mu.Unlock() + + return s.flushLocked() +} + +func (s *tendermintWALStore[V, H, A]) flushLocked() error { + if s.closed { + return errors.New("Flush: WAL is closed") + } + + if len(s.pendingRecords) == 0 { + return nil + } + + encodedBatch, err := encodeBatch(s.pendingRecords, s.nextBatchSeqNum, s.encodedBatch) + if err != nil { + return err + } + recordCount := uint32(len(s.pendingRecords)) + appendResult, err := s.wal.appendSync(encodedBatch) + const maxReusableEncodedBatchCap = 512 << 10 // Sized for current mainnet encoded-batch workloads. + if cap(encodedBatch) <= maxReusableEncodedBatchCap { + s.encodedBatch = encodedBatch[:0] + } else { + s.encodedBatch = nil + } + if err != nil { + if !appendResult.committed { + return err + } + } + walNum := appendResult.walNum + + committedErr := err + records := s.pendingRecords + + s.nextBatchSeqNum += uint64(recordCount) + s.updateIndexesFromCommittedRecords(walNum, records) + + pruneRecordCount := countPruneRecords(records) + committedErr = errors.Join(committedErr, s.removeObsoleteWALFiles(pruneRecordCount)) + + clear(records) + s.pendingRecords = records[:0] + return committedErr +} + +func countPruneRecords[V types.Hashable[H], H types.Hash, A types.Addr]( + records []walRecordEnvelope[V, H, A], +) uint64 { + var pruneRecordCount uint64 + for _, record := range records { + if record.Kind == walRecordPruneUpToHeight { + pruneRecordCount++ + } + } + return pruneRecordCount +} + +func (s *tendermintWALStore[V, H, A]) removeObsoleteWALFiles( + pruneRecordCount uint64, +) error { + if pruneRecordCount == 0 { + return nil + } + + s.pruneRecordsSinceCleanup += pruneRecordCount + if s.pruneRecordsSinceCleanup < cleanupPruneRecordInterval { + return nil + } + + if err := writePruneWatermark(s.wal.dir, s.prunedUpToHeight); err != nil { + return err + } + + // Future optimisation: run cleanup in a background worker, piggyback prune + // durability on the next WAL flush instead of the driver's per-height Flush. + rotateErr := s.wal.rotateAfterSynced() + cleanupErr := s.cleanupObsoleteWALs() + if rotateErr == nil && cleanupErr == nil { + s.pruneRecordsSinceCleanup = 0 + } + return errors.Join(rotateErr, cleanupErr) +} + +func (s *tendermintWALStore[V, H, A]) cleanupObsoleteWALs() error { + minLiveWALNum := s.wal.minLiveWALNum() + updateMinLiveWALNum := func(walRefs map[pebblewal.NumWAL]int) { + for walNum := range walRefs { + if walNum < minLiveWALNum { + minLiveWALNum = walNum + } + } + } + updateMinLiveWALNum(s.walHeightRefs) + + toDelete, err := s.wal.obsolete(minLiveWALNum) + if err != nil { + return fmt.Errorf("cleanupObsoleteWALs: mark obsolete WALs: %w", err) + } + + for _, log := range toDelete { + if err := log.FS.Remove(log.Path); err != nil && !errors.Is(err, os.ErrNotExist) { + return fmt.Errorf("cleanupObsoleteWALs: remove obsolete WAL %s: %w", log.Path, err) + } + } + return nil +} + +func (s *tendermintWALStore[V, H, A]) DeleteWALEntries(height types.Height) error { + s.mu.Lock() + defer s.mu.Unlock() + if s.closed { + return errors.New("DeleteWALEntries: WAL is closed") + } + if height <= s.prunedUpToHeight { + return nil + } + + for i := range s.pendingRecords { + record := &s.pendingRecords[i] + if record.Kind == walRecordPruneUpToHeight { + record.Height = max(record.Height, height) + return nil + } + } + + s.pendingRecords = append(s.pendingRecords, walRecordEnvelope[V, H, A]{ + Kind: walRecordPruneUpToHeight, + Height: height, + }) + return nil +} + +func (s *tendermintWALStore[V, H, A]) SetWALEntry(entry wal.Entry[V, H, A]) error { + s.mu.Lock() + defer s.mu.Unlock() + if s.closed { + return errors.New("SetWALEntry: WAL is closed") + } + + record := walRecordEnvelope[V, H, A]{ + Kind: walRecordEntry, + } + if err := record.setEntry(entry); err != nil { + return err + } + if entry.GetHeight() <= s.prunedUpToHeight { + return nil + } + s.pendingRecords = append(s.pendingRecords, record) + return nil +} + +func (s *tendermintWALStore[V, H, A]) LoadAllEntries() iter.Seq2[wal.Entry[V, H, A], error] { + s.mu.Lock() + entrySnapshot := make(map[types.Height][]wal.Entry[V, H, A], len(s.entriesByHeight)) + heights := make([]types.Height, 0, len(s.entriesByHeight)) + for height, entries := range s.entriesByHeight { + heights = append(heights, height) + entrySnapshot[height] = append([]wal.Entry[V, H, A](nil), entries...) + } + s.mu.Unlock() + slices.Sort(heights) + + return func(yield func(wal.Entry[V, H, A], error) bool) { + for _, height := range heights { + for _, entry := range entrySnapshot[height] { + if !yield(entry, nil) { + return + } + } + } + } +} + +func (s *tendermintWALStore[V, H, A]) Close() error { + s.mu.Lock() + defer s.mu.Unlock() + if s.closed { + return nil + } + + flushErr := s.flushLocked() + s.closed = true + closeErr := s.wal.close() + return errors.Join(flushErr, closeErr, s.wal.closeManager()) +} diff --git a/consensus/walstore/wal_store_benchmark_test.go b/consensus/walstore/wal_store_benchmark_test.go new file mode 100644 index 0000000000..7802b68a5b --- /dev/null +++ b/consensus/walstore/wal_store_benchmark_test.go @@ -0,0 +1,216 @@ +package walstore_test + +import ( + "sort" + "testing" + "time" + + "github.com/NethermindEth/juno/consensus/starknet" + "github.com/NethermindEth/juno/consensus/types" + consensuswal "github.com/NethermindEth/juno/consensus/types/wal" + "github.com/NethermindEth/juno/consensus/walstore" + "github.com/NethermindEth/juno/core/felt" + "github.com/stretchr/testify/require" +) + +const benchmarkRound = types.Round(1) + +func BenchmarkTendermintWALStoreConsensusEntryDurability(b *testing.B) { + allEntryTypes := makeWALEntriesForHeight(1) + stages := []struct { + name string + entry starknet.WALEntry + }{ + {"Proposal", allEntryTypes[0]}, + {"Prevote", allEntryTypes[1]}, + {"Precommit", allEntryTypes[2]}, + {"Timeout", allEntryTypes[3]}, + } + + for _, stage := range stages { + b.Run(stage.name, func(b *testing.B) { + walStore := newBenchmarkWALStore(b) + b.ReportAllocs() + b.ResetTimer() + for range b.N { + require.NoError(b, walStore.SetWALEntry(stage.entry)) + require.NoError(b, walStore.Flush()) + } + }) + } + + b.Run("Height", func(b *testing.B) { + walStore := newBenchmarkWALStore(b) + flushGroups := makeHeightWALFlushGroups(1) + b.ReportAllocs() + b.ResetTimer() + for range b.N { + writeWALFlushGroups(b, walStore, flushGroups) + } + }) + + b.Run("HeightWithCleanup", func(b *testing.B) { + walStore := newBenchmarkWALStore(b) + b.ReportAllocs() + b.ResetTimer() + for i := range b.N { + height := types.Height(i + 1) + writeWALFlushGroups(b, walStore, makeHeightWALFlushGroups(height)) + + b.StopTimer() + walstore.ForcePendingCleanup(walStore) + b.StartTimer() + + require.NoError(b, walStore.DeleteWALEntries(height)) + require.NoError(b, walStore.Flush()) + } + }) +} + +// Measures per-height latency with cleanup at the real prune interval. +func BenchmarkTendermintWALStoreSustainedHeightLatency(b *testing.B) { + const heights = 1000 + + walStore := newBenchmarkWALStore(b) + flushGroupsByHeight := make([][][]starknet.WALEntry, heights) + for i := range flushGroupsByHeight { + flushGroupsByHeight[i] = makeHeightWALFlushGroups(types.Height(i + 1)) + } + + latencies := make([]time.Duration, 0, heights) + b.ReportAllocs() + b.ResetTimer() + for i := range heights { + height := types.Height(i + 1) + start := time.Now() + writeWALFlushGroups(b, walStore, flushGroupsByHeight[i]) + require.NoError(b, walStore.DeleteWALEntries(height)) + require.NoError(b, walStore.Flush()) + latencies = append(latencies, time.Since(start)) + } + b.StopTimer() + + reportHeightLatency(b, latencies) +} + +func reportHeightLatency(b *testing.B, latencies []time.Duration) { + b.Helper() + sort.Slice(latencies, func(i, j int) bool { return latencies[i] < latencies[j] }) + + pct := func(p float64) time.Duration { + idx := int(p * float64(len(latencies)-1)) + return latencies[idx] + } + + b.Logf( + "n=%d p50=%s p95=%s p99=%s max=%s", + len(latencies), + pct(0.50), + pct(0.95), + pct(0.99), + latencies[len(latencies)-1], + ) +} + +func newBenchmarkWALStore(b *testing.B) testTendermintWALStore { + b.Helper() + walStore, testDB, _ := newTestTendermintWALStore(b) + b.Cleanup(func() { + require.NoError(b, walStore.Close()) + require.NoError(b, testDB.Close()) + }) + return walStore +} + +func BenchmarkTendermintWALStoreReplay(b *testing.B) { + walStore, testDB, dbPath := newTestTendermintWALStore(b) + writeWALFlushGroups(b, walStore, makeHeightWALFlushGroups(1)) + require.NoError(b, walStore.Close()) + require.NoError(b, testDB.Close()) + + b.ReportAllocs() + b.ResetTimer() + + for range b.N { + ws, db := openTestTendermintWALStore(b, dbPath) + + b.StopTimer() + require.NoError(b, ws.Close()) + require.NoError(b, db.Close()) + b.StartTimer() + } +} + +func writeWALFlushGroups( + tb testing.TB, + walStore testTendermintWALStore, + flushGroups [][]starknet.WALEntry, +) { + for _, group := range flushGroups { + for _, entry := range group { + require.NoError(tb, walStore.SetWALEntry(entry)) + } + require.NoError(tb, walStore.Flush()) + } +} + +// makeHeightWALFlushGroups returns one height of WAL writes for 4 validators. +// Each group is flushed together, matching driver execution. +func makeHeightWALFlushGroups(height types.Height) [][]starknet.WALEntry { + start := consensuswal.Start(height) + proposer := felt.FromUint64[starknet.Address](1) + proposalValue := felt.FromUint64[starknet.Value](10) + proposalHash := proposalValue.Hash() + proposal := starknet.WALProposal{ + MessageHeader: starknet.MessageHeader{Height: height, Round: benchmarkRound, Sender: proposer}, + ValidRound: benchmarkRound, + Value: &proposalValue, + } + + prevote1 := makeBenchmarkWALPrevote(height, 1, &proposalHash) + prevote2 := makeBenchmarkWALPrevote(height, 2, &proposalHash) + prevote3 := makeBenchmarkWALPrevote(height, 3, &proposalHash) + precommit1 := makeBenchmarkWALPrecommit(height, 1, &proposalHash) + precommit2 := makeBenchmarkWALPrecommit(height, 2, &proposalHash) + precommit3 := makeBenchmarkWALPrecommit(height, 3, &proposalHash) + + return [][]starknet.WALEntry{ + {&start, &proposal}, + {&prevote1}, + {&prevote2}, + {&prevote3}, + {&precommit1}, + {&precommit2}, + {&precommit3}, + } +} + +func makeBenchmarkWALPrevote( + height types.Height, + sender uint64, + id *starknet.Hash, +) starknet.WALPrevote { + return starknet.WALPrevote{ + MessageHeader: starknet.MessageHeader{ + Height: height, + Round: benchmarkRound, + Sender: felt.FromUint64[starknet.Address](sender), + }, + ID: id, + } +} + +func makeBenchmarkWALPrecommit( + height types.Height, + sender uint64, + id *starknet.Hash, +) starknet.WALPrecommit { + return starknet.WALPrecommit{ + MessageHeader: starknet.MessageHeader{ + Height: height, + Round: benchmarkRound, + Sender: felt.FromUint64[starknet.Address](sender), + }, + ID: id, + } +} diff --git a/consensus/walstore/wal_store_failure_test.go b/consensus/walstore/wal_store_failure_test.go new file mode 100644 index 0000000000..b2094cc358 --- /dev/null +++ b/consensus/walstore/wal_store_failure_test.go @@ -0,0 +1,49 @@ +package walstore_test + +import ( + "io" + "os" + "path/filepath" + "testing" + + "github.com/NethermindEth/juno/consensus/starknet" + "github.com/NethermindEth/juno/consensus/walstore" + "github.com/NethermindEth/juno/db/pebblev2" + "github.com/cockroachdb/pebble/v2/record" + pebblewal "github.com/cockroachdb/pebble/v2/wal" + "github.com/stretchr/testify/require" +) + +func TestNewTendermintWALStoreFailsOnCorruptWALBatch(t *testing.T) { + dbPath := t.TempDir() + walDir := walstore.DefaultWALDir(dbPath) + require.NoError(t, os.MkdirAll(walDir, 0o755)) + walPath := filepath.Join(walDir, pebblewal.NumWAL(1).String()+".log") + appendValidWALRecord(t, walPath, []byte("not-a-batch")) + + testDB, err := pebblev2.New(dbPath) + require.NoError(t, err) + defer func() { + require.NoError(t, testDB.Close()) + }() + + _, err = walstore.NewTendermintWALStore[starknet.Value, starknet.Hash, starknet.Address](testDB) + require.Error(t, err) +} + +func appendValidWALRecord(t *testing.T, path string, payload []byte) { + t.Helper() + + file, err := os.OpenFile(path, os.O_CREATE|os.O_RDWR, 0o644) + require.NoError(t, err) + defer file.Close() + + _, err = file.Seek(0, io.SeekEnd) + require.NoError(t, err) + + writer := record.NewWriter(file) + _, err = writer.WriteRecord(payload) + require.NoError(t, err) + require.NoError(t, writer.Close()) + require.NoError(t, file.Sync()) +} diff --git a/consensus/walstore/wal_store_helpers_test.go b/consensus/walstore/wal_store_helpers_test.go new file mode 100644 index 0000000000..541e07b840 --- /dev/null +++ b/consensus/walstore/wal_store_helpers_test.go @@ -0,0 +1,134 @@ +package walstore_test + +import ( + "slices" + "testing" + + "github.com/NethermindEth/juno/consensus/starknet" + "github.com/NethermindEth/juno/consensus/types" + "github.com/NethermindEth/juno/consensus/walstore" + "github.com/NethermindEth/juno/core/felt" + kvdb "github.com/NethermindEth/juno/db" + "github.com/NethermindEth/juno/db/pebblev2" + "github.com/stretchr/testify/require" +) + +type testTendermintWALStore = walstore.TendermintWALStore[ + starknet.Value, + starknet.Hash, + starknet.Address, +] + +func openTestTendermintWALStore( + tb testing.TB, + dbPath string, +) (testTendermintWALStore, kvdb.KeyValueStore) { + tb.Helper() + + testDB, err := pebblev2.New(dbPath) + require.NoError(tb, err) + walStore, err := walstore.NewTendermintWALStore[ + starknet.Value, + starknet.Hash, + starknet.Address, + ](testDB) + require.NoError(tb, err) + + return walStore, testDB +} + +func newTestTendermintWALStore( + tb testing.TB, +) (testTendermintWALStore, kvdb.KeyValueStore, string) { + tb.Helper() + dbPath := tb.TempDir() + + walStore, testDB := openTestTendermintWALStore(tb, dbPath) + + return walStore, testDB, dbPath +} + +func reopenTestTendermintWALStore( + t *testing.T, + oldWALStore testTendermintWALStore, + oldDB kvdb.KeyValueStore, + dbPath string, +) (testTendermintWALStore, kvdb.KeyValueStore) { + t.Helper() + require.NoError(t, oldWALStore.Close()) + require.NoError(t, oldDB.Close()) + + return openTestTendermintWALStore(t, dbPath) +} + +func makeWALEntriesForHeight(height types.Height) []starknet.WALEntry { + const round = types.Round(1) + const timeoutStep = types.StepPrevote + + proposalSender := felt.FromUint64[starknet.Address](1) + prevoteSender := felt.FromUint64[starknet.Address](2) + precommitSender := felt.FromUint64[starknet.Address](3) + proposalValue := felt.FromUint64[starknet.Value](10) + proposalValueHash := proposalValue.Hash() + + proposal := starknet.WALProposal{ + MessageHeader: starknet.MessageHeader{Height: height, Round: round, Sender: proposalSender}, + ValidRound: round, + Value: &proposalValue, + } + prevote := starknet.WALPrevote{ + MessageHeader: starknet.MessageHeader{Height: height, Round: round, Sender: prevoteSender}, + ID: &proposalValueHash, + } + precommit := starknet.WALPrecommit{ + MessageHeader: starknet.MessageHeader{Height: height, Round: round, Sender: precommitSender}, + ID: &proposalValueHash, + } + timeout := starknet.WALTimeout{Height: height, Round: round, Step: timeoutStep} + + return []starknet.WALEntry{ + &proposal, + &prevote, + &precommit, + &timeout, + } +} + +func writeAndFlushEntries( + t *testing.T, + walStore testTendermintWALStore, + entryList ...[]starknet.WALEntry, +) { + t.Helper() + for _, entries := range entryList { + for _, entry := range entries { + require.NoError(t, walStore.SetWALEntry(entry)) + } + } + require.NoError(t, walStore.Flush()) +} + +func assertLoadedEntries( + t *testing.T, + walStore testTendermintWALStore, + entryList ...[]starknet.WALEntry, +) { + t.Helper() + index := 0 + entries := slices.Concat(entryList...) + for entry, err := range walStore.LoadAllEntries() { + require.NoError(t, err) + require.Equal(t, entries[index], entry) + index++ + } + require.Equal(t, len(entries), index) +} + +func assertNoEntries(t *testing.T, walStore testTendermintWALStore) { + t.Helper() + + for entry, err := range walStore.LoadAllEntries() { + require.NoError(t, err) + require.FailNowf(t, "unexpected entry", "%v", entry) + } +} diff --git a/consensus/walstore/wal_store_test.go b/consensus/walstore/wal_store_test.go new file mode 100644 index 0000000000..56e272f5b1 --- /dev/null +++ b/consensus/walstore/wal_store_test.go @@ -0,0 +1,212 @@ +package walstore_test + +import ( + "context" + "os" + "strings" + "testing" + + "github.com/NethermindEth/juno/consensus/starknet" + "github.com/NethermindEth/juno/consensus/types" + "github.com/NethermindEth/juno/consensus/types/wal" + "github.com/NethermindEth/juno/consensus/walstore" + "github.com/NethermindEth/juno/db/remote" + "github.com/NethermindEth/juno/utils/log" + "github.com/stretchr/testify/require" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials/insecure" +) + +func TestNewTendermintWALStoreRejectsDatabaseWithoutLocalPath(t *testing.T) { + testDB, err := remote.New( + "localhost:1234", + context.Background(), + log.NewNopZapLogger(), + grpc.WithTransportCredentials(insecure.NewCredentials()), + ) + require.NoError(t, err) + defer func() { + require.NoError(t, testDB.Close()) + }() + + walStore, err := walstore.NewTendermintWALStore[ + starknet.Value, + starknet.Hash, + starknet.Address, + ](testDB) + require.Error(t, err) + require.Nil(t, walStore) +} + +func TestSetWALEntryRejectsNilEntries(t *testing.T) { + walStore, testDB, _ := newTestTendermintWALStore(t) + defer func() { + require.NoError(t, walStore.Close()) + require.NoError(t, testDB.Close()) + }() + + tests := map[string]starknet.WALEntry{ + "nil interface": nil, + "nil start": (*wal.Start)(nil), + "nil proposal": (*starknet.WALProposal)(nil), + "nil prevote": (*starknet.WALPrevote)(nil), + "nil precommit": (*starknet.WALPrecommit)(nil), + "nil timeout": (*starknet.WALTimeout)(nil), + } + + for name, entry := range tests { + t.Run(name, func(t *testing.T) { + require.Error(t, walStore.SetWALEntry(entry)) + }) + } +} + +func TestWALReopenLoadsInterleavedEntries(t *testing.T) { + firstHeight := types.Height(1000) + secondHeight := firstHeight + 1 + thirdHeight := secondHeight + 1 + firstHeightFirstBatch := makeWALEntriesForHeight(firstHeight) + firstHeightSecondBatch := makeWALEntriesForHeight(firstHeight) + secondHeightFirstBatch := makeWALEntriesForHeight(secondHeight) + thirdHeightFirstBatch := makeWALEntriesForHeight(thirdHeight) + + walStore, testDB, dbPath := newTestTendermintWALStore(t) + defer func() { + require.NoError(t, walStore.Close()) + require.NoError(t, testDB.Close()) + }() + + for entryIndex := range firstHeightFirstBatch { + require.NoError(t, walStore.SetWALEntry(firstHeightFirstBatch[entryIndex])) + require.NoError(t, walStore.SetWALEntry(secondHeightFirstBatch[entryIndex])) + } + require.NoError(t, walStore.Flush()) + walStore, testDB = reopenTestTendermintWALStore(t, walStore, testDB, dbPath) + assertLoadedEntries(t, walStore, firstHeightFirstBatch, secondHeightFirstBatch) + + writeAndFlushEntries(t, walStore, firstHeightSecondBatch, thirdHeightFirstBatch) + walStore, testDB = reopenTestTendermintWALStore(t, walStore, testDB, dbPath) + assertLoadedEntries( + t, + walStore, + firstHeightFirstBatch, + firstHeightSecondBatch, + secondHeightFirstBatch, + thirdHeightFirstBatch, + ) +} + +func TestWALReopenSkipsPrunedHeight(t *testing.T) { + firstHeight := types.Height(1000) + secondHeight := firstHeight + 1 + thirdHeight := secondHeight + 1 + firstHeightEntries := makeWALEntriesForHeight(firstHeight) + secondHeightEntries := makeWALEntriesForHeight(secondHeight) + thirdHeightEntries := makeWALEntriesForHeight(thirdHeight) + + walStore, testDB, dbPath := newTestTendermintWALStore(t) + defer func() { + require.NoError(t, walStore.Close()) + require.NoError(t, testDB.Close()) + }() + + writeAndFlushEntries(t, walStore, firstHeightEntries, secondHeightEntries, thirdHeightEntries) + walStore, testDB = reopenTestTendermintWALStore(t, walStore, testDB, dbPath) + + require.NoError(t, walStore.DeleteWALEntries(firstHeight)) + require.NoError(t, walStore.Flush()) + + walStore, testDB = reopenTestTendermintWALStore(t, walStore, testDB, dbPath) + assertLoadedEntries(t, walStore, secondHeightEntries, thirdHeightEntries) +} + +func TestWALCloseFlushesPendingEntries(t *testing.T) { + committedHeight := types.Height(1000) + pendingHeight := committedHeight + 1 + committedEntries := makeWALEntriesForHeight(committedHeight) + pendingEntries := makeWALEntriesForHeight(pendingHeight) + + walStore, testDB, dbPath := newTestTendermintWALStore(t) + writeAndFlushEntries(t, walStore, committedEntries) + for _, entry := range pendingEntries { + require.NoError(t, walStore.SetWALEntry(entry)) + } + + require.NoError(t, walStore.Close()) + require.NoError(t, testDB.Close()) + + walStore, testDB = openTestTendermintWALStore(t, dbPath) + defer func() { + require.NoError(t, walStore.Close()) + require.NoError(t, testDB.Close()) + }() + + assertLoadedEntries( + t, + walStore, + committedEntries, + pendingEntries, + ) +} + +func TestPruningAllHeightsRemovesWALFiles(t *testing.T) { + // Must match cleanupPruneRecordInterval to trigger WAL file cleanup. + const heightsToDelete = 256 + + firstHeight := types.Height(1000) + entryBatches := make([][]starknet.WALEntry, heightsToDelete) + for i := range entryBatches { + entryBatches[i] = makeWALEntriesForHeight(firstHeight + types.Height(i)) + } + + walStore, testDB, dbPath := newTestTendermintWALStore(t) + defer func() { + require.NoError(t, walStore.Close()) + require.NoError(t, testDB.Close()) + }() + + writeAndFlushEntries(t, walStore, entryBatches...) + + for i := range entryBatches { + require.NoError(t, walStore.DeleteWALEntries(firstHeight+types.Height(i))) + require.NoError(t, walStore.Flush()) + } + + assertNoEntries(t, walStore) + assertNoWALLogFiles(t, dbPath) + + walStore, testDB = reopenTestTendermintWALStore(t, walStore, testDB, dbPath) + writeAndFlushEntries(t, walStore, makeWALEntriesForHeight(firstHeight)) + assertNoEntries(t, walStore) +} + +func TestWALReopenRejectsWritesAtPrunedHeight(t *testing.T) { + prunedHeight := types.Height(1000) + retainedHeight := prunedHeight + 1 + retainedEntries := makeWALEntriesForHeight(retainedHeight) + + walStore, testDB, dbPath := newTestTendermintWALStore(t) + defer func() { + require.NoError(t, walStore.Close()) + require.NoError(t, testDB.Close()) + }() + + writeAndFlushEntries(t, walStore, retainedEntries) + require.NoError(t, walStore.DeleteWALEntries(prunedHeight)) + require.NoError(t, walStore.Flush()) + + walStore, testDB = reopenTestTendermintWALStore(t, walStore, testDB, dbPath) + writeAndFlushEntries(t, walStore, makeWALEntriesForHeight(prunedHeight)) + + assertLoadedEntries(t, walStore, retainedEntries) +} + +func assertNoWALLogFiles(t *testing.T, dbPath string) { + t.Helper() + + dirEntries, err := os.ReadDir(walstore.DefaultWALDir(dbPath)) + require.NoError(t, err) + for _, entry := range dirEntries { + require.False(t, strings.HasSuffix(entry.Name(), ".log"), entry.Name()) + } +} diff --git a/consensus/walstore/wal_writer.go b/consensus/walstore/wal_writer.go new file mode 100644 index 0000000000..3d4846332c --- /dev/null +++ b/consensus/walstore/wal_writer.go @@ -0,0 +1,255 @@ +package walstore + +import ( + "errors" + "fmt" + "io" + "os" + "path/filepath" + "sync" + + "github.com/cockroachdb/pebble/v2/record" + pebblewal "github.com/cockroachdb/pebble/v2/wal" +) + +type walAppendResult struct { + walNum pebblewal.NumWAL + committed bool +} + +type noopEventListener struct{} + +func (noopEventListener) LogCreated(pebblewal.CreateInfo) {} + +type walWriter struct { + manager pebblewal.Manager + dir string + + nextWALNum pebblewal.NumWAL + + writer pebblewal.Writer + currentWALNum pebblewal.NumWAL + + currentWALSyncedOffset int64 + repairRequired bool +} + +func newWALWriter(manager pebblewal.Manager, dir string, nextWALNum pebblewal.NumWAL) *walWriter { + return &walWriter{ + manager: manager, + dir: dir, + nextWALNum: nextWALNum, + } +} + +func nextWALNum(logs pebblewal.Logs) pebblewal.NumWAL { + next := pebblewal.NumWAL(initialWALNum) + for _, log := range logs { + if candidate := log.Num + 1; candidate > next { + next = candidate + } + } + return next +} + +func (w *walWriter) appendSync(encodedBatch []byte) (walAppendResult, error) { + walNum, writer, err := w.ensureWriter() + if err != nil { + return walAppendResult{}, err + } + + var ( + waitGroup sync.WaitGroup + syncErr error + ) + waitGroup.Add(1) + + logicalOffset, err := writer.WriteRecord(encodedBatch, pebblewal.SyncOptions{ + Done: &waitGroup, + Err: &syncErr, + }, nil) + if err != nil { + abortErr := w.abortUncommitted() + return walAppendResult{}, errors.Join(fmt.Errorf("Flush: write WAL record: %w", err), abortErr) + } + + waitGroup.Wait() + if syncErr != nil { + abortErr := w.abortUncommitted() + return walAppendResult{}, errors.Join(fmt.Errorf("Flush: sync WAL record: %w", syncErr), abortErr) + } + + w.currentWALSyncedOffset = logicalOffset + return walAppendResult{walNum: walNum, committed: true}, nil +} + +func (w *walWriter) rotateAfterSynced() error { + return w.closeAndRepairCurrent(w.currentWALSyncedOffset, false) +} + +func (w *walWriter) abortUncommitted() error { + // Failed writes/syncs may leave trailing WAL bytes, so always repair. + return w.closeAndRepairCurrent(w.currentWALSyncedOffset, true) +} + +func (w *walWriter) close() error { + return w.closeAndRepairCurrent(w.currentWALSyncedOffset, false) +} + +func (w *walWriter) ensureWriter() (pebblewal.NumWAL, pebblewal.Writer, error) { + if w.repairRequired { + return 0, nil, errors.New( + "Flush: previous WAL tail repair failed; refusing to create a new WAL writer", + ) + } + if w.writer != nil { + return w.currentWALNum, w.writer, nil + } + + writer, err := w.manager.Create(w.nextWALNum, 0) + if err != nil { + return 0, nil, fmt.Errorf("Flush: create WAL writer: %w", err) + } + + w.currentWALNum = w.nextWALNum + w.nextWALNum++ + w.writer = writer + w.currentWALSyncedOffset = 0 + return w.currentWALNum, w.writer, nil +} + +func (w *walWriter) closeAndRepairCurrent(repairOffset int64, forceRepair bool) error { + if w.writer == nil { + return nil + } + + walNum := w.currentWALNum + closeErr := w.closeCurrent() + + if closeErr == nil && !forceRepair { + w.repairRequired = false + return nil + } + + walPath := filepath.Join(w.dir, walNum.String()+".log") + repairErr := repairWALTail(walPath, repairOffset) + if repairErr != nil { + w.repairRequired = true + return errors.Join(closeErr, repairErr) + } + w.repairRequired = false + return closeErr +} + +func (w *walWriter) closeCurrent() error { + if w.writer == nil { + return nil + } + + _, err := w.writer.Close() + w.writer = nil + w.currentWALNum = 0 + w.currentWALSyncedOffset = 0 + if err != nil { + return fmt.Errorf("Flush: close WAL writer: %w", err) + } + return nil +} + +func (w *walWriter) minLiveWALNum() pebblewal.NumWAL { + minLiveWALNum := w.nextWALNum + if w.writer != nil && w.currentWALNum < minLiveWALNum { + minLiveWALNum = w.currentWALNum + } + return minLiveWALNum +} + +func (w *walWriter) obsolete(minLiveWALNum pebblewal.NumWAL) ([]pebblewal.DeletableLog, error) { + if w.manager == nil { + return nil, nil + } + return w.manager.Obsolete(minLiveWALNum, true) +} + +func (w *walWriter) closeManager() error { + if w.manager == nil { + return nil + } + return w.manager.Close() +} + +func repairWALTail(walPath string, syncedOffset int64) (err error) { + file, openErr := os.OpenFile(walPath, os.O_RDWR, 0) + if errors.Is(openErr, os.ErrNotExist) { + return nil + } + if openErr != nil { + return fmt.Errorf("Flush: open WAL for repair: %w", openErr) + } + defer func() { + if closeErr := file.Close(); closeErr != nil { + err = errors.Join(err, fmt.Errorf("Flush: close WAL after repair: %w", closeErr)) + } + }() + + info, err := file.Stat() + if err != nil { + return fmt.Errorf("Flush: stat WAL for repair: %w", err) + } + if syncedOffset > info.Size() { + syncedOffset = info.Size() + } + + if err := file.Truncate(syncedOffset); err != nil { + return fmt.Errorf("Flush: truncate WAL to last synced offset: %w", err) + } + if err := file.Sync(); err != nil { + return fmt.Errorf("Flush: sync truncated WAL: %w", err) + } + return nil +} + +// Scan the latest WAL and truncate any bytes after the last valid record. +func recoverLatestWALTail(logs pebblewal.Logs) error { + if len(logs) == 0 { + return nil + } + + latest := logs[len(logs)-1] + reader := latest.OpenForRead() + defer reader.Close() + + for { + _, offset, err := reader.NextRecord() + switch { + case err == nil: + continue + case errors.Is(err, os.ErrNotExist): + return nil + case errors.Is(err, io.EOF): + return repairWALTailIfLonger(offset.PhysicalFile, offset.Physical) + case record.IsInvalidRecord(err): + return repairWALTail(offset.PhysicalFile, offset.Physical) + default: + return fmt.Errorf("recoverLatestWALTail: read WAL %s: %w", latest.Num, err) + } + } +} + +// EOF can still leave trailing garbage beyond the last valid record offset. +func repairWALTailIfLonger(walPath string, syncedOffset int64) error { + if walPath == "" { + return nil + } + info, err := os.Stat(walPath) + if errors.Is(err, os.ErrNotExist) { + return nil + } + if err != nil { + return fmt.Errorf("Flush: stat WAL for repair: %w", err) + } + if info.Size() <= syncedOffset { + return nil + } + return repairWALTail(walPath, syncedOffset) +} diff --git a/consensus/walstore/wal_writer_failure_test.go b/consensus/walstore/wal_writer_failure_test.go new file mode 100644 index 0000000000..5e2122f645 --- /dev/null +++ b/consensus/walstore/wal_writer_failure_test.go @@ -0,0 +1,302 @@ +package walstore + +// These tests assert durability/error-handling behavior, but use package walstore +// so they can inject WAL writer failures that are not reachable through the +// public WAL store API. +// +// Invariants covered here: +// - failed append/sync aborts to the previous committed offset and remains retryable +// - a committed prune record stays applied even if later rotate/close fails +// - ambiguous repair state blocks future WAL writers + +import ( + "errors" + "io" + "os" + "path/filepath" + "testing" + + "github.com/NethermindEth/juno/consensus/starknet" + "github.com/NethermindEth/juno/consensus/types" + consensuswal "github.com/NethermindEth/juno/consensus/types/wal" + "github.com/cockroachdb/pebble/v2/record" + pebblewal "github.com/cockroachdb/pebble/v2/wal" + "github.com/stretchr/testify/require" +) + +func TestFlushKeepsExistingEntriesWhenPruneSyncFails(t *testing.T) { + const walNum = pebblewal.NumWAL(1) + + walDir := t.TempDir() + walPath := walPathFor(walDir, walNum) + syncErr := errors.New("fsync failed") + + walStore := newTestWALStore( + walDir, + &fakeWALWriter{path: walPath, syncErr: syncErr}, + walNum, + 0, + ) + existingEntry := seedExistingEntry(walStore, 1) + require.NoError(t, walStore.DeleteWALEntries(1)) + + err := walStore.Flush() + require.ErrorIs(t, err, syncErr) + requireEntries(t, walStore, existingEntry) +} + +func TestAbortUncommittedRepairsWALTail(t *testing.T) { + const walNum = pebblewal.NumWAL(2) + + walDir := t.TempDir() + walPath := walPathFor(walDir, walNum) + // Keep a valid committed prefix so repair must truncate only the bad tail. + committedOffset := appendPhysicalRecord(t, walPath, emptyBatchPayload(t)) + appendPhysicalRecord(t, walPath, []byte("garbage-tail")) + closeErr := errors.New("close failed") + + walStore := newTestWALStore( + walDir, + &fakeWALWriter{path: walPath, closeErr: closeErr}, + walNum, + committedOffset, + ) + + err := walStore.wal.abortUncommitted() + require.ErrorIs(t, err, closeErr) + requireWALSize(t, walPath, committedOffset) +} + +func TestFailedRepairBlocksFutureFlushes(t *testing.T) { + const walNum = pebblewal.NumWAL(3) + + walDir := t.TempDir() + walPath := walPathFor(walDir, walNum) + // walPath normally points to a WAL file. Making it a directory forces repair to fail. + require.NoError(t, os.Mkdir(walPath, 0o755)) + + walStore := newTestWALStore(walDir, &fakeWALWriter{path: walPath}, walNum, 0) + + err := walStore.wal.abortUncommitted() + require.Error(t, err) + start := consensuswal.Start(1) + require.NoError(t, walStore.SetWALEntry(&start)) + err = walStore.Flush() + require.Error(t, err) + requireEntries(t, walStore) +} + +func TestFlushFailureDoesNotExposeEntriesAndCanRetry(t *testing.T) { + tests := map[string]func(string, error) *fakeWALWriter{ + "write fails": func(walPath string, err error) *fakeWALWriter { + return &fakeWALWriter{path: walPath, writeErr: err} + }, + "sync fails": func(walPath string, err error) *fakeWALWriter { + return &fakeWALWriter{path: walPath, syncErr: err} + }, + } + + for name, failingWriter := range tests { + t.Run(name, func(t *testing.T) { + const walNum = pebblewal.NumWAL(4) + + walDir := t.TempDir() + walPath := walPathFor(walDir, walNum) + + start := consensuswal.Start(2) + injectedErr := errors.New(name) + walStore := newTestWALStore( + walDir, + failingWriter(walPath, injectedErr), + walNum, + 0, + ) + require.NoError(t, walStore.SetWALEntry(&start)) + + err := walStore.Flush() + require.ErrorIs(t, err, injectedErr) + requireEntries(t, walStore) + + // Failed flushes close the writer, so install a fresh fake writer for retry. + walStore.wal.writer = &fakeWALWriter{path: walPath} + walStore.wal.currentWALNum = walNum + require.NoError(t, walStore.Flush()) + requireEntries(t, walStore, &start) + }) + } +} + +func TestFlushKeepsEntriesPrunedAfterCloseFailure(t *testing.T) { + const walNum = pebblewal.NumWAL(5) + + walDir := t.TempDir() + walPath := walPathFor(walDir, walNum) + closeErr := errors.New("close failed") + walStore := newTestWALStore( + walDir, + &fakeWALWriter{path: walPath, closeErr: closeErr}, + walNum, + 0, + ) + seedExistingEntry(walStore, 1) + // Force cleanup, which closes the writer after the prune record is committed. + walStore.pruneRecordsSinceCleanup = cleanupPruneRecordInterval - 1 + require.NoError(t, walStore.DeleteWALEntries(1)) + + err := walStore.Flush() + require.ErrorIs(t, err, closeErr) + requireEntries(t, walStore) + + require.NoError(t, walStore.Flush()) + requireEntries(t, walStore) +} + +type ( + testWALStore = tendermintWALStore[starknet.Value, starknet.Hash, starknet.Address] + testWALEntry = consensuswal.Entry[starknet.Value, starknet.Hash, starknet.Address] +) + +type fakeWALWriter struct { + path string + writeErr error + syncErr error + closeErr error +} + +var _ pebblewal.Writer = (*fakeWALWriter)(nil) + +func newTestWALStore( + walDir string, + writer pebblewal.Writer, + walNum pebblewal.NumWAL, + syncedOffset int64, +) *testWALStore { + wal := newWALWriter(nil, walDir, walNum+1) + wal.writer = writer + wal.currentWALNum = walNum + wal.currentWALSyncedOffset = syncedOffset + + return &tendermintWALStore[starknet.Value, starknet.Hash, starknet.Address]{ + wal: wal, + nextBatchSeqNum: initialSeqNum, + entriesByHeight: make(map[types.Height][]testWALEntry), + walFilesByHeight: make(map[types.Height]walNumSet), + walHeightRefs: make(map[pebblewal.NumWAL]int), + } +} + +func (w *fakeWALWriter) WriteRecord( + p []byte, + opts pebblewal.SyncOptions, + _ pebblewal.RefCount, +) (int64, error) { + if w.writeErr != nil { + return 0, w.writeErr + } + + offset, err := appendPhysicalWALRecord(w.path, p) + if err != nil { + return 0, err + } + + if opts.Err != nil { + *opts.Err = w.syncErr + } + + if opts.Done != nil { + opts.Done.Done() + } + return offset, nil +} + +func (w *fakeWALWriter) Close() (int64, error) { + return 0, w.closeErr +} + +func (w *fakeWALWriter) Metrics() record.LogWriterMetrics { + return record.LogWriterMetrics{} +} + +func walPathFor(walDir string, walNum pebblewal.NumWAL) string { + return filepath.Join(walDir, walNum.String()+".log") +} + +func appendPhysicalWALRecord(path string, payload []byte) (int64, error) { + file, err := os.OpenFile(path, os.O_CREATE|os.O_RDWR, 0o644) + if err != nil { + return 0, err + } + defer file.Close() + + if _, err := file.Seek(0, io.SeekEnd); err != nil { + return 0, err + } + + writer := record.NewWriter(file) + offset, err := writer.WriteRecord(payload) + if closeErr := writer.Close(); closeErr != nil { + err = errors.Join(err, closeErr) + } + if syncErr := file.Sync(); syncErr != nil { + err = errors.Join(err, syncErr) + } + return offset, err +} + +func appendPhysicalRecord(t *testing.T, path string, payload []byte) int64 { + t.Helper() + + offset, err := appendPhysicalWALRecord(path, payload) + require.NoError(t, err) + return offset +} + +func emptyBatchPayload(t *testing.T) []byte { + t.Helper() + + payload, err := encodeBatch[starknet.Value, starknet.Hash, starknet.Address]( + nil, initialSeqNum, nil, + ) + require.NoError(t, err) + return payload +} + +func requireWALSize(t *testing.T, walPath string, size int64) { + t.Helper() + + info, err := os.Stat(walPath) + require.NoError(t, err) + require.Equal(t, size, info.Size()) +} + +func loadAllEntries( + t *testing.T, + db *testWALStore, +) []testWALEntry { + t.Helper() + + var entries []testWALEntry + for entry, err := range db.LoadAllEntries() { + require.NoError(t, err) + entries = append(entries, entry) + } + return entries +} + +func requireEntries( + t *testing.T, + db *testWALStore, + want ...testWALEntry, +) { + t.Helper() + require.Equal(t, want, loadAllEntries(t, db)) +} + +func seedExistingEntry( + db *testWALStore, + height types.Height, +) testWALEntry { + start := consensuswal.Start(height) + db.entriesByHeight[height] = []testWALEntry{&start} + return &start +} diff --git a/db/database.go b/db/database.go index 0379f91281..7ce383b2f3 100644 --- a/db/database.go +++ b/db/database.go @@ -46,6 +46,9 @@ type Helper interface { // remove this once the metrics are refactored // Returns the underlying database Impl() any + // Returns a local filesystem path for consensus WAL files. + // Empty means the DB has no local path. + Path() string } // Represents a key-value data store that can handle different operations diff --git a/db/memory/db.go b/db/memory/db.go index 2b87ef93c1..56db7fbeb8 100644 --- a/db/memory/db.go +++ b/db/memory/db.go @@ -2,6 +2,7 @@ package memory import ( "errors" + "os" "slices" "sort" "strings" @@ -22,8 +23,10 @@ var _ db.KeyValueStore = (*Database)(nil) // Represents an in-memory key-value store. // It is thread-safe. type Database struct { - db map[string][]byte - lock sync.RWMutex + db map[string][]byte + path string + removePathOnClose bool + lock sync.RWMutex } func New() *Database { @@ -89,9 +92,32 @@ func (d *Database) Close() error { defer d.lock.Unlock() d.db = nil + if d.removePathOnClose && d.path != "" { + _ = os.RemoveAll(d.path) + } return nil } +func (d *Database) Path() string { + d.lock.Lock() + defer d.lock.Unlock() + + if d.db == nil { + return "" + } + if d.path != "" { + return d.path + } + + path, err := os.MkdirTemp("", "juno-memory-db-*") + if err != nil { + return "" + } + d.path = path + d.removePathOnClose = true + return d.path +} + func (d *Database) DeleteRange(start, end []byte) error { d.lock.Lock() defer d.lock.Unlock() diff --git a/db/memory/db_test.go b/db/memory/db_test.go index 3f6233ab5b..91ba94fadc 100644 --- a/db/memory/db_test.go +++ b/db/memory/db_test.go @@ -1,15 +1,38 @@ -package memory +package memory_test import ( "testing" "github.com/NethermindEth/juno/db" + "github.com/NethermindEth/juno/db/memory" + "github.com/stretchr/testify/require" ) func TestMemoryDB(t *testing.T) { t.Run("test suite", func(t *testing.T) { db.TestKeyValueStoreSuite(t, func() db.KeyValueStore { - return New() + return memory.New() }) }) } + +func TestMemoryDBPathLifecycle(t *testing.T) { + memoryDB := memory.New() + + path := memoryDB.Path() + require.NotEmpty(t, path) + + require.NoError(t, memoryDB.Close()) +} + +func TestMemoryDBCopyGetsIndependentPath(t *testing.T) { + memoryDB := memory.New() + + copyDB := memoryDB.Copy() + copyPath := copyDB.Path() + require.NotEmpty(t, copyPath) + require.NotEqual(t, memoryDB.Path(), copyPath) + + require.NoError(t, memoryDB.Close()) + require.NoError(t, copyDB.Close()) +} diff --git a/db/pebble/db.go b/db/pebble/db.go index 74d5244158..ba9a14e202 100644 --- a/db/pebble/db.go +++ b/db/pebble/db.go @@ -20,6 +20,7 @@ var _ db.KeyValueStore = (*DB)(nil) type DB struct { db *pebble.DB + path string closed bool writeOpt *pebble.WriteOptions listener db.EventListener @@ -42,12 +43,17 @@ func New(path string, options ...Option) (db.KeyValueStore, error) { return &DB{ db: pDB, + path: path, closeLock: new(sync.RWMutex), listener: &db.SelectiveListener{}, writeOpt: &pebble.WriteOptions{Sync: true}, // TODO: can we use non-sync writes for performance? }, nil } +func (d *DB) Path() string { + return d.path +} + func (d *DB) Close() error { d.closeLock.Lock() defer d.closeLock.Unlock() diff --git a/db/pebblev2/db.go b/db/pebblev2/db.go index 9d2c9fb9a8..690b4ad78d 100644 --- a/db/pebblev2/db.go +++ b/db/pebblev2/db.go @@ -20,6 +20,7 @@ var _ db.KeyValueStore = (*DB)(nil) type DB struct { db *pebble.DB + path string closed bool writeOpt *pebble.WriteOptions listener db.EventListener @@ -49,12 +50,17 @@ func New(path string, options ...Option) (db.KeyValueStore, error) { return &DB{ db: pDB, + path: path, closeLock: new(sync.RWMutex), listener: &db.SelectiveListener{}, writeOpt: &pebble.WriteOptions{Sync: true}, // TODO: can we use non-sync writes for performance? }, nil } +func (d *DB) Path() string { + return d.path +} + func (d *DB) Close() error { d.closeLock.Lock() defer d.closeLock.Unlock() diff --git a/db/remote/db.go b/db/remote/db.go index 0e3052edba..7a4fc1d69e 100644 --- a/db/remote/db.go +++ b/db/remote/db.go @@ -46,6 +46,11 @@ func New( }, nil } +func (d *DB) Path() string { + // Remote DB has no local filesystem path. + return "" +} + func (d *DB) NewTransaction(write bool) (*transaction, error) { defer d.listener.OnIO(write, time.Now()) diff --git a/sync/data_source.go b/sync/data_source.go index 9f1ae8310f..1ff6131aab 100644 --- a/sync/data_source.go +++ b/sync/data_source.go @@ -17,7 +17,7 @@ type CommittedBlock struct { Block *core.Block StateUpdate *core.StateUpdate NewClasses map[felt.Felt]core.ClassDefinition - Persisted chan struct{} // This is used to signal that the block has been persisted + Persisted chan error // This is used to signal whether the block was persisted successfully } type DataSource interface { @@ -59,7 +59,7 @@ func (f *feederGatewayDataSource) BlockByNumber(ctx context.Context, blockNumber Block: block, StateUpdate: stateUpdate, NewClasses: newClasses, - Persisted: make(chan struct{}), + Persisted: make(chan error, 1), }, nil } diff --git a/sync/reorg_test.go b/sync/reorg_test.go index 8465fcf796..82fbb64cdc 100644 --- a/sync/reorg_test.go +++ b/sync/reorg_test.go @@ -86,7 +86,7 @@ func (t *testBlockDataSource) getBlocks() []sync.CommittedBlock { func getBlock(blocks []sync.CommittedBlock, blockNumber uint64) sync.CommittedBlock { committedBlock := blocks[blockNumber] - committedBlock.Persisted = make(chan struct{}) + committedBlock.Persisted = make(chan error, 1) return committedBlock } diff --git a/sync/sync.go b/sync/sync.go index 5730ada612..8b5f7ad41c 100644 --- a/sync/sync.go +++ b/sync/sync.go @@ -321,7 +321,7 @@ func (s *Synchronizer) verifierTask( ) if err != nil { return func() { - defer close(committedBlock.Persisted) + committedBlock.Persisted <- err s.logger.Warn( "Sanity checks failed", zap.Uint64("number", committedBlock.Block.Number), @@ -344,9 +344,9 @@ func (s *Synchronizer) storeTask( resetStreams context.CancelFunc, commitments *core.BlockCommitments, ) { - defer close(committedBlock.Persisted) select { case <-ctx.Done(): + committedBlock.Persisted <- ctx.Err() return default: } @@ -356,6 +356,7 @@ func (s *Synchronizer) storeTask( stateUpdate := committedBlock.StateUpdate newClasses := committedBlock.NewClasses if err := s.blockchain.Store(block, commitments, stateUpdate, newClasses); err != nil { + committedBlock.Persisted <- err if errors.Is(err, blockchain.ErrParentDoesNotMatchHead) { // Block block.Number - 1 is the parent of this block which doesn't match // so we need to revert the head to block.Number - 2 @@ -368,6 +369,7 @@ func (s *Synchronizer) storeTask( resetStreams() return } + committedBlock.Persisted <- nil s.listener.OnSyncStepDone(OpStore, block.Number, time.Since(storeTimer))