Skip to content

Commit 36890d8

Browse files
refactor: drop num_moe_layers/top_k from r3/v1 deserializer
Mirrors the gateway-side r3_serializer change: the per-token matrix shape (num_moe_layers, top_k) is no longer required and is no longer written into the r3/v1 binary header. Per-token matrix byte size is recovered as matrix_byte_length / replayed_token_count. - HEADER_FORMAT: "<4sBBBBIIHHIIQ" (36 bytes) -> "<4sBBBBIIIIQ" (32 bytes). - Drop num_moe_layers/top_k from _parse_header() and the metadata dict returned by decompress_and_parse_r3(). - Compute matrix_elem_size from matrix_byte_length / replayed_token_count with a divisibility check that surfaces malformed payloads early. - Update unit tests to use matrix_elem_size as the parameter and drop assertions on the removed header fields; round-trip test no longer passes num_moe_layers/top_k to RouterReplayData. Co-authored-by: Cursor <cursoragent@cursor.com>
1 parent e769ac1 commit 36890d8

2 files changed

Lines changed: 35 additions & 57 deletions

File tree

eval_protocol/adapters/r3_deserializer.py

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515

1616
import base64
1717
import logging
18-
import math
1918
import struct
2019
from enum import IntEnum
2120
from typing import Any, Dict, List, Optional, Tuple
@@ -25,8 +24,8 @@
2524
logger = logging.getLogger(__name__)
2625

2726
MAGIC = b"R3V1"
28-
HEADER_FORMAT = "<4sBBBBIIHHIIQ"
29-
HEADER_SIZE = struct.calcsize(HEADER_FORMAT) # 36 bytes
27+
HEADER_FORMAT = "<4sBBBBIIIIQ"
28+
HEADER_SIZE = struct.calcsize(HEADER_FORMAT) # 32 bytes
3029

3130

3231
class _SelectorMode(IntEnum):
@@ -62,8 +61,6 @@ def _parse_header(raw: bytes) -> Dict[str, Any]:
6261
flags,
6362
total_token_count,
6463
replayed_token_count,
65-
num_moe_layers,
66-
top_k,
6764
replay_start_token,
6865
selector_byte_length,
6966
matrix_byte_length,
@@ -80,8 +77,6 @@ def _parse_header(raw: bytes) -> Dict[str, Any]:
8077
"flags": flags,
8178
"total_token_count": total_token_count,
8279
"replayed_token_count": replayed_token_count,
83-
"num_moe_layers": num_moe_layers,
84-
"top_k": top_k,
8580
"replay_start_token": replay_start_token,
8681
"selector_byte_length": selector_byte_length,
8782
"matrix_byte_length": matrix_byte_length,
@@ -117,9 +112,9 @@ def decompress_and_parse_r3(
117112
``total_token_count``. Each present position contains a
118113
base64-encoded routing matrix (matching the format returned by
119114
the direct inference path); absent positions are ``None``.
120-
- ``metadata`` is a dict with keys ``num_moe_layers``, ``top_k``,
121-
``routing_dtype``, ``selector_mode``, ``total_token_count``,
122-
``replayed_token_count``, ``replay_start_token``.
115+
- ``metadata`` is a dict with keys ``routing_dtype``,
116+
``selector_mode``, ``total_token_count``, ``replayed_token_count``,
117+
``replay_start_token``.
123118
"""
124119
compressed = base64.b64decode(data_b64)
125120

@@ -132,14 +127,23 @@ def decompress_and_parse_r3(
132127
routing_dtype = header["routing_dtype"]
133128
total_token_count = header["total_token_count"]
134129
replayed_token_count = header["replayed_token_count"]
135-
num_moe_layers = header["num_moe_layers"]
136-
top_k = header["top_k"]
137130
replay_start_token = header["replay_start_token"]
138131
selector_byte_length = header["selector_byte_length"]
139132
matrix_byte_length = header["matrix_byte_length"]
140133

141-
dtype_byte_width = _RoutingDtype(routing_dtype).byte_width
142-
matrix_elem_size = num_moe_layers * top_k * dtype_byte_width
134+
# Per-token matrix byte size is implicit in the payload: all replayed
135+
# tokens share the same matrix length, so we can recover it from the
136+
# matrix section total length divided by the replayed-token count.
137+
if replayed_token_count > 0:
138+
if matrix_byte_length % replayed_token_count != 0:
139+
raise ValueError(
140+
f"matrix_byte_length ({matrix_byte_length}) is not a multiple of "
141+
f"replayed_token_count ({replayed_token_count}); cannot split "
142+
"into per-token matrices"
143+
)
144+
matrix_elem_size = matrix_byte_length // replayed_token_count
145+
else:
146+
matrix_elem_size = 0
143147

144148
body = raw[HEADER_SIZE:]
145149
selector_bytes = body[:selector_byte_length]
@@ -173,8 +177,6 @@ def decompress_and_parse_r3(
173177
matrices[pos] = base64.b64encode(matrix_bytes[start:end]).decode("ascii")
174178

175179
metadata: Dict[str, Any] = {
176-
"num_moe_layers": num_moe_layers,
177-
"top_k": top_k,
178180
"routing_dtype": _ROUTING_DTYPE_NAMES.get(
179181
_RoutingDtype(routing_dtype), str(routing_dtype)
180182
),

tests/adapters/test_r3_deserializer.py

Lines changed: 17 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -28,17 +28,20 @@ def _make_raw_r3(
2828
routing_dtype: int = _RoutingDtype.UINT8,
2929
total_token_count: int = 4,
3030
replayed_token_count: int = 4,
31-
num_moe_layers: int = 2,
32-
top_k: int = 2,
31+
matrix_elem_size: Optional[int] = None,
3332
replay_start_token: int = 0,
3433
selector_bytes: bytes = b"",
3534
matrix_data: Optional[bytes] = None,
3635
) -> bytes:
37-
"""Build a raw (uncompressed) R3/v1 payload for testing."""
38-
dtype_byte_width = _RoutingDtype(routing_dtype).byte_width
39-
matrix_elem_size = num_moe_layers * top_k * dtype_byte_width
36+
"""Build a raw (uncompressed) R3/v1 payload for testing.
4037
38+
``matrix_elem_size`` is the per-token matrix byte length; when not given
39+
and no explicit ``matrix_data`` is supplied, defaults to 4 bytes/token
40+
(a minimal placeholder for tests that don't care about shape).
41+
"""
4142
if matrix_data is None:
43+
if matrix_elem_size is None:
44+
matrix_elem_size = 4
4245
matrix_data = bytes(range(matrix_elem_size)) * replayed_token_count
4346

4447
header = struct.pack(
@@ -50,8 +53,6 @@ def _make_raw_r3(
5053
0x01, # flags: little-endian
5154
total_token_count,
5255
replayed_token_count,
53-
num_moe_layers,
54-
top_k,
5556
replay_start_token,
5657
len(selector_bytes),
5758
len(matrix_data),
@@ -86,7 +87,7 @@ def test_too_short(self):
8687
def test_unsupported_version(self):
8788
raw = struct.pack(
8889
HEADER_FORMAT,
89-
MAGIC, 99, 0, 1, 0, 4, 4, 2, 2, 0, 0, 16,
90+
MAGIC, 99, 0, 1, 0, 4, 4, 0, 0, 16,
9091
)
9192
with pytest.raises(ValueError, match="Unsupported R3 header version"):
9293
_parse_header(raw)
@@ -118,10 +119,8 @@ def test_multi_byte(self):
118119

119120
class TestDecompressAndParseR3:
120121
def test_all_mode_uint8(self):
121-
num_moe_layers = 2
122-
top_k = 2
122+
matrix_elem_size = 4 # e.g. 2 MoE layers * 2 top-k * 1 byte (uint8)
123123
total_tokens = 4
124-
matrix_elem_size = num_moe_layers * top_k # 4 bytes per token
125124

126125
matrices_raw = []
127126
for i in range(total_tokens):
@@ -131,17 +130,13 @@ def test_all_mode_uint8(self):
131130
raw = _make_raw_r3(
132131
total_token_count=total_tokens,
133132
replayed_token_count=total_tokens,
134-
num_moe_layers=num_moe_layers,
135-
top_k=top_k,
136133
matrix_data=matrix_data,
137134
)
138135
blob = _compress_and_b64(raw)
139136

140137
matrices, metadata = decompress_and_parse_r3(blob)
141138

142139
assert len(matrices) == total_tokens
143-
assert metadata["num_moe_layers"] == num_moe_layers
144-
assert metadata["top_k"] == top_k
145140
assert metadata["routing_dtype"] == "uint8"
146141
assert metadata["selector_mode"] == "all"
147142
assert metadata["total_token_count"] == total_tokens
@@ -153,12 +148,10 @@ def test_all_mode_uint8(self):
153148
assert decoded == matrices_raw[i]
154149

155150
def test_suffix_mode(self):
156-
num_moe_layers = 2
157-
top_k = 2
151+
matrix_elem_size = 4
158152
total_tokens = 8
159153
replayed = 3
160154
start_token = 5
161-
matrix_elem_size = num_moe_layers * top_k
162155

163156
matrices_raw = []
164157
for i in range(replayed):
@@ -169,8 +162,6 @@ def test_suffix_mode(self):
169162
selector_mode=_SelectorMode.SUFFIX,
170163
total_token_count=total_tokens,
171164
replayed_token_count=replayed,
172-
num_moe_layers=num_moe_layers,
173-
top_k=top_k,
174165
replay_start_token=start_token,
175166
matrix_data=matrix_data,
176167
)
@@ -194,10 +185,8 @@ def test_suffix_mode(self):
194185
assert decoded == matrices_raw[i]
195186

196187
def test_bitmap_mode(self):
197-
num_moe_layers = 2
198-
top_k = 2
188+
matrix_elem_size = 4
199189
total_tokens = 8
200-
matrix_elem_size = num_moe_layers * top_k
201190

202191
# Replay tokens at positions 1, 3, 6
203192
replayed_positions = [1, 3, 6]
@@ -218,8 +207,6 @@ def test_bitmap_mode(self):
218207
selector_mode=_SelectorMode.BITMAP,
219208
total_token_count=total_tokens,
220209
replayed_token_count=replayed,
221-
num_moe_layers=num_moe_layers,
222-
top_k=top_k,
223210
selector_bytes=selector_bytes,
224211
matrix_data=matrix_data,
225212
)
@@ -241,10 +228,8 @@ def test_bitmap_mode(self):
241228
assert matrices[i] is None
242229

243230
def test_uint16_dtype(self):
244-
num_moe_layers = 2
245-
top_k = 2
231+
matrix_elem_size = 8 # e.g. 2 MoE layers * 2 top-k * 2 bytes (uint16)
246232
total_tokens = 2
247-
matrix_elem_size = num_moe_layers * top_k * 2 # 2 bytes per element for uint16
248233

249234
matrices_raw = []
250235
for i in range(total_tokens):
@@ -255,8 +240,6 @@ def test_uint16_dtype(self):
255240
routing_dtype=_RoutingDtype.UINT16,
256241
total_token_count=total_tokens,
257242
replayed_token_count=total_tokens,
258-
num_moe_layers=num_moe_layers,
259-
top_k=top_k,
260243
matrix_data=matrix_data,
261244
)
262245
blob = _compress_and_b64(raw)
@@ -336,8 +319,6 @@ def test_round_trip_with_serializer(self):
336319
data = RouterReplayData(
337320
routing_matrices=original_matrices,
338321
total_token_count=total_tokens,
339-
num_moe_layers=num_moe_layers,
340-
top_k=top_k,
341322
routing_dtype="uint8",
342323
)
343324

@@ -350,8 +331,7 @@ def test_round_trip_with_serializer(self):
350331
matrices, metadata = decompress_and_parse_r3(blob_b64)
351332

352333
assert len(matrices) == total_tokens
353-
assert metadata["num_moe_layers"] == num_moe_layers
354-
assert metadata["top_k"] == top_k
334+
assert metadata["total_token_count"] == total_tokens
355335

356336
for i in range(total_tokens):
357337
if original_b64[i] is None:
@@ -367,10 +347,8 @@ class TestConvertTraceDictWithPayloads:
367347
def test_trace_with_router_replay_payload(self):
368348
from eval_protocol.adapters.fireworks_tracing import convert_trace_dict_to_evaluation_row
369349

370-
num_moe_layers = 2
371-
top_k = 2
350+
matrix_elem_size = 4
372351
total_tokens = 4
373-
matrix_elem_size = num_moe_layers * top_k
374352

375353
matrices_raw = []
376354
for i in range(total_tokens):
@@ -380,8 +358,6 @@ def test_trace_with_router_replay_payload(self):
380358
raw = _make_raw_r3(
381359
total_token_count=total_tokens,
382360
replayed_token_count=total_tokens,
383-
num_moe_layers=num_moe_layers,
384-
top_k=top_k,
385361
matrix_data=matrix_data,
386362
)
387363
blob = _compress_and_b64(raw)
@@ -424,8 +400,8 @@ def test_trace_with_router_replay_payload(self):
424400
assert decoded == matrices_raw[i]
425401

426402
meta = row.execution_metadata.extra["routing_metadata"]
427-
assert meta["num_moe_layers"] == num_moe_layers
428-
assert meta["top_k"] == top_k
403+
assert meta["routing_dtype"] == "uint8"
404+
assert meta["total_token_count"] == total_tokens
429405

430406
def test_trace_without_payloads(self):
431407
from eval_protocol.adapters.fireworks_tracing import convert_trace_dict_to_evaluation_row

0 commit comments

Comments
 (0)