22# Licensed under the Apache 2.0 License.
33
44from hashlib import sha256
5- import math
5+ import struct
66
77
88class MerkleTree :
@@ -16,6 +16,7 @@ def __init__(self):
1616 def reset_tree (self ):
1717 self ._levels = [[]]
1818 self ._root = None
19+ self ._num_flushed = 0
1920
2021 def add_leaf (self , values : bytes , do_hash = True ):
2122 digest = values
@@ -45,46 +46,141 @@ def get_merkle_root(self) -> bytes:
4546
4647 return self ._root
4748
48- def _recalculate_level (self , level ):
49- assert len (self ._levels ) > level - 1
50- prev_level = self ._levels [level - 1 ]
51- number_of_leaves_on_prev_level = len (prev_level )
52-
53- assert (
54- number_of_leaves_on_prev_level > 1
55- ), "Merkle Tree should have more than one leaf at every level"
49+ def _recalculate_level (self , prev_level , current_level ):
50+ """
51+ Compute the next level of hashes from the previous level.
52+ Reuses already-computed hashes where possible.
5653
57- solo_leaf = None
54+ Args:
55+ prev_level: List of hashes from the previous (lower) level
56+ current_level: List of already-computed hashes at this level
5857
59- if (
60- number_of_leaves_on_prev_level % 2 == 1
61- ): # if odd number of leaves on the level
62- # Get the solo leaf (last leaf in-case the leaves are odd numbered)
63- solo_leaf = prev_level [- 1 ]
64- number_of_leaves_on_prev_level -= 1
65-
66- if not len (self ._levels ) > level :
67- self ._levels .append ([])
68-
69- # Reuse existing level as much as possible
70- current_level = self ._levels [level ]
71-
72- # Since we may have copied a solo-leaf to the rightmost node last time, pop and re-calculate it
73- if len (current_level ):
58+ Returns:
59+ Updated list of computed hashes for this level
60+ """
61+ # Handle solo leaf: if last entry was a promoted solo, pop it for recalc
62+ if current_level :
7463 current_level .pop (- 1 )
7564
65+ # Determine how many pairs are already computed
7666 done = len (current_level )
7767
68+ # Handle odd count on input level
69+ number_of_leaves_on_prev_level = len (prev_level )
70+ solo_leaf = None
71+ if number_of_leaves_on_prev_level % 2 == 1 :
72+ solo_leaf = prev_level [- 1 ]
73+ number_of_leaves_on_prev_level -= 1
74+
75+ # Compute new pairs starting after 'done' existing pairs
7876 for left_node , right_node in zip (
7977 prev_level [done * 2 : number_of_leaves_on_prev_level : 2 ],
8078 prev_level [done * 2 + 1 : number_of_leaves_on_prev_level : 2 ],
8179 ):
8280 current_level .append (sha256 (left_node + right_node ).digest ())
81+
8382 if solo_leaf is not None :
8483 current_level .append (solo_leaf )
8584
85+ return current_level
86+
8687 def _make_tree (self ):
87- if self .get_leaf_count () > 0 :
88- num_levels = 1 + math .ceil (math .log (self .get_leaf_count (), 2 ))
89- for level in range (1 , num_levels ):
90- self ._recalculate_level (level )
88+ if self .get_leaf_count () == 0 :
89+ return
90+
91+ # Build tree from leaves. After deserialize, _levels[i] contains:
92+ # - Flushed hash at [0] if bit i of num_flushed is set
93+ # - Followed by any previously computed hashes
94+ # We read from _levels[level_idx] and write computed hashes to _levels[level_idx+1].
95+ it = self ._num_flushed
96+ level_idx = 0
97+
98+ while len (self ._levels [level_idx ]) > 1 or it != 0 :
99+ prev_level = self ._levels [level_idx ]
100+
101+ # Ensure next level exists
102+ if level_idx + 1 >= len (self ._levels ):
103+ self ._levels .append ([])
104+
105+ # Check if next level has a flushed hash at [0] that we must preserve
106+ next_level = self ._levels [level_idx + 1 ]
107+ next_has_flushed = (it >> 1 ) & 0x01 and next_level
108+
109+ # Compute next level, reusing hashes after the flushed one
110+ skip = 1 if next_has_flushed else 0
111+ computed = self ._recalculate_level (prev_level , next_level [skip :])
112+
113+ # Store result, preserving flushed hash at [0] if present
114+ if next_has_flushed :
115+ self ._levels [level_idx + 1 ] = [next_level [0 ]] + computed
116+ else :
117+ self ._levels [level_idx + 1 ] = computed
118+
119+ it >>= 1
120+ level_idx += 1
121+
122+ def deserialise (self , buffer : bytes , position : int = 0 ) -> int :
123+ """
124+ Deserialise a compact merkle tree representation.
125+
126+ Format (big-endian):
127+ [uint64_t] num_leaf_nodes - Count of leaf nodes in this serialisation
128+ [uint64_t] num_flushed - Count of flushed (pruned) leaves
129+ [hash...] leaf_hashes - Hash data for leaf nodes (32 bytes each)
130+ [hash...] flushed_hashes - Roots of flushed subtrees on the left edge
131+
132+ Args:
133+ buffer: The byte buffer containing the serialised tree
134+ position: Starting position in the buffer (default: 0)
135+
136+ Returns:
137+ The new position in the buffer after deserialisation
138+ """
139+ HASH_SIZE = 32 # SHA-256 hash size
140+
141+ # Helper function to read bytes and advance position
142+ def read_bytes (pos : int , size : int ) -> tuple [bytes , int ]:
143+ """Read size bytes from buffer at pos, return (data, new_pos)"""
144+ if len (buffer ) < pos + size :
145+ raise ValueError (
146+ f"Buffer too small: need { pos + size } bytes, have { len (buffer )} "
147+ )
148+ return buffer [pos : pos + size ], pos + size
149+
150+ # Reset the tree
151+ self .reset_tree ()
152+
153+ # Parse header - big-endian uint64_t values
154+ uint64_data , position = read_bytes (position , 8 )
155+ num_leaf_nodes = struct .unpack (">Q" , uint64_data )[0 ]
156+
157+ uint64_data , position = read_bytes (position , 8 )
158+ self ._num_flushed = struct .unpack (">Q" , uint64_data )[0 ]
159+
160+ # Read leaf hashes into _levels[0]
161+ for _ in range (num_leaf_nodes ):
162+ leaf_hash , position = read_bytes (position , HASH_SIZE )
163+ self ._levels [0 ].append (leaf_hash )
164+
165+ # Read flushed subtree hashes into their conceptual levels.
166+ # Bit i of num_flushed indicates a flushed subtree of size 2^i,
167+ # whose root is at level i (for i>0) or a single leaf at level 0 (i=0).
168+ it = self ._num_flushed
169+ level = 0
170+
171+ while it != 0 :
172+ if it & 0x01 :
173+ flushed_hash , position = read_bytes (position , HASH_SIZE )
174+ if level == 0 :
175+ # Flushed leaf - insert at beginning of level 0
176+ self ._levels [0 ].insert (0 , flushed_hash )
177+ else :
178+ # Ensure level exists
179+ while len (self ._levels ) <= level :
180+ self ._levels .append ([])
181+ # Store flushed hash at its conceptual level
182+ self ._levels [level ] = [flushed_hash ]
183+ level += 1
184+ it >>= 1
185+
186+ return position
0 commit comments