@@ -6,8 +6,10 @@ import (
66
77 "github.com/NethermindEth/juno/core"
88 "github.com/NethermindEth/juno/core/felt"
9+ "github.com/NethermindEth/juno/core/state"
10+ "github.com/NethermindEth/juno/core/state/statefactory"
11+ "github.com/NethermindEth/juno/core/trie2/triedb"
912 "github.com/NethermindEth/juno/db"
10- "github.com/NethermindEth/juno/db/memory"
1113 "github.com/NethermindEth/juno/feed"
1214 "github.com/NethermindEth/juno/utils"
1315 "github.com/ethereum/go-ethereum/common"
@@ -19,6 +21,7 @@ type L1HeadSubscription struct {
1921
2022//go:generate mockgen -destination=../mocks/mock_blockchain.go -package=mocks github.com/NethermindEth/juno/blockchain Reader
2123type Reader interface {
24+ StateProvider
2225 Height () (height uint64 , err error )
2326
2427 Head () (head * core.Block , err error )
@@ -51,10 +54,6 @@ type Reader interface {
5154 StateUpdateByHash (hash * felt.Felt ) (update * core.StateUpdate , err error )
5255 L1HandlerTxnHash (msgHash * common.Hash ) (l1HandlerTxnHash felt.Felt , err error )
5356
54- HeadState () (core.StateReader , StateCloser , error )
55- StateAtBlockHash (blockHash * felt.Felt ) (core.StateReader , StateCloser , error )
56- StateAtBlockNumber (blockNumber uint64 ) (core.StateReader , StateCloser , error )
57-
5857 BlockCommitmentsByNumber (blockNumber uint64 ) (* core.BlockCommitments , error )
5958
6059 EventFilter (
@@ -66,6 +65,12 @@ type Reader interface {
6665 Network () * utils.Network
6766}
6867
68+ type StateProvider interface {
69+ HeadState () (core.StateReader , StateCloser , error )
70+ StateAtBlockHash (blockHash * felt.Felt ) (core.StateReader , StateCloser , error )
71+ StateAtBlockNumber (blockNumber uint64 ) (core.StateReader , StateCloser , error )
72+ }
73+
6974var ErrParentDoesNotMatchHead = errors .New ("block's parent hash does not match head block hash" )
7075
7176var _ Reader = (* Blockchain )(nil )
@@ -74,14 +79,24 @@ var _ Reader = (*Blockchain)(nil)
7479type Blockchain struct {
7580 network * utils.Network
7681 database db.KeyValueStore
82+ stateDB * state.StateDB
7783 listener EventListener
7884 l1HeadFeed * feed.Feed [* core.L1Head ]
7985 cachedFilters * AggregatedBloomFilterCache
8086 runningFilter * core.RunningEventFilter
8187 transactionLayout core.TransactionLayout
88+ StateFactory * statefactory.StateFactory
8289}
8390
84- func New (database db.KeyValueStore , network * utils.Network ) * Blockchain {
91+ func New (database db.KeyValueStore , network * utils.Network , stateVersion bool ) * Blockchain {
92+ trieDB , err := triedb .New (database , nil ) // TODO: handle hashdb and pathdb
93+ if err != nil {
94+ panic (err )
95+ }
96+ stateDB := state .NewStateDB (database , trieDB )
97+
98+ stateFactory := statefactory .NewStateFactory (stateVersion , trieDB , stateDB )
99+
85100 cachedFilters := NewAggregatedBloomCache (AggregatedBloomFilterCacheSize )
86101 fallback := func (key EventFiltersCacheKey ) (core.AggregatedBloomFilter , error ) {
87102 return core .GetAggregatedBloomFilter (database , key .fromBlock , key .toBlock )
@@ -92,12 +107,14 @@ func New(database db.KeyValueStore, network *utils.Network) *Blockchain {
92107
93108 return & Blockchain {
94109 database : database ,
110+ stateDB : stateDB ,
95111 network : network ,
96112 listener : & SelectiveListener {},
97113 l1HeadFeed : feed .New [* core.L1Head ](),
98114 cachedFilters : & cachedFilters ,
99115 runningFilter : runningFilter ,
100116 transactionLayout : core .TransactionLayoutPerTx , // default to per-tx for backward compatibility
117+ StateFactory : stateFactory ,
101118 }
102119}
103120
@@ -295,8 +312,7 @@ func (b *Blockchain) SubscribeL1Head() L1HeadSubscription {
295312
296313func (b * Blockchain ) L1Head () (core.L1Head , error ) {
297314 b .listener .OnRead ("L1Head" )
298- l1Head , err := core .GetL1Head (b .database )
299- return l1Head , err
315+ return core .GetL1Head (b .database )
300316}
301317
302318func (b * Blockchain ) SetL1Head (update * core.L1Head ) error {
@@ -310,6 +326,21 @@ func (b *Blockchain) Store(
310326 blockCommitments * core.BlockCommitments ,
311327 stateUpdate * core.StateUpdate ,
312328 newClasses map [felt.Felt ]core.ClassDefinition ,
329+ ) error {
330+ // old state
331+ // TODO(maksymmalick): remove this once we have a new state implementation
332+ if ! b .StateFactory .UseNewState () {
333+ return b .deprecatedStore (block , blockCommitments , stateUpdate , newClasses )
334+ }
335+
336+ return b .store (block , blockCommitments , stateUpdate , newClasses )
337+ }
338+
339+ func (b * Blockchain ) deprecatedStore (
340+ block * core.Block ,
341+ blockCommitments * core.BlockCommitments ,
342+ stateUpdate * core.StateUpdate ,
343+ newClasses map [felt.Felt ]core.ClassDefinition ,
313344) error {
314345 err := b .database .Update (func (txn db.IndexedBatch ) error {
315346 if err := verifyBlock (txn , block ); err != nil {
@@ -346,6 +377,7 @@ func (b *Blockchain) Store(
346377 }
347378
348379 err = storeCasmHashMetadata (
380+ b .database ,
349381 txn ,
350382 block .Number ,
351383 block .ProtocolVersion ,
@@ -371,7 +403,8 @@ func (b *Blockchain) Store(
371403// storeCasmHashMetadata stores CASM hash metadata for declared and migrated classes.
372404// See [core.ClassCasmHashMetadata]
373405func storeCasmHashMetadata (
374- txn db.IndexedBatch ,
406+ reader db.KeyValueReader ,
407+ writer db.KeyValueWriter ,
375408 blockNumber uint64 ,
376409 protocolVersion string ,
377410 stateUpdate * core.StateUpdate ,
@@ -385,16 +418,17 @@ func storeCasmHashMetadata(
385418 isV2Protocol := ver .GreaterThanEqual (core .Ver0_14_1 )
386419
387420 if isV2Protocol {
388- return storeCasmHashMetadataV2 (txn , blockNumber , stateUpdate )
421+ return storeCasmHashMetadataV2 (reader , writer , blockNumber , stateUpdate )
389422 }
390423
391- return storeCasmHashMetadataV1 (txn , blockNumber , stateUpdate , newClasses )
424+ return storeCasmHashMetadataV1 (writer , blockNumber , stateUpdate , newClasses )
392425}
393426
394427// storeCasmHashMetadataV2 stores metadata for classes declared with casm hash v2 or
395428// migrated from v1. casm hash v2 is after protocol version >= 0.14.1.
396429func storeCasmHashMetadataV2 (
397- txn db.IndexedBatch ,
430+ reader db.KeyValueReader ,
431+ writer db.KeyValueWriter ,
398432 blockNumber uint64 ,
399433 stateUpdate * core.StateUpdate ,
400434) error {
@@ -404,7 +438,7 @@ func storeCasmHashMetadataV2(
404438 (* felt .CasmClassHash )(casmHash ),
405439 )
406440 err := core .WriteClassCasmHashMetadata (
407- txn ,
441+ writer ,
408442 (* felt .SierraClassHash )(& sierraClassHash ),
409443 & metadata ,
410444 )
@@ -414,7 +448,7 @@ func storeCasmHashMetadataV2(
414448 }
415449
416450 for sierraClassHash := range stateUpdate .StateDiff .MigratedClasses {
417- metadata , err := core .GetClassCasmHashMetadata (txn , & sierraClassHash )
451+ metadata , err := core .GetClassCasmHashMetadata (reader , & sierraClassHash )
418452 if err != nil {
419453 return fmt .Errorf ("cannot migrate class %s: metadata not found" ,
420454 sierraClassHash .String (),
@@ -429,7 +463,7 @@ func storeCasmHashMetadataV2(
429463 )
430464 }
431465
432- err = core .WriteClassCasmHashMetadata (txn , & sierraClassHash , & metadata )
466+ err = core .WriteClassCasmHashMetadata (writer , & sierraClassHash , & metadata )
433467 if err != nil {
434468 return err
435469 }
@@ -440,7 +474,7 @@ func storeCasmHashMetadataV2(
440474// storeDeclaredV1Classes stores metadata for classes declared with V1 hash (protocol < 0.14.1).
441475// It computes the V2 hash from the class definition.
442476func storeCasmHashMetadataV1 (
443- txn db.IndexedBatch ,
477+ writer db.KeyValueWriter ,
444478 blockNumber uint64 ,
445479 stateUpdate * core.StateUpdate ,
446480 newClasses map [felt.Felt ]core.ClassDefinition ,
@@ -469,7 +503,7 @@ func storeCasmHashMetadataV1(
469503
470504 metadata := core .NewCasmHashMetadataDeclaredV1 (blockNumber , casmHashV1 , & casmHashV2 )
471505 err := core .WriteClassCasmHashMetadata (
472- txn ,
506+ writer ,
473507 (* felt .SierraClassHash )(& sierraClassHash ),
474508 & metadata ,
475509 )
@@ -543,7 +577,7 @@ func (b *Blockchain) HeadState() (core.StateReader, StateCloser, error) {
543577 b .listener .OnRead ("HeadState" )
544578 txn := b .database .NewIndexedBatch ()
545579
546- _ , err := core .GetChainHeight (txn )
580+ height , err := core .GetChainHeight (txn )
547581 if err != nil {
548582 return nil , nil , err
549583 }
@@ -564,7 +598,25 @@ func (b *Blockchain) StateAtBlockNumber(
564598 return nil , nil , err
565599 }
566600
567- return core .NewDeprecatedStateHistory (core .NewState (txn ), blockNumber ), noopStateCloser , nil
601+ if ! b .StateFactory .UseNewState () {
602+ return core .NewDeprecatedStateHistory (core .NewState (txn ), blockNumber ), noopStateCloser , nil
603+ }
604+
605+ height , err := core .GetChainHeight (txn )
606+ if err != nil {
607+ return nil , nil , err
608+ }
609+
610+ header , err := core .GetBlockHeaderByNumber (txn , height )
611+ if err != nil {
612+ return nil , nil , err
613+ }
614+
615+ history , err := state .NewStateHistory (blockNumber , header .GlobalStateRoot , b .stateDB )
616+ if err != nil {
617+ return nil , nil , err
618+ }
619+ return & history , noopStateCloser , nil
568620}
569621
570622// StateAtBlockHash returns a StateReader that provides
@@ -574,19 +626,34 @@ func (b *Blockchain) StateAtBlockHash(
574626) (core.StateReader , StateCloser , error ) {
575627 b .listener .OnRead ("StateAtBlockHash" )
576628 if blockHash .IsZero () {
577- memDB := memory .New ()
578- txn := memDB .NewIndexedBatch ()
579- emptyState := core .NewState (txn )
580- return emptyState , noopStateCloser , nil
629+ emptyState , err := b .StateFactory .EmptyState ()
630+ return emptyState , noopStateCloser , err
581631 }
582632
583633 txn := b .database .NewIndexedBatch ()
584634 header , err := core .GetBlockHeaderByHash (txn , blockHash )
585635 if err != nil {
586636 return nil , nil , err
587637 }
638+ if ! b .StateFactory .UseNewState () {
639+ return core .NewDeprecatedStateHistory (core .NewState (txn ), header .Number ), noopStateCloser , nil
640+ }
641+
642+ height , err := core .GetChainHeight (txn )
643+ if err != nil {
644+ return nil , nil , err
645+ }
646+
647+ headHeader , err := core .GetBlockHeaderByNumber (txn , height )
648+ if err != nil {
649+ return nil , nil , err
650+ }
588651
589- return core .NewDeprecatedStateHistory (core .NewState (txn ), header .Number ), noopStateCloser , nil
652+ history , err := state .NewStateHistory (header .Number , headHeader .GlobalStateRoot , b .stateDB )
653+ if err != nil {
654+ return nil , nil , err
655+ }
656+ return & history , noopStateCloser , nil
590657}
591658
592659// EventFilter returns an EventFilter object that is tied to a snapshot of the blockchain
@@ -616,10 +683,22 @@ func (b *Blockchain) EventFilter(
616683
617684// RevertHead reverts the head block
618685func (b * Blockchain ) RevertHead () error {
619- return b .database .Update (b .revertHead )
686+ if ! b .StateFactory .UseNewState () {
687+ return b .database .Update (b .deprecatedRevertHead )
688+ }
689+ return b .database .Write (b .revertHead )
620690}
621691
622692func (b * Blockchain ) GetReverseStateDiff () (core.StateDiff , error ) {
693+ if ! b .StateFactory .UseNewState () {
694+ return b .deprecatedGetReverseStateDiff ()
695+ }
696+
697+ return b .getReverseStateDiff ()
698+ }
699+
700+ // TODO(maksymmalick): remove this once we have a new state integrated
701+ func (b * Blockchain ) deprecatedGetReverseStateDiff () (core.StateDiff , error ) {
623702 txn := b .database .NewIndexedBatch ()
624703 blockNum , err := core .GetChainHeight (txn )
625704 if err != nil {
@@ -675,7 +754,7 @@ func (b *Blockchain) revertHead(txn db.IndexedBatch) error {
675754 }
676755 }
677756
678- if err := b .transactionLayout .DeleteTxsAndReceipts (txn , blockNumber ); err != nil {
757+ if err := b .transactionLayout .DeleteTxsAndReceipts (b . database , txn , blockNumber ); err != nil {
679758 return err
680759 }
681760
@@ -788,6 +867,7 @@ func (b *Blockchain) updateStateRoots(
788867 block * core.Block ,
789868 stateUpdate * core.StateUpdate ,
790869 newClasses map [felt.Felt ]core.ClassDefinition ,
870+ flushChanges bool ,
791871) error {
792872 state := core .NewState (txn )
793873
@@ -799,7 +879,7 @@ func (b *Blockchain) updateStateRoots(
799879 stateUpdate .OldRoot = & oldStateRoot
800880
801881 // Apply state update
802- if err = state .Update (block .Number , stateUpdate , newClasses , true ); err != nil {
882+ if err = state .Update (block .Number , stateUpdate , newClasses , true , flushChanges ); err != nil {
803883 return err
804884 }
805885
@@ -853,7 +933,7 @@ func (b *Blockchain) signBlock(
853933
854934// storeBlockData persists all block-related data to the database
855935func (b * Blockchain ) storeBlockData (
856- txn db.IndexedBatch ,
936+ txn db.KeyValueWriter ,
857937 block * core.Block ,
858938 stateUpdate * core.StateUpdate ,
859939 commitments * core.BlockCommitments ,
0 commit comments