Skip to content

Commit e769ac1

Browse files
initial commit
1 parent 86a52a4 commit e769ac1

8 files changed

Lines changed: 714 additions & 12 deletions

File tree

eval_protocol/adapters/fireworks_tracing.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,13 +100,29 @@ def convert_trace_dict_to_evaluation_row(
100100
):
101101
break # Break early if we've found all the metadata we need
102102

103+
# Extract router replay payloads when present
104+
payloads = trace.get("payloads")
105+
if isinstance(payloads, dict):
106+
router_replay = payloads.get("router_replay")
107+
if isinstance(router_replay, dict) and router_replay.get("data"):
108+
try:
109+
from .r3_deserializer import decompress_and_parse_r3
110+
111+
matrices, r3_meta = decompress_and_parse_r3(router_replay["data"])
112+
if execution_metadata.extra is None:
113+
execution_metadata.extra = {}
114+
execution_metadata.extra["routing_matrices"] = matrices
115+
execution_metadata.extra["routing_metadata"] = r3_meta
116+
except Exception as e:
117+
logger.warning("Failed to decompress R3 payload for trace %s: %s", trace.get("id"), e)
118+
103119
return EvaluationRow(
104120
messages=messages,
105121
tools=tools,
106122
input_metadata=InputMetadata(
107123
row_id=row_id,
108124
session_data={
109-
"langfuse_trace_id": trace.get("id"), # Store the trace ID here
125+
"langfuse_trace_id": trace.get("id"),
110126
},
111127
),
112128
execution_metadata=execution_metadata,
@@ -426,6 +442,7 @@ def get_evaluation_rows(
426442
max_retries: int = 3,
427443
span_name: Optional[str] = None,
428444
converter: Optional[TraceDictConverter] = None,
445+
include_payloads: bool = False,
429446
) -> List[EvaluationRow]:
430447
"""Pull traces from Langfuse via proxy and convert to EvaluationRow format.
431448
@@ -449,6 +466,8 @@ def get_evaluation_rows(
449466
max_retries: Max retry attempts used by proxy (default: 3)
450467
converter: Optional custom converter implementing TraceDictConverter protocol.
451468
If provided, this will be used instead of the default conversion logic.
469+
include_payloads: If True, request payload data (e.g., router replay)
470+
from the gateway and decompress it into the returned EvaluationRows.
452471
453472
Returns:
454473
List[EvaluationRow]: Converted evaluation rows
@@ -479,6 +498,7 @@ def get_evaluation_rows(
479498
"to_timestamp": to_timestamp.isoformat() if to_timestamp else None,
480499
"sleep_between_gets": sleep_between_gets,
481500
"max_retries": max_retries,
501+
"include_payloads": include_payloads if include_payloads else None,
482502
}
483503

484504
# Remove None values
Lines changed: 189 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,189 @@
1+
"""R3/v1 binary deserializer for router-replay payloads.
2+
3+
Implements the inverse of the packed binary format produced by the tracing
4+
gateway's ``r3_serializer.serialize_r3``. See that module for the full
5+
header specification.
6+
7+
The main entry point is :func:`decompress_and_parse_r3`, which accepts the
8+
base64-encoded compressed blob returned by the gateway's
9+
``/v1/traces/pointwise?include_payloads=true`` endpoint and produces
10+
per-token routing matrices in the same ``List[Optional[str]]`` format used
11+
by the direct inference path (``DeploymentSampler.sample_with_tokens()``).
12+
"""
13+
14+
from __future__ import annotations
15+
16+
import base64
17+
import logging
18+
import math
19+
import struct
20+
from enum import IntEnum
21+
from typing import Any, Dict, List, Optional, Tuple
22+
23+
import zstandard as zstd
24+
25+
logger = logging.getLogger(__name__)
26+
27+
MAGIC = b"R3V1"
28+
HEADER_FORMAT = "<4sBBBBIIHHIIQ"
29+
HEADER_SIZE = struct.calcsize(HEADER_FORMAT) # 36 bytes
30+
31+
32+
class _SelectorMode(IntEnum):
33+
ALL = 0
34+
SUFFIX = 1
35+
BITMAP = 2
36+
37+
38+
class _RoutingDtype(IntEnum):
39+
UINT8 = 1
40+
UINT16 = 2
41+
42+
@property
43+
def byte_width(self) -> int:
44+
return self.value
45+
46+
47+
_SELECTOR_MODE_NAMES = {v: v.name.lower() for v in _SelectorMode}
48+
_ROUTING_DTYPE_NAMES = {v: v.name.lower() for v in _RoutingDtype}
49+
50+
51+
def _parse_header(raw: bytes) -> Dict[str, Any]:
52+
if len(raw) < HEADER_SIZE:
53+
raise ValueError(
54+
f"Payload too short for r3/v1 header: {len(raw)} < {HEADER_SIZE}"
55+
)
56+
57+
(
58+
magic,
59+
version,
60+
selector_mode,
61+
routing_dtype,
62+
flags,
63+
total_token_count,
64+
replayed_token_count,
65+
num_moe_layers,
66+
top_k,
67+
replay_start_token,
68+
selector_byte_length,
69+
matrix_byte_length,
70+
) = struct.unpack(HEADER_FORMAT, raw[:HEADER_SIZE])
71+
72+
if magic != MAGIC:
73+
raise ValueError(f"Bad R3 magic: {magic!r}")
74+
if version != 1:
75+
raise ValueError(f"Unsupported R3 header version: {version}")
76+
77+
return {
78+
"selector_mode": selector_mode,
79+
"routing_dtype": routing_dtype,
80+
"flags": flags,
81+
"total_token_count": total_token_count,
82+
"replayed_token_count": replayed_token_count,
83+
"num_moe_layers": num_moe_layers,
84+
"top_k": top_k,
85+
"replay_start_token": replay_start_token,
86+
"selector_byte_length": selector_byte_length,
87+
"matrix_byte_length": matrix_byte_length,
88+
}
89+
90+
91+
def _read_bitmap_positions(
92+
selector_bytes: bytes, total_token_count: int
93+
) -> List[int]:
94+
"""Return sorted token indices where the bitmap bit is set."""
95+
positions: List[int] = []
96+
for i in range(total_token_count):
97+
byte_idx = i >> 3
98+
bit_idx = i & 7
99+
if byte_idx < len(selector_bytes) and (selector_bytes[byte_idx] >> bit_idx) & 1:
100+
positions.append(i)
101+
return positions
102+
103+
104+
def decompress_and_parse_r3(
105+
data_b64: str,
106+
) -> Tuple[List[Optional[str]], Dict[str, Any]]:
107+
"""Decompress and unpack an R3/v1 payload into per-token routing matrices.
108+
109+
Args:
110+
data_b64: Base64-encoded zstd-compressed R3 binary blob, as returned
111+
by the tracing gateway in ``payloads.router_replay.data``.
112+
113+
Returns:
114+
A tuple of ``(routing_matrices, metadata)`` where:
115+
116+
- ``routing_matrices`` is a ``List[Optional[str]]`` of length
117+
``total_token_count``. Each present position contains a
118+
base64-encoded routing matrix (matching the format returned by
119+
the direct inference path); absent positions are ``None``.
120+
- ``metadata`` is a dict with keys ``num_moe_layers``, ``top_k``,
121+
``routing_dtype``, ``selector_mode``, ``total_token_count``,
122+
``replayed_token_count``, ``replay_start_token``.
123+
"""
124+
compressed = base64.b64decode(data_b64)
125+
126+
decompressor = zstd.ZstdDecompressor()
127+
raw = decompressor.decompress(compressed, max_output_size=len(compressed) * 20)
128+
129+
header = _parse_header(raw)
130+
131+
selector_mode = header["selector_mode"]
132+
routing_dtype = header["routing_dtype"]
133+
total_token_count = header["total_token_count"]
134+
replayed_token_count = header["replayed_token_count"]
135+
num_moe_layers = header["num_moe_layers"]
136+
top_k = header["top_k"]
137+
replay_start_token = header["replay_start_token"]
138+
selector_byte_length = header["selector_byte_length"]
139+
matrix_byte_length = header["matrix_byte_length"]
140+
141+
dtype_byte_width = _RoutingDtype(routing_dtype).byte_width
142+
matrix_elem_size = num_moe_layers * top_k * dtype_byte_width
143+
144+
body = raw[HEADER_SIZE:]
145+
selector_bytes = body[:selector_byte_length]
146+
matrix_bytes = body[selector_byte_length : selector_byte_length + matrix_byte_length]
147+
148+
if matrix_elem_size == 0:
149+
replayed_positions: List[int] = []
150+
elif selector_mode == _SelectorMode.ALL:
151+
replayed_positions = list(range(total_token_count))
152+
elif selector_mode == _SelectorMode.SUFFIX:
153+
replayed_positions = list(
154+
range(replay_start_token, replay_start_token + replayed_token_count)
155+
)
156+
elif selector_mode == _SelectorMode.BITMAP:
157+
replayed_positions = _read_bitmap_positions(selector_bytes, total_token_count)
158+
else:
159+
raise ValueError(f"Unknown selector_mode: {selector_mode}")
160+
161+
# Split matrix bytes into per-token chunks and base64-encode each one
162+
matrices: List[Optional[str]] = [None] * total_token_count
163+
for idx, pos in enumerate(replayed_positions):
164+
start = idx * matrix_elem_size
165+
end = start + matrix_elem_size
166+
if end > len(matrix_bytes):
167+
logger.warning(
168+
"R3 matrix data truncated at token %d (position %d): "
169+
"expected %d bytes but only %d remaining",
170+
idx, pos, matrix_elem_size, len(matrix_bytes) - start,
171+
)
172+
break
173+
matrices[pos] = base64.b64encode(matrix_bytes[start:end]).decode("ascii")
174+
175+
metadata: Dict[str, Any] = {
176+
"num_moe_layers": num_moe_layers,
177+
"top_k": top_k,
178+
"routing_dtype": _ROUTING_DTYPE_NAMES.get(
179+
_RoutingDtype(routing_dtype), str(routing_dtype)
180+
),
181+
"selector_mode": _SELECTOR_MODE_NAMES.get(
182+
_SelectorMode(selector_mode), str(selector_mode)
183+
),
184+
"total_token_count": total_token_count,
185+
"replayed_token_count": replayed_token_count,
186+
"replay_start_token": replay_start_token,
187+
}
188+
189+
return matrices, metadata

eval_protocol/pytest/remote_rollout_processor.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,11 +35,13 @@ def __init__(
3535
model_base_url: str = "https://tracing.fireworks.ai",
3636
poll_interval: float = 1.0,
3737
timeout_seconds: float = 120.0,
38+
include_payloads: bool = False,
3839
):
3940
# Prefer constructor-provided configuration. These can be overridden via
4041
# config.kwargs at call time for backward compatibility.
4142
self._remote_base_url = remote_base_url
4243
self._model_base_url = model_base_url
44+
self._include_payloads = include_payloads
4345
if os.getenv("EP_REMOTE_ROLLOUT_PROCESSOR_BASE_URL"):
4446
self._remote_base_url = os.getenv("EP_REMOTE_ROLLOUT_PROCESSOR_BASE_URL")
4547
_ep_model_base_url = os.getenv("EP_MODEL_BASE_URL")
@@ -175,7 +177,10 @@ async def _process_row(row: EvaluationRow) -> EvaluationRow:
175177
row.execution_metadata.rollout_duration_seconds = time.perf_counter() - start_time
176178

177179
def _update_with_trace() -> None:
178-
return update_row_with_remote_trace(row, default_fireworks_output_data_loader, model_base_url)
180+
return update_row_with_remote_trace(
181+
row, default_fireworks_output_data_loader, model_base_url,
182+
include_payloads=self._include_payloads,
183+
)
179184

180185
await asyncio.to_thread(_update_with_trace) # Update row with remote trace in-place
181186
return row

eval_protocol/pytest/tracing_utils.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,11 @@ def fetch_traces() -> List[EvaluationRow]:
2222
# Use EP_REMOTE_API_KEY for fetching remote traces, falling back to FIREWORKS_API_KEY
2323
api_key = os.environ.get("EP_REMOTE_API_KEY") or os.environ.get("FIREWORKS_API_KEY")
2424
adapter = FireworksTracingAdapter(base_url=base_url, api_key=api_key)
25-
return adapter.get_evaluation_rows(tags=[f"rollout_id:{config.rollout_id}"], max_retries=5)
25+
return adapter.get_evaluation_rows(
26+
tags=[f"rollout_id:{config.rollout_id}"],
27+
max_retries=5,
28+
include_payloads=config.include_payloads,
29+
)
2630

2731
return DynamicDataLoader(generators=[fetch_traces], preprocess_fn=filter_longest_conversation)
2832

@@ -148,13 +152,20 @@ def build_init_request(
148152

149153

150154
def update_row_with_remote_trace(
151-
row: EvaluationRow, output_data_loader: Callable[[DataLoaderConfig], DynamicDataLoader], model_base_url: str
155+
row: EvaluationRow,
156+
output_data_loader: Callable[[DataLoaderConfig], DynamicDataLoader],
157+
model_base_url: str,
158+
include_payloads: bool = False,
152159
) -> None:
153160
"""Update row with remote trace data using output_data_loader (shared logic)."""
154161
if not row.execution_metadata.rollout_id:
155162
return None
156163

157-
loader_config = DataLoaderConfig(rollout_id=row.execution_metadata.rollout_id, model_base_url=model_base_url)
164+
loader_config = DataLoaderConfig(
165+
rollout_id=row.execution_metadata.rollout_id,
166+
model_base_url=model_base_url,
167+
include_payloads=include_payloads,
168+
)
158169
data_loader = output_data_loader(loader_config)
159170
results = data_loader.load()
160171
output_rows: List[EvaluationRow] = [r for result in results for r in result.rows]

eval_protocol/types/remote_rollout_processor.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ class DataLoaderConfig(BaseModel):
3939

4040
rollout_id: str
4141
model_base_url: Optional[str] = None
42+
include_payloads: bool = False
4243

4344

4445
class InitRequest(BaseModel):

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ dependencies = [
4848
"deepdiff>=6.0.0",
4949
"websockets>=15.0.1",
5050
"fastapi>=0.116.1",
51+
"zstandard>=0.19.0",
5152
]
5253

5354
[project.urls]

0 commit comments

Comments
 (0)