@@ -28,17 +28,20 @@ def _make_raw_r3(
2828 routing_dtype : int = _RoutingDtype .UINT8 ,
2929 total_token_count : int = 4 ,
3030 replayed_token_count : int = 4 ,
31- num_moe_layers : int = 2 ,
32- top_k : int = 2 ,
31+ matrix_elem_size : Optional [int ] = None ,
3332 replay_start_token : int = 0 ,
3433 selector_bytes : bytes = b"" ,
3534 matrix_data : Optional [bytes ] = None ,
3635) -> bytes :
37- """Build a raw (uncompressed) R3/v1 payload for testing."""
38- dtype_byte_width = _RoutingDtype (routing_dtype ).byte_width
39- matrix_elem_size = num_moe_layers * top_k * dtype_byte_width
36+ """Build a raw (uncompressed) R3/v1 payload for testing.
4037
38+ ``matrix_elem_size`` is the per-token matrix byte length; when not given
39+ and no explicit ``matrix_data`` is supplied, defaults to 4 bytes/token
40+ (a minimal placeholder for tests that don't care about shape).
41+ """
4142 if matrix_data is None :
43+ if matrix_elem_size is None :
44+ matrix_elem_size = 4
4245 matrix_data = bytes (range (matrix_elem_size )) * replayed_token_count
4346
4447 header = struct .pack (
@@ -50,8 +53,6 @@ def _make_raw_r3(
5053 0x01 , # flags: little-endian
5154 total_token_count ,
5255 replayed_token_count ,
53- num_moe_layers ,
54- top_k ,
5556 replay_start_token ,
5657 len (selector_bytes ),
5758 len (matrix_data ),
@@ -86,7 +87,7 @@ def test_too_short(self):
8687 def test_unsupported_version (self ):
8788 raw = struct .pack (
8889 HEADER_FORMAT ,
89- MAGIC , 99 , 0 , 1 , 0 , 4 , 4 , 2 , 2 , 0 , 0 , 16 ,
90+ MAGIC , 99 , 0 , 1 , 0 , 4 , 4 , 0 , 0 , 16 ,
9091 )
9192 with pytest .raises (ValueError , match = "Unsupported R3 header version" ):
9293 _parse_header (raw )
@@ -118,10 +119,8 @@ def test_multi_byte(self):
118119
119120class TestDecompressAndParseR3 :
120121 def test_all_mode_uint8 (self ):
121- num_moe_layers = 2
122- top_k = 2
122+ matrix_elem_size = 4 # e.g. 2 MoE layers * 2 top-k * 1 byte (uint8)
123123 total_tokens = 4
124- matrix_elem_size = num_moe_layers * top_k # 4 bytes per token
125124
126125 matrices_raw = []
127126 for i in range (total_tokens ):
@@ -131,17 +130,13 @@ def test_all_mode_uint8(self):
131130 raw = _make_raw_r3 (
132131 total_token_count = total_tokens ,
133132 replayed_token_count = total_tokens ,
134- num_moe_layers = num_moe_layers ,
135- top_k = top_k ,
136133 matrix_data = matrix_data ,
137134 )
138135 blob = _compress_and_b64 (raw )
139136
140137 matrices , metadata = decompress_and_parse_r3 (blob )
141138
142139 assert len (matrices ) == total_tokens
143- assert metadata ["num_moe_layers" ] == num_moe_layers
144- assert metadata ["top_k" ] == top_k
145140 assert metadata ["routing_dtype" ] == "uint8"
146141 assert metadata ["selector_mode" ] == "all"
147142 assert metadata ["total_token_count" ] == total_tokens
@@ -153,12 +148,10 @@ def test_all_mode_uint8(self):
153148 assert decoded == matrices_raw [i ]
154149
155150 def test_suffix_mode (self ):
156- num_moe_layers = 2
157- top_k = 2
151+ matrix_elem_size = 4
158152 total_tokens = 8
159153 replayed = 3
160154 start_token = 5
161- matrix_elem_size = num_moe_layers * top_k
162155
163156 matrices_raw = []
164157 for i in range (replayed ):
@@ -169,8 +162,6 @@ def test_suffix_mode(self):
169162 selector_mode = _SelectorMode .SUFFIX ,
170163 total_token_count = total_tokens ,
171164 replayed_token_count = replayed ,
172- num_moe_layers = num_moe_layers ,
173- top_k = top_k ,
174165 replay_start_token = start_token ,
175166 matrix_data = matrix_data ,
176167 )
@@ -194,10 +185,8 @@ def test_suffix_mode(self):
194185 assert decoded == matrices_raw [i ]
195186
196187 def test_bitmap_mode (self ):
197- num_moe_layers = 2
198- top_k = 2
188+ matrix_elem_size = 4
199189 total_tokens = 8
200- matrix_elem_size = num_moe_layers * top_k
201190
202191 # Replay tokens at positions 1, 3, 6
203192 replayed_positions = [1 , 3 , 6 ]
@@ -218,8 +207,6 @@ def test_bitmap_mode(self):
218207 selector_mode = _SelectorMode .BITMAP ,
219208 total_token_count = total_tokens ,
220209 replayed_token_count = replayed ,
221- num_moe_layers = num_moe_layers ,
222- top_k = top_k ,
223210 selector_bytes = selector_bytes ,
224211 matrix_data = matrix_data ,
225212 )
@@ -241,10 +228,8 @@ def test_bitmap_mode(self):
241228 assert matrices [i ] is None
242229
243230 def test_uint16_dtype (self ):
244- num_moe_layers = 2
245- top_k = 2
231+ matrix_elem_size = 8 # e.g. 2 MoE layers * 2 top-k * 2 bytes (uint16)
246232 total_tokens = 2
247- matrix_elem_size = num_moe_layers * top_k * 2 # 2 bytes per element for uint16
248233
249234 matrices_raw = []
250235 for i in range (total_tokens ):
@@ -255,8 +240,6 @@ def test_uint16_dtype(self):
255240 routing_dtype = _RoutingDtype .UINT16 ,
256241 total_token_count = total_tokens ,
257242 replayed_token_count = total_tokens ,
258- num_moe_layers = num_moe_layers ,
259- top_k = top_k ,
260243 matrix_data = matrix_data ,
261244 )
262245 blob = _compress_and_b64 (raw )
@@ -336,8 +319,6 @@ def test_round_trip_with_serializer(self):
336319 data = RouterReplayData (
337320 routing_matrices = original_matrices ,
338321 total_token_count = total_tokens ,
339- num_moe_layers = num_moe_layers ,
340- top_k = top_k ,
341322 routing_dtype = "uint8" ,
342323 )
343324
@@ -350,8 +331,7 @@ def test_round_trip_with_serializer(self):
350331 matrices , metadata = decompress_and_parse_r3 (blob_b64 )
351332
352333 assert len (matrices ) == total_tokens
353- assert metadata ["num_moe_layers" ] == num_moe_layers
354- assert metadata ["top_k" ] == top_k
334+ assert metadata ["total_token_count" ] == total_tokens
355335
356336 for i in range (total_tokens ):
357337 if original_b64 [i ] is None :
@@ -367,10 +347,8 @@ class TestConvertTraceDictWithPayloads:
367347 def test_trace_with_router_replay_payload (self ):
368348 from eval_protocol .adapters .fireworks_tracing import convert_trace_dict_to_evaluation_row
369349
370- num_moe_layers = 2
371- top_k = 2
350+ matrix_elem_size = 4
372351 total_tokens = 4
373- matrix_elem_size = num_moe_layers * top_k
374352
375353 matrices_raw = []
376354 for i in range (total_tokens ):
@@ -380,8 +358,6 @@ def test_trace_with_router_replay_payload(self):
380358 raw = _make_raw_r3 (
381359 total_token_count = total_tokens ,
382360 replayed_token_count = total_tokens ,
383- num_moe_layers = num_moe_layers ,
384- top_k = top_k ,
385361 matrix_data = matrix_data ,
386362 )
387363 blob = _compress_and_b64 (raw )
@@ -424,8 +400,8 @@ def test_trace_with_router_replay_payload(self):
424400 assert decoded == matrices_raw [i ]
425401
426402 meta = row .execution_metadata .extra ["routing_metadata" ]
427- assert meta ["num_moe_layers " ] == num_moe_layers
428- assert meta ["top_k " ] == top_k
403+ assert meta ["routing_dtype " ] == "uint8"
404+ assert meta ["total_token_count " ] == total_tokens
429405
430406 def test_trace_without_payloads (self ):
431407 from eval_protocol .adapters .fireworks_tracing import convert_trace_dict_to_evaluation_row
0 commit comments