Skip to content

Commit d532ab4

Browse files
CopiloteddyashtonCopilot
authored
Add Python deserialisation for Merkle tree mini-trees (#7676)
Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: eddyashton <6000239+eddyashton@users.noreply.github.com> Co-authored-by: Eddy Ashton <edashton@microsoft.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
1 parent c4f6692 commit d532ab4

2 files changed

Lines changed: 169 additions & 29 deletions

File tree

python/src/ccf/merkletree.py

Lines changed: 125 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
# Licensed under the Apache 2.0 License.
33

44
from hashlib import sha256
5-
import math
5+
import struct
66

77

88
class MerkleTree:
@@ -16,6 +16,7 @@ def __init__(self):
1616
def reset_tree(self):
1717
self._levels = [[]]
1818
self._root = None
19+
self._num_flushed = 0
1920

2021
def add_leaf(self, values: bytes, do_hash=True):
2122
digest = values
@@ -45,46 +46,141 @@ def get_merkle_root(self) -> bytes:
4546

4647
return self._root
4748

48-
def _recalculate_level(self, level):
49-
assert len(self._levels) > level - 1
50-
prev_level = self._levels[level - 1]
51-
number_of_leaves_on_prev_level = len(prev_level)
52-
53-
assert (
54-
number_of_leaves_on_prev_level > 1
55-
), "Merkle Tree should have more than one leaf at every level"
49+
def _recalculate_level(self, prev_level, current_level):
50+
"""
51+
Compute the next level of hashes from the previous level.
52+
Reuses already-computed hashes where possible.
5653
57-
solo_leaf = None
54+
Args:
55+
prev_level: List of hashes from the previous (lower) level
56+
current_level: List of already-computed hashes at this level
5857
59-
if (
60-
number_of_leaves_on_prev_level % 2 == 1
61-
): # if odd number of leaves on the level
62-
# Get the solo leaf (last leaf in-case the leaves are odd numbered)
63-
solo_leaf = prev_level[-1]
64-
number_of_leaves_on_prev_level -= 1
65-
66-
if not len(self._levels) > level:
67-
self._levels.append([])
68-
69-
# Reuse existing level as much as possible
70-
current_level = self._levels[level]
71-
72-
# Since we may have copied a solo-leaf to the rightmost node last time, pop and re-calculate it
73-
if len(current_level):
58+
Returns:
59+
Updated list of computed hashes for this level
60+
"""
61+
# Handle solo leaf: if last entry was a promoted solo, pop it for recalc
62+
if current_level:
7463
current_level.pop(-1)
7564

65+
# Determine how many pairs are already computed
7666
done = len(current_level)
7767

68+
# Handle odd count on input level
69+
number_of_leaves_on_prev_level = len(prev_level)
70+
solo_leaf = None
71+
if number_of_leaves_on_prev_level % 2 == 1:
72+
solo_leaf = prev_level[-1]
73+
number_of_leaves_on_prev_level -= 1
74+
75+
# Compute new pairs starting after 'done' existing pairs
7876
for left_node, right_node in zip(
7977
prev_level[done * 2 : number_of_leaves_on_prev_level : 2],
8078
prev_level[done * 2 + 1 : number_of_leaves_on_prev_level : 2],
8179
):
8280
current_level.append(sha256(left_node + right_node).digest())
81+
8382
if solo_leaf is not None:
8483
current_level.append(solo_leaf)
8584

85+
return current_level
86+
8687
def _make_tree(self):
87-
if self.get_leaf_count() > 0:
88-
num_levels = 1 + math.ceil(math.log(self.get_leaf_count(), 2))
89-
for level in range(1, num_levels):
90-
self._recalculate_level(level)
88+
if self.get_leaf_count() == 0:
89+
return
90+
91+
# Build tree from leaves. After deserialize, _levels[i] contains:
92+
# - Flushed hash at [0] if bit i of num_flushed is set
93+
# - Followed by any previously computed hashes
94+
# We read from _levels[level_idx] and write computed hashes to _levels[level_idx+1].
95+
it = self._num_flushed
96+
level_idx = 0
97+
98+
while len(self._levels[level_idx]) > 1 or it != 0:
99+
prev_level = self._levels[level_idx]
100+
101+
# Ensure next level exists
102+
if level_idx + 1 >= len(self._levels):
103+
self._levels.append([])
104+
105+
# Check if next level has a flushed hash at [0] that we must preserve
106+
next_level = self._levels[level_idx + 1]
107+
next_has_flushed = (it >> 1) & 0x01 and next_level
108+
109+
# Compute next level, reusing hashes after the flushed one
110+
skip = 1 if next_has_flushed else 0
111+
computed = self._recalculate_level(prev_level, next_level[skip:])
112+
113+
# Store result, preserving flushed hash at [0] if present
114+
if next_has_flushed:
115+
self._levels[level_idx + 1] = [next_level[0]] + computed
116+
else:
117+
self._levels[level_idx + 1] = computed
118+
119+
it >>= 1
120+
level_idx += 1
121+
122+
def deserialise(self, buffer: bytes, position: int = 0) -> int:
123+
"""
124+
Deserialise a compact merkle tree representation.
125+
126+
Format (big-endian):
127+
[uint64_t] num_leaf_nodes - Count of leaf nodes in this serialisation
128+
[uint64_t] num_flushed - Count of flushed (pruned) leaves
129+
[hash...] leaf_hashes - Hash data for leaf nodes (32 bytes each)
130+
[hash...] flushed_hashes - Roots of flushed subtrees on the left edge
131+
132+
Args:
133+
buffer: The byte buffer containing the serialised tree
134+
position: Starting position in the buffer (default: 0)
135+
136+
Returns:
137+
The new position in the buffer after deserialisation
138+
"""
139+
HASH_SIZE = 32 # SHA-256 hash size
140+
141+
# Helper function to read bytes and advance position
142+
def read_bytes(pos: int, size: int) -> tuple[bytes, int]:
143+
"""Read size bytes from buffer at pos, return (data, new_pos)"""
144+
if len(buffer) < pos + size:
145+
raise ValueError(
146+
f"Buffer too small: need {pos + size} bytes, have {len(buffer)}"
147+
)
148+
return buffer[pos : pos + size], pos + size
149+
150+
# Reset the tree
151+
self.reset_tree()
152+
153+
# Parse header - big-endian uint64_t values
154+
uint64_data, position = read_bytes(position, 8)
155+
num_leaf_nodes = struct.unpack(">Q", uint64_data)[0]
156+
157+
uint64_data, position = read_bytes(position, 8)
158+
self._num_flushed = struct.unpack(">Q", uint64_data)[0]
159+
160+
# Read leaf hashes into _levels[0]
161+
for _ in range(num_leaf_nodes):
162+
leaf_hash, position = read_bytes(position, HASH_SIZE)
163+
self._levels[0].append(leaf_hash)
164+
165+
# Read flushed subtree hashes into their conceptual levels.
166+
# Bit i of num_flushed indicates a flushed subtree of size 2^i,
167+
# whose root is at level i (for i>0) or a single leaf at level 0 (i=0).
168+
it = self._num_flushed
169+
level = 0
170+
171+
while it != 0:
172+
if it & 0x01:
173+
flushed_hash, position = read_bytes(position, HASH_SIZE)
174+
if level == 0:
175+
# Flushed leaf - insert at beginning of level 0
176+
self._levels[0].insert(0, flushed_hash)
177+
else:
178+
# Ensure level exists
179+
while len(self._levels) <= level:
180+
self._levels.append([])
181+
# Store flushed hash at its conceptual level
182+
self._levels[level] = [flushed_hash]
183+
level += 1
184+
it >>= 1
185+
186+
return position

tests/e2e_operations.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
import ccf.read_ledger
3636
import re
3737
import hashlib
38+
from ccf.merkletree import MerkleTree
3839

3940
from loguru import logger as LOG
4041

@@ -2414,11 +2415,54 @@ def run_read_ledger_on_testdata(args):
24142415
committed_only=False,
24152416
read_recovery_files=False,
24162417
)
2418+
2419+
# Validate merkle tree deserialization
2420+
# Maintain a merkle tree by adding leaves, and compare with deserialized trees
2421+
accumulated_tree = MerkleTree()
2422+
trees_validated = 0
2423+
2424+
# Start with empty bytes array. CCF MerkleTree uses an empty array as the first leaf of its merkle tree.
2425+
empty_bytes_array = bytearray(ccf.ledger.SHA256_DIGEST_SIZE)
2426+
accumulated_tree.add_leaf(empty_bytes_array, do_hash=False)
2427+
24172428
for chunk in ledger:
24182429
for tx in chunk:
24192430
tables = tx.get_public_domain().get_tables()
24202431
tx_count += 1
2432+
2433+
# Check if this transaction has a serialized merkle tree
2434+
if "public:ccf.internal.tree" in tables:
2435+
tree_table = tables["public:ccf.internal.tree"]
2436+
if ccf.ledger.WELL_KNOWN_SINGLETON_TABLE_KEY in tree_table:
2437+
tree_data = tree_table[
2438+
ccf.ledger.WELL_KNOWN_SINGLETON_TABLE_KEY
2439+
]
2440+
2441+
# Deserialize the tree from the transaction
2442+
deserialized_tree = MerkleTree()
2443+
deserialized_tree.deserialise(tree_data)
2444+
2445+
# Compare roots: the accumulated tree should match the deserialized tree
2446+
accumulated_root = accumulated_tree.get_merkle_root()
2447+
deserialized_root = deserialized_tree.get_merkle_root()
2448+
2449+
if accumulated_root != deserialized_root:
2450+
raise ValueError(
2451+
f"Merkle tree mismatch in {testdata_path} at tx {tx_count}: "
2452+
f"accumulated={accumulated_root.hex() if accumulated_root else 'None'}, "
2453+
f"deserialized={deserialized_root.hex() if deserialized_root else 'None'}"
2454+
)
2455+
2456+
trees_validated += 1
2457+
2458+
# Add transaction to accumulated tree
2459+
# Transaction leaves are the transaction digest
2460+
accumulated_tree.add_leaf(tx.get_tx_digest(), do_hash=False)
2461+
24212462
LOG.info(f"Read {tx_count} transactions from {testdata_path}")
2463+
if trees_validated > 0:
2464+
LOG.info(f"Validated {trees_validated} merkle tree deserializations")
2465+
24222466
snapshot_path = os.path.join(
24232467
args.historical_testdata, testdata_dir.name, "snapshots"
24242468
)

0 commit comments

Comments
 (0)