Skip to content

Commit 97c112c

Browse files
committed
chore: code cleanups
1 parent 2b58353 commit 97c112c

11 files changed

Lines changed: 97 additions & 77 deletions

File tree

Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ test: clean-testcache rustdeps ## Run tests
7474
go test $(GO_TAGS) ./...
7575

7676
test-new-state: clean-testcache rustdeps ## Run tests with new state
77-
USE_NEW_STATE=true go test $(GO_TAGS) ./...
77+
JUNO_NEW_STATE=true go test $(GO_TAGS) ./...
7878

7979
test-cached: rustdeps ## Run cached tests
8080
go test $(GO_TAGS) ./...

blockchain/blockchain.go

Lines changed: 25 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ type Blockchain struct {
8686
cachedFilters *AggregatedBloomFilterCache
8787
runningFilter *core.RunningEventFilter
8888
transactionLayout core.TransactionLayout
89-
StateFactory *statefactory.StateFactory
89+
stateFactory *statefactory.StateFactory
9090
}
9191

9292
func New(database db.KeyValueStore, network *utils.Network, stateVersion bool) *Blockchain {
@@ -115,7 +115,7 @@ func New(database db.KeyValueStore, network *utils.Network, stateVersion bool) *
115115
cachedFilters: &cachedFilters,
116116
runningFilter: runningFilter,
117117
transactionLayout: core.TransactionLayoutPerTx, // default to per-tx for backward compatibility
118-
StateFactory: stateFactory,
118+
stateFactory: stateFactory,
119119
}
120120
}
121121

@@ -322,7 +322,7 @@ func (b *Blockchain) Store(
322322
) error {
323323
// old state
324324
// TODO(maksymmalick): remove this once we have a new state implementation
325-
if !b.StateFactory.UseNewState() {
325+
if !b.stateFactory.UseNewState() {
326326
return b.deprecatedStore(block, blockCommitments, stateUpdate, newClasses)
327327
}
328328

@@ -406,7 +406,7 @@ func (b *Blockchain) store(
406406
return err
407407
}
408408

409-
st, err := b.StateFactory.NewState(stateUpdate.OldRoot, nil, batch)
409+
st, err := b.stateFactory.NewState(stateUpdate.OldRoot, nil, batch)
410410
if err != nil {
411411
return err
412412
}
@@ -667,9 +667,9 @@ func (b *Blockchain) HeadState() (core.StateReader, StateCloser, error) {
667667
return nil, nil, err
668668
}
669669

670-
state, err := b.StateFactory.NewState(header.GlobalStateRoot, txn, nil)
670+
st, err := b.stateFactory.NewStateReader(header.GlobalStateRoot, txn)
671671

672-
return state, noopStateCloser, err
672+
return st, noopStateCloser, err
673673
}
674674

675675
// StateAtBlockNumber returns a StateReader that provides
@@ -685,12 +685,6 @@ func (b *Blockchain) StateAtBlockNumber(
685685
return nil, nil, err
686686
}
687687

688-
if !b.StateFactory.UseNewState() {
689-
return core.NewDeprecatedStateHistory(
690-
core.NewDeprecatedState(txn), blockNumber,
691-
), noopStateCloser, nil
692-
}
693-
694688
height, err := core.GetChainHeight(txn)
695689
if err != nil {
696690
return nil, nil, err
@@ -701,11 +695,11 @@ func (b *Blockchain) StateAtBlockNumber(
701695
return nil, nil, err
702696
}
703697

704-
history, err := state.NewStateHistory(blockNumber, header.GlobalStateRoot, b.stateDB)
698+
st, err := b.stateFactory.NewStateHistory(header.GlobalStateRoot, txn, blockNumber)
705699
if err != nil {
706700
return nil, nil, err
707701
}
708-
return &history, noopStateCloser, nil
702+
return st, noopStateCloser, nil
709703
}
710704

711705
// StateAtBlockHash returns a StateReader that provides
@@ -715,7 +709,7 @@ func (b *Blockchain) StateAtBlockHash(
715709
) (core.StateReader, StateCloser, error) {
716710
b.listener.OnRead("StateAtBlockHash")
717711
if blockHash.IsZero() {
718-
emptyState, err := b.StateFactory.EmptyState()
712+
emptyState, err := b.stateFactory.EmptyState()
719713
return emptyState, noopStateCloser, err
720714
}
721715

@@ -724,13 +718,6 @@ func (b *Blockchain) StateAtBlockHash(
724718
if err != nil {
725719
return nil, nil, err
726720
}
727-
if !b.StateFactory.UseNewState() {
728-
return core.NewDeprecatedStateHistory(
729-
core.NewDeprecatedState(txn),
730-
header.Number,
731-
), noopStateCloser, nil
732-
}
733-
734721
height, err := core.GetChainHeight(txn)
735722
if err != nil {
736723
return nil, nil, err
@@ -741,11 +728,11 @@ func (b *Blockchain) StateAtBlockHash(
741728
return nil, nil, err
742729
}
743730

744-
history, err := state.NewStateHistory(header.Number, headHeader.GlobalStateRoot, b.stateDB)
731+
st, err := b.stateFactory.NewStateHistory(headHeader.GlobalStateRoot, txn, header.Number)
745732
if err != nil {
746733
return nil, nil, err
747734
}
748-
return &history, noopStateCloser, nil
735+
return st, noopStateCloser, nil
749736
}
750737

751738
// EventFilter returns an EventFilter object that is tied to a snapshot of the blockchain
@@ -775,14 +762,14 @@ func (b *Blockchain) EventFilter(
775762

776763
// RevertHead reverts the head block
777764
func (b *Blockchain) RevertHead() error {
778-
if !b.StateFactory.UseNewState() {
765+
if !b.stateFactory.UseNewState() {
779766
return b.database.Update(b.deprecatedRevertHead)
780767
}
781768
return b.database.Write(b.revertHead)
782769
}
783770

784771
func (b *Blockchain) GetReverseStateDiff() (core.StateDiff, error) {
785-
if !b.StateFactory.UseNewState() {
772+
if !b.stateFactory.UseNewState() {
786773
return b.deprecatedGetReverseStateDiff()
787774
}
788775

@@ -822,12 +809,12 @@ func (b *Blockchain) getReverseStateDiff() (core.StateDiff, error) {
822809
if err != nil {
823810
return ret, err
824811
}
825-
state, err := state.New(stateUpdate.NewRoot, b.stateDB, nil)
812+
st, err := state.NewStateReader(stateUpdate.NewRoot, b.stateDB)
826813
if err != nil {
827814
return ret, err
828815
}
829816

830-
return state.GetReverseStateDiff(blockNum, stateUpdate.StateDiff)
817+
return st.GetReverseStateDiff(blockNum, stateUpdate.StateDiff)
831818
}
832819

833820
func (b *Blockchain) deprecatedRevertHead(txn db.IndexedBatch) error {
@@ -972,7 +959,14 @@ func (b *Blockchain) Simulate(
972959
txn := b.database.NewIndexedBatch()
973960
defer txn.Close()
974961

975-
if err := b.updateStateRoots(txn, nil, block, stateUpdate, newClasses); err != nil {
962+
// For the new state path, create a temporary batch that is intentionally never
963+
// committed — writes accumulate in memory and are discarded after simulation.
964+
var batch db.Batch
965+
if b.stateFactory.UseNewState() {
966+
batch = b.database.NewBatch()
967+
}
968+
969+
if err := b.updateStateRoots(txn, batch, block, stateUpdate, newClasses); err != nil {
976970
return SimulateResult{}, err
977971
}
978972

@@ -1005,7 +999,7 @@ func (b *Blockchain) Finalise(
1005999
newClasses map[felt.Felt]core.ClassDefinition,
10061000
sign utils.BlockSignFunc,
10071001
) error {
1008-
if !b.StateFactory.UseNewState() {
1002+
if !b.stateFactory.UseNewState() {
10091003
err := b.database.Update(func(txn db.IndexedBatch) error {
10101004
if err := b.updateStateRoots(txn, nil, block, stateUpdate, newClasses); err != nil {
10111005
return err
@@ -1101,7 +1095,7 @@ func (b *Blockchain) updateStateRoots(
11011095
stateRoot = &felt.Zero
11021096
}
11031097

1104-
state, err := b.StateFactory.NewState(stateRoot, txn, batch)
1098+
state, err := b.stateFactory.NewState(stateRoot, txn, batch)
11051099
if err != nil {
11061100
return err
11071101
}

blockchain/blockchain_test.go

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1070,11 +1070,8 @@ func chainStateCommitment(testDB db.KeyValueStore) (felt.Felt, error) {
10701070
}
10711071
stateDB := state.NewStateDB(testDB, trieDB)
10721072
sf := statefactory.NewStateFactory(statetestutils.UseNewState(), trieDB, stateDB)
1073-
var txn db.IndexedBatch
1074-
if !sf.UseNewState() {
1075-
txn = testDB.NewIndexedBatch()
1076-
}
1077-
st, err := sf.NewState(header.GlobalStateRoot, txn, nil)
1073+
txn := testDB.NewIndexedBatch()
1074+
st, err := sf.NewStateReader(header.GlobalStateRoot, txn)
10781075
if err != nil {
10791076
return felt.Felt{}, err
10801077
}

core/pending_state.go

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,7 @@ func (p *PendingState) ContractClassHash(addr *felt.Felt) (felt.Felt, error) {
4242
} else if classHash, ok = p.stateDiff.DeployedContracts[*addr]; ok {
4343
return *classHash, nil
4444
}
45-
classHash, err := p.head.ContractClassHash(addr)
46-
return classHash, err
45+
return p.head.ContractClassHash(addr)
4746
}
4847

4948
func (p *PendingState) ContractNonce(addr *felt.Felt) (felt.Felt, error) {
@@ -52,8 +51,7 @@ func (p *PendingState) ContractNonce(addr *felt.Felt) (felt.Felt, error) {
5251
} else if _, found = p.stateDiff.DeployedContracts[*addr]; found {
5352
return felt.Felt{}, nil
5453
}
55-
nonce, err := p.head.ContractNonce(addr)
56-
return nonce, err
54+
return p.head.ContractNonce(addr)
5755
}
5856

5957
func (p *PendingState) ContractStorage(addr, key *felt.Felt) (felt.Felt, error) {
@@ -65,8 +63,7 @@ func (p *PendingState) ContractStorage(addr, key *felt.Felt) (felt.Felt, error)
6563
if _, found := p.stateDiff.DeployedContracts[*addr]; found {
6664
return felt.Felt{}, nil
6765
}
68-
value, err := p.head.ContractStorage(addr, key)
69-
return value, err
66+
return p.head.ContractStorage(addr, key)
7067
}
7168

7269
// ContractStorageLastUpdatedBlock returns the most recent block number at which a given storage

core/state/errors.go

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,5 @@ var (
88
ErrContractNotDeployed = errors.New("contract not deployed")
99
ErrContractAlreadyDeployed = errors.New("contract already deployed")
1010
ErrNoHistoryValue = errors.New("no history value found")
11-
ErrCheckHeadState = errors.New("check head state")
1211
ErrHistoricalTrieNotSupported = errors.New("cannot support historical trie")
1312
)

core/state/history.go

Lines changed: 3 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
package state
22

33
import (
4-
"errors"
5-
64
"github.com/NethermindEth/juno/core"
75
"github.com/NethermindEth/juno/core/felt"
86
"github.com/NethermindEth/juno/db"
@@ -30,41 +28,32 @@ func NewStateHistory(blockNum uint64, stateRoot *felt.Felt, db *StateDB) (stateH
3028

3129
func (s *stateHistory) ContractClassHash(addr *felt.Felt) (felt.Felt, error) {
3230
if err := s.checkDeployed(addr); err != nil {
33-
return felt.Felt{}, err
31+
return felt.Zero, err
3432
}
3533
ret, err := s.state.ContractClassHashAt(addr, s.blockNum)
3634
if err != nil {
37-
if errors.Is(err, ErrCheckHeadState) {
38-
return s.state.ContractClassHash(addr)
39-
}
4035
return felt.Zero, err
4136
}
4237
return ret, nil
4338
}
4439

4540
func (s *stateHistory) ContractNonce(addr *felt.Felt) (felt.Felt, error) {
4641
if err := s.checkDeployed(addr); err != nil {
47-
return felt.Felt{}, err
42+
return felt.Zero, err
4843
}
4944
ret, err := s.state.ContractNonceAt(addr, s.blockNum)
5045
if err != nil {
51-
if errors.Is(err, ErrCheckHeadState) {
52-
return s.state.ContractNonce(addr)
53-
}
5446
return felt.Zero, err
5547
}
5648
return ret, nil
5749
}
5850

5951
func (s *stateHistory) ContractStorage(addr, key *felt.Felt) (felt.Felt, error) {
6052
if err := s.checkDeployed(addr); err != nil {
61-
return felt.Felt{}, err
53+
return felt.Zero, err
6254
}
6355
ret, err := s.state.ContractStorageAt(addr, key, s.blockNum)
6456
if err != nil {
65-
if errors.Is(err, ErrCheckHeadState) {
66-
return s.state.ContractStorage(addr, key)
67-
}
6857
return felt.Zero, err
6958
}
7059
return ret, nil

core/state/state.go

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -356,16 +356,11 @@ func (s *State) Revert(header *core.Header, update *core.StateUpdate) error {
356356
if !newComm.Equal(update.OldRoot) {
357357
return fmt.Errorf("state commitment mismatch: %v (expected) != %v (actual)", update.OldRoot, &newComm)
358358
}
359-
if s.batch != nil {
360-
if err := s.flush(blockNum, &stateUpdate, dirtyClasses, false); err != nil {
361-
return err
362-
}
363-
if err := s.deleteHistory(blockNum, update.StateDiff); err != nil {
364-
return err
365-
}
359+
if err := s.flush(blockNum, &stateUpdate, dirtyClasses, false); err != nil {
360+
return err
366361
}
367362

368-
return nil
363+
return s.deleteHistory(blockNum, update.StateDiff)
369364
}
370365

371366
func (s *State) GetReverseStateDiff(blockNum uint64, diff *core.StateDiff) (core.StateDiff, error) {

core/state/statefactory/state_factory.go

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ func NewStateFactory(
3232
}
3333

3434
func (sf *StateFactory) NewState(
35-
// todo: this should be *felt.StateRootHash
35+
// TODO: this should be *felt.StateRootHash
3636
stateRoot *felt.Felt,
3737
txn db.IndexedBatch,
3838
batch db.Batch,
@@ -49,7 +49,7 @@ func (sf *StateFactory) NewState(
4949
return state, nil
5050
}
5151

52-
func (sf *StateFactory) NewStateReader(
52+
func (sf *StateFactory) NewStateHistory(
5353
stateRoot *felt.Felt,
5454
txn db.IndexedBatch,
5555
blockNumber uint64,
@@ -81,6 +81,20 @@ func (sf *StateFactory) EmptyState() (core.StateReader, error) {
8181
return state, nil
8282
}
8383

84+
// NewStateReader returns a read-only view of the state at the given root.
85+
// Use this for operations that only need to read from the current (head) state,
86+
// such as HeadState or computing reverse diffs.
87+
// Returns core.State (not just core.StateReader) so callers can access Commitment.
88+
func (sf *StateFactory) NewStateReader(
89+
stateRoot *felt.Felt,
90+
txn db.IndexedBatch,
91+
) (core.State, error) {
92+
if !sf.useNewState {
93+
return core.NewDeprecatedState(txn), nil
94+
}
95+
return state.NewStateReader(stateRoot, sf.stateDB)
96+
}
97+
8498
func (sf *StateFactory) UseNewState() bool {
8599
return sf.useNewState
86100
}

core/state/statefactory/state_factory_test.go

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,20 +46,39 @@ func TestStateFactory_NewState(t *testing.T) {
4646
})
4747
}
4848

49+
func TestStateFactory_NewStateHistory(t *testing.T) {
50+
t.Run("deprecated", func(t *testing.T) {
51+
sf := newTestFactory(t, false)
52+
txn := memory.New().NewIndexedBatch()
53+
54+
reader, err := sf.NewStateHistory(&felt.Zero, txn, 0)
55+
require.NoError(t, err)
56+
assert.NotNil(t, reader)
57+
})
58+
59+
t.Run("new impl", func(t *testing.T) {
60+
sf := newTestFactory(t, true)
61+
62+
reader, err := sf.NewStateHistory(&felt.Zero, nil, 0)
63+
require.NoError(t, err)
64+
assert.NotNil(t, reader)
65+
})
66+
}
67+
4968
func TestStateFactory_NewStateReader(t *testing.T) {
5069
t.Run("deprecated", func(t *testing.T) {
5170
sf := newTestFactory(t, false)
5271
txn := memory.New().NewIndexedBatch()
5372

54-
reader, err := sf.NewStateReader(&felt.Zero, txn, 0)
73+
reader, err := sf.NewStateReader(&felt.Zero, txn)
5574
require.NoError(t, err)
5675
assert.NotNil(t, reader)
5776
})
5877

5978
t.Run("new impl", func(t *testing.T) {
6079
sf := newTestFactory(t, true)
6180

62-
reader, err := sf.NewStateReader(&felt.Zero, nil, 0)
81+
reader, err := sf.NewStateReader(&felt.Zero, nil)
6382
require.NoError(t, err)
6483
assert.NotNil(t, reader)
6584
})

0 commit comments

Comments
 (0)