Skip to content

Commit 99e49fa

Browse files
feat: add R3/v1 router replay deserialization support (#450)
* initial commit * refactor: drop num_moe_layers/top_k from r3/v1 deserializer Mirrors the gateway-side r3_serializer change: the per-token matrix shape (num_moe_layers, top_k) is no longer required and is no longer written into the r3/v1 binary header. Per-token matrix byte size is recovered as matrix_byte_length / replayed_token_count. - HEADER_FORMAT: "<4sBBBBIIHHIIQ" (36 bytes) -> "<4sBBBBIIIIQ" (32 bytes). - Drop num_moe_layers/top_k from _parse_header() and the metadata dict returned by decompress_and_parse_r3(). - Compute matrix_elem_size from matrix_byte_length / replayed_token_count with a divisibility check that surfaces malformed payloads early. - Update unit tests to use matrix_elem_size as the parameter and drop assertions on the removed header fields; round-trip test no longer passes num_moe_layers/top_k to RouterReplayData. Co-authored-by: Cursor <cursoragent@cursor.com> * test * fix: drop arbitrary 20x cap on r3/v1 decompressed size ZstdCompressor.compress() (used by the gateway-side r3_serializer) embeds the uncompressed size in the frame header, so passing max_output_size=len(compressed)*20 was both unnecessary and incorrect: highly compressible router-replay payloads (e.g. tokens routing to a small subset of experts) routinely exceed a 20:1 ratio, and would have failed deserialization with ZstdError. Removing the cap lets the library auto-allocate from the embedded content size. Verified locally: a 64 KiB zero-filled matrix payload compresses to ~35 bytes (>1800x ratio) and now deserializes cleanly. Adds a regression test covering the high-compression case. Co-authored-by: Cursor <cursoragent@cursor.com> * fix: do not construct IntEnum for unknown dtype/selector_mode _RoutingDtype(int) and _SelectorMode(int) raise ValueError for any value not in the enum, so the .get() fallback was unreachable: a future routing_dtype=3 in the header would crash metadata construction before str(int) could run. Look up names by raw int instead — IntEnum keys hash-equal their int values, so known modes resolve to their lowercase name and unknown ones fall back to str(int) without ever constructing the enum. Adds a regression test exercising routing_dtype=99. Co-authored-by: Cursor <cursoragent@cursor.com> * chore: drop unused _RoutingDtype.byte_width decompress_and_parse_r3 now derives matrix_elem_size from matrix_byte_length / replayed_token_count, so the dtype's per-element byte width is no longer referenced anywhere. Removing dead code. Co-authored-by: Cursor <cursoragent@cursor.com> * simplify 1 aspect * early return --------- Co-authored-by: Cursor <cursoragent@cursor.com>
1 parent 251ed86 commit 99e49fa

8 files changed

Lines changed: 727 additions & 13 deletions

File tree

eval_protocol/adapters/fireworks_tracing.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,13 +100,29 @@ def convert_trace_dict_to_evaluation_row(
100100
):
101101
break # Break early if we've found all the metadata we need
102102

103+
# Extract router replay payloads when present
104+
payloads = trace.get("payloads")
105+
if isinstance(payloads, dict):
106+
router_replay = payloads.get("router_replay")
107+
if isinstance(router_replay, dict) and router_replay.get("data"):
108+
try:
109+
from .r3_deserializer import decompress_and_parse_r3
110+
111+
matrices, r3_meta = decompress_and_parse_r3(router_replay["data"])
112+
if execution_metadata.extra is None:
113+
execution_metadata.extra = {}
114+
execution_metadata.extra["routing_matrices"] = matrices
115+
execution_metadata.extra["routing_metadata"] = r3_meta
116+
except Exception as e:
117+
logger.warning("Failed to decompress R3 payload for trace %s: %s", trace.get("id"), e)
118+
103119
return EvaluationRow(
104120
messages=messages,
105121
tools=tools,
106122
input_metadata=InputMetadata(
107123
row_id=row_id,
108124
session_data={
109-
"langfuse_trace_id": trace.get("id"), # Store the trace ID here
125+
"langfuse_trace_id": trace.get("id"),
110126
},
111127
),
112128
execution_metadata=execution_metadata,
@@ -426,6 +442,7 @@ def get_evaluation_rows(
426442
max_retries: int = 3,
427443
span_name: Optional[str] = None,
428444
converter: Optional[TraceDictConverter] = None,
445+
include_payloads: bool = False,
429446
) -> List[EvaluationRow]:
430447
"""Pull traces from Langfuse via proxy and convert to EvaluationRow format.
431448
@@ -449,6 +466,8 @@ def get_evaluation_rows(
449466
max_retries: Max retry attempts used by proxy (default: 3)
450467
converter: Optional custom converter implementing TraceDictConverter protocol.
451468
If provided, this will be used instead of the default conversion logic.
469+
include_payloads: If True, request payload data (e.g., router replay)
470+
from the gateway and decompress it into the returned EvaluationRows.
452471
453472
Returns:
454473
List[EvaluationRow]: Converted evaluation rows
@@ -479,6 +498,7 @@ def get_evaluation_rows(
479498
"to_timestamp": to_timestamp.isoformat() if to_timestamp else None,
480499
"sleep_between_gets": sleep_between_gets,
481500
"max_retries": max_retries,
501+
"include_payloads": include_payloads if include_payloads else None,
482502
}
483503

484504
# Remove None values
Lines changed: 187 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,187 @@
1+
"""R3/v1 binary deserializer for router-replay payloads.
2+
3+
Implements the inverse of the packed binary format produced by the tracing
4+
gateway's ``r3_serializer.serialize_r3``. See that module for the full
5+
header specification.
6+
7+
The main entry point is :func:`decompress_and_parse_r3`, which accepts the
8+
base64-encoded compressed blob returned by the gateway's
9+
``/v1/traces/pointwise?include_payloads=true`` endpoint and produces
10+
per-token routing matrices in the same ``List[Optional[str]]`` format used
11+
by the direct inference path (``DeploymentSampler.sample_with_tokens()``).
12+
"""
13+
14+
from __future__ import annotations
15+
16+
import base64
17+
import struct
18+
from enum import IntEnum
19+
from typing import Any, Dict, List, Optional, Tuple
20+
21+
import zstandard as zstd
22+
23+
MAGIC = b"R3V1"
24+
HEADER_FORMAT = "<4sBBBBIIIIQ"
25+
HEADER_SIZE = struct.calcsize(HEADER_FORMAT) # 32 bytes
26+
BITS_PER_BYTE = 8
27+
28+
29+
class _SelectorMode(IntEnum):
30+
ALL = 0
31+
SUFFIX = 1
32+
BITMAP = 2
33+
34+
35+
class _RoutingDtype(IntEnum):
36+
UINT8 = 1
37+
UINT16 = 2
38+
39+
40+
_SELECTOR_MODE_NAMES = {v: v.name.lower() for v in _SelectorMode}
41+
_ROUTING_DTYPE_NAMES = {v: v.name.lower() for v in _RoutingDtype}
42+
43+
44+
def _parse_header(raw: bytes) -> Dict[str, Any]:
45+
if len(raw) < HEADER_SIZE:
46+
raise ValueError(
47+
f"Payload too short for r3/v1 header: {len(raw)} < {HEADER_SIZE}"
48+
)
49+
50+
(
51+
magic,
52+
version,
53+
selector_mode,
54+
routing_dtype,
55+
flags,
56+
total_token_count,
57+
replayed_token_count,
58+
replay_start_token,
59+
selector_byte_length,
60+
matrix_byte_length,
61+
) = struct.unpack(HEADER_FORMAT, raw[:HEADER_SIZE])
62+
63+
if magic != MAGIC:
64+
raise ValueError(f"Bad R3 magic: {magic!r}")
65+
if version != 1:
66+
raise ValueError(f"Unsupported R3 header version: {version}")
67+
68+
return {
69+
"selector_mode": selector_mode,
70+
"routing_dtype": routing_dtype,
71+
"flags": flags,
72+
"total_token_count": total_token_count,
73+
"replayed_token_count": replayed_token_count,
74+
"replay_start_token": replay_start_token,
75+
"selector_byte_length": selector_byte_length,
76+
"matrix_byte_length": matrix_byte_length,
77+
}
78+
79+
80+
def _read_bitmap_positions(
81+
selector_bytes: bytes, total_token_count: int
82+
) -> List[int]:
83+
"""Return sorted token indices where the bitmap bit is set."""
84+
positions: List[int] = []
85+
for i in range(total_token_count):
86+
byte_idx = i // BITS_PER_BYTE
87+
bit_idx = i % BITS_PER_BYTE
88+
if byte_idx < len(selector_bytes) and (selector_bytes[byte_idx] >> bit_idx) & 1:
89+
positions.append(i)
90+
return positions
91+
92+
93+
def decompress_and_parse_r3(
94+
data_b64: str,
95+
) -> Tuple[List[Optional[str]], Dict[str, Any]]:
96+
"""Decompress and unpack an R3/v1 payload into per-token routing matrices.
97+
98+
Args:
99+
data_b64: Base64-encoded zstd-compressed R3 binary blob, as returned
100+
by the tracing gateway in ``payloads.router_replay.data``.
101+
102+
Returns:
103+
A tuple of ``(routing_matrices, metadata)`` where:
104+
105+
- ``routing_matrices`` is a ``List[Optional[str]]`` of length
106+
``total_token_count``. Each present position contains a
107+
base64-encoded routing matrix (matching the format returned by
108+
the direct inference path); absent positions are ``None``.
109+
- ``metadata`` is a dict with keys ``routing_dtype``,
110+
``selector_mode``, ``total_token_count``, ``replayed_token_count``,
111+
``replay_start_token``.
112+
"""
113+
compressed = base64.b64decode(data_b64)
114+
115+
# ZstdCompressor.compress() embeds the uncompressed size in the frame
116+
# header by default, so the library can auto-allocate the output buffer.
117+
decompressor = zstd.ZstdDecompressor()
118+
raw = decompressor.decompress(compressed)
119+
120+
header = _parse_header(raw)
121+
122+
selector_mode = header["selector_mode"]
123+
routing_dtype = header["routing_dtype"]
124+
total_token_count = header["total_token_count"]
125+
replayed_token_count = header["replayed_token_count"]
126+
replay_start_token = header["replay_start_token"]
127+
selector_byte_length = header["selector_byte_length"]
128+
matrix_byte_length = header["matrix_byte_length"]
129+
130+
metadata: Dict[str, Any] = {
131+
"routing_dtype": _ROUTING_DTYPE_NAMES.get(routing_dtype, str(routing_dtype)),
132+
"selector_mode": _SELECTOR_MODE_NAMES.get(selector_mode, str(selector_mode)),
133+
"total_token_count": total_token_count,
134+
"replayed_token_count": replayed_token_count,
135+
"replay_start_token": replay_start_token,
136+
}
137+
138+
if replayed_token_count == 0:
139+
return [None] * total_token_count, metadata
140+
141+
# Per-token matrix byte size is implicit in the payload: all replayed
142+
# tokens share the same matrix length, so we can recover it from the
143+
# matrix section total length divided by the replayed-token count.
144+
if matrix_byte_length % replayed_token_count != 0:
145+
raise ValueError(
146+
f"matrix_byte_length ({matrix_byte_length}) is not a multiple of "
147+
f"replayed_token_count ({replayed_token_count}); cannot split "
148+
"into per-token matrices"
149+
)
150+
matrix_elem_size = matrix_byte_length // replayed_token_count
151+
152+
body = raw[HEADER_SIZE:]
153+
expected_body_length = selector_byte_length + matrix_byte_length
154+
if len(body) < expected_body_length:
155+
raise ValueError(
156+
f"Payload body too short for selector and matrix sections: "
157+
f"{len(body)} < {expected_body_length}"
158+
)
159+
160+
selector_bytes = body[:selector_byte_length]
161+
matrix_bytes = body[selector_byte_length : selector_byte_length + matrix_byte_length]
162+
163+
if selector_mode == _SelectorMode.ALL:
164+
replayed_positions = list(range(total_token_count))
165+
elif selector_mode == _SelectorMode.SUFFIX:
166+
replayed_positions = list(
167+
range(replay_start_token, replay_start_token + replayed_token_count)
168+
)
169+
elif selector_mode == _SelectorMode.BITMAP:
170+
replayed_positions = _read_bitmap_positions(selector_bytes, total_token_count)
171+
else:
172+
raise ValueError(f"Unknown selector_mode: {selector_mode}")
173+
174+
if len(replayed_positions) != replayed_token_count:
175+
raise ValueError(
176+
f"Selector produced {len(replayed_positions)} replayed positions, "
177+
f"but header replayed_token_count is {replayed_token_count}"
178+
)
179+
180+
# Split matrix bytes into per-token chunks and base64-encode each one
181+
matrices: List[Optional[str]] = [None] * total_token_count
182+
for idx, pos in enumerate(replayed_positions):
183+
start = idx * matrix_elem_size
184+
end = start + matrix_elem_size
185+
matrices[pos] = base64.b64encode(matrix_bytes[start:end]).decode("ascii")
186+
187+
return matrices, metadata

eval_protocol/pytest/remote_rollout_processor.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,11 +35,13 @@ def __init__(
3535
model_base_url: str = "https://tracing.fireworks.ai",
3636
poll_interval: float = 1.0,
3737
timeout_seconds: float = 120.0,
38+
include_payloads: bool = False,
3839
):
3940
# Prefer constructor-provided configuration. These can be overridden via
4041
# config.kwargs at call time for backward compatibility.
4142
self._remote_base_url = remote_base_url
4243
self._model_base_url = model_base_url
44+
self._include_payloads = include_payloads
4345
if os.getenv("EP_REMOTE_ROLLOUT_PROCESSOR_BASE_URL"):
4446
self._remote_base_url = os.getenv("EP_REMOTE_ROLLOUT_PROCESSOR_BASE_URL")
4547
_ep_model_base_url = os.getenv("EP_MODEL_BASE_URL")
@@ -194,7 +196,10 @@ async def _process_row(row: EvaluationRow) -> EvaluationRow:
194196
row.execution_metadata.rollout_duration_seconds = time.perf_counter() - start_time
195197

196198
def _update_with_trace() -> None:
197-
return update_row_with_remote_trace(row, default_fireworks_output_data_loader, model_base_url)
199+
return update_row_with_remote_trace(
200+
row, default_fireworks_output_data_loader, model_base_url,
201+
include_payloads=self._include_payloads,
202+
)
198203

199204
await asyncio.to_thread(_update_with_trace) # Update row with remote trace in-place
200205
return row

eval_protocol/pytest/tracing_utils.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,11 @@ def fetch_traces() -> List[EvaluationRow]:
2222
# Use EP_REMOTE_API_KEY for fetching remote traces, falling back to FIREWORKS_API_KEY
2323
api_key = os.environ.get("EP_REMOTE_API_KEY") or os.environ.get("FIREWORKS_API_KEY")
2424
adapter = FireworksTracingAdapter(base_url=base_url, api_key=api_key)
25-
return adapter.get_evaluation_rows(tags=[f"rollout_id:{config.rollout_id}"], max_retries=5)
25+
return adapter.get_evaluation_rows(
26+
tags=[f"rollout_id:{config.rollout_id}"],
27+
max_retries=5,
28+
include_payloads=config.include_payloads,
29+
)
2630

2731
return DynamicDataLoader(generators=[fetch_traces], preprocess_fn=filter_longest_conversation)
2832

@@ -129,7 +133,7 @@ def build_init_request(
129133

130134
# Build final model base URL with tracing metadata
131135
final_model_base_url = model_base_url
132-
if model_base_url and ("tracing.fireworks.ai" in model_base_url or model_base_url.startswith("http://localhost")):
136+
if model_base_url and ("tracing.fireworks.ai" in model_base_url or model_base_url.startswith("http://localhost") or "litellm-gateway" in model_base_url):
133137
final_model_base_url = build_fireworks_tracing_url(model_base_url, meta, completion_params_base_url)
134138

135139
# Extract API key from environment or completion_params
@@ -148,13 +152,20 @@ def build_init_request(
148152

149153

150154
def update_row_with_remote_trace(
151-
row: EvaluationRow, output_data_loader: Callable[[DataLoaderConfig], DynamicDataLoader], model_base_url: str
155+
row: EvaluationRow,
156+
output_data_loader: Callable[[DataLoaderConfig], DynamicDataLoader],
157+
model_base_url: str,
158+
include_payloads: bool = False,
152159
) -> None:
153160
"""Update row with remote trace data using output_data_loader (shared logic)."""
154161
if not row.execution_metadata.rollout_id:
155162
return None
156163

157-
loader_config = DataLoaderConfig(rollout_id=row.execution_metadata.rollout_id, model_base_url=model_base_url)
164+
loader_config = DataLoaderConfig(
165+
rollout_id=row.execution_metadata.rollout_id,
166+
model_base_url=model_base_url,
167+
include_payloads=include_payloads,
168+
)
158169
data_loader = output_data_loader(loader_config)
159170
results = data_loader.load()
160171
output_rows: List[EvaluationRow] = [r for result in results for r in result.rows]

eval_protocol/types/remote_rollout_processor.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ class DataLoaderConfig(BaseModel):
3939

4040
rollout_id: str
4141
model_base_url: Optional[str] = None
42+
include_payloads: bool = False
4243

4344

4445
class InitRequest(BaseModel):

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ dependencies = [
4848
"deepdiff>=6.0.0",
4949
"websockets>=15.0.1",
5050
"fastapi>=0.116.1",
51+
"zstandard>=0.19.0",
5152
]
5253

5354
[project.urls]

0 commit comments

Comments
 (0)