Skip to content

Commit 08e9b85

Browse files
Decode prompt token trace payloads (#455)
* Decode prompt token trace payloads Hydrate prompt_token_ids from Fireworks tracing payloads so RemoteRolloutProcessor can pass token-native prompt IDs through assistant turn metadata. Co-authored-by: Cursor <cursoragent@cursor.com> * Extract gateway payload decoders into standalone eval_protocol.tracing Move the per-payload binary deserializers out of adapters/ into a dependency-light eval_protocol/tracing package: a PayloadType StrEnum, DecodedPayload, and a decode_payloads registry (master decode) usable without EvaluationRow/rollout machinery. Refactor FireworksTracingAdapter to use the registry instead of three copy-pasted decode blocks, and decode pti/v1 as zstd(JSON int array) to match the gateway. Adds tests/tracing and a README. Co-authored-by: Cursor <cursoragent@cursor.com> * Trim verbose comment in fireworks_tracing adapter Replace the 4-line narrating comment over the payload-decode block with a single intent line; the code already shows what it does. Co-authored-by: Cursor <cursoragent@cursor.com> * Slim eval_protocol.tracing public API to the master decode surface Export only PayloadType, DecodedPayload, and decode_payloads/decode_payload/ decode_trace from the package __init__. The per-type decoders and PAYLOAD_DECODERS are internal building blocks, still reachable via submodules. Co-authored-by: Cursor <cursoragent@cursor.com> * Inline base64+zstd decompress into prompt_token_ids; drop _decompress module The shared helper had a single caller (prompt_token_ids); logprobs/router_replay already inline the same base64+zstd step. Inline it there too for consistency and remove the dedicated _decompress.py file. Co-authored-by: Cursor <cursoragent@cursor.com> * Type DecodedPayload.token_ids instead of an untyped extras bag Replace the Dict[str, Any] `extras` field (only ever holding logprobs token_ids) with a typed `token_ids: Optional[List[int]]` field. Callers now get a real type (dp.token_ids) instead of Any from extras.get(...). Update adapter, tests, README. Co-authored-by: Cursor <cursoragent@cursor.com> * Keep deprecated lp/r3 deserializer shims for backward compatibility. Re-export the moved tracing decoders from their old adapter paths with a DeprecationWarning so existing imports keep working after the tracing refactor. Co-authored-by: Cursor <cursoragent@cursor.com> --------- Co-authored-by: Cursor <cursoragent@cursor.com>
1 parent 1bd5447 commit 08e9b85

19 files changed

Lines changed: 1173 additions & 328 deletions

eval_protocol/adapters/fireworks_tracing.py

Lines changed: 35 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,8 @@
1515
import os
1616

1717
from eval_protocol.models import EvaluationRow, InputMetadata, ExecutionMetadata, Message
18+
from eval_protocol.tracing import PayloadType, decode_payloads
1819
from .base import BaseAdapter
19-
from .lp_deserializer import decompress_and_parse_lp
20-
from .r3_deserializer import decompress_and_parse_r3
2120
from .utils import extract_messages_from_data
2221
from ..common_utils import get_user_agent
2322

@@ -102,45 +101,42 @@ def convert_trace_dict_to_evaluation_row(
102101
):
103102
break # Break early if we've found all the metadata we need
104103

105-
# Extract router replay payloads when present
104+
# Decoding lives in eval_protocol.tracing; here we only map results onto the row.
106105
payloads = trace.get("payloads")
107106
if isinstance(payloads, dict):
108-
router_replay = payloads.get("router_replay")
109-
if isinstance(router_replay, dict) and router_replay.get("data"):
110-
try:
111-
matrices, r3_meta = decompress_and_parse_r3(router_replay["data"])
112-
if execution_metadata.extra is None:
113-
execution_metadata.extra = {}
114-
execution_metadata.extra["routing_matrices"] = matrices
115-
execution_metadata.extra["routing_metadata"] = r3_meta
116-
except Exception as e:
117-
logger.warning("Failed to decompress R3 payload for trace %s: %s", trace.get("id"), e)
118-
119-
logprobs_payload = payloads.get("logprobs")
120-
if isinstance(logprobs_payload, dict) and logprobs_payload.get("data"):
121-
try:
122-
logprobs, token_ids, lp_meta = decompress_and_parse_lp(logprobs_payload["data"])
123-
if execution_metadata.extra is None:
124-
execution_metadata.extra = {}
125-
execution_metadata.extra["completion_logprobs"] = logprobs
126-
if token_ids is not None:
127-
execution_metadata.extra["completion_token_ids"] = token_ids
128-
execution_metadata.extra["logprobs_metadata"] = lp_meta
129-
130-
for i in range(len(messages) - 1, -1, -1):
131-
if messages[i].role == "assistant":
132-
content_entries = [{"logprob": lp} for lp in logprobs]
133-
if token_ids is not None:
134-
for entry, tid in zip(content_entries, token_ids):
135-
entry["token_id"] = tid
136-
messages[i].logprobs = {"content": content_entries}
137-
break
138-
except Exception as e:
139-
logger.warning(
140-
"Failed to decompress logprobs payload for trace %s: %s",
141-
trace.get("id"),
142-
e,
143-
)
107+
decoded = decode_payloads(
108+
payloads,
109+
on_error=lambda pt, e: logger.warning(
110+
"Failed to decode %s payload for trace %s: %s", pt.value, trace.get("id"), e
111+
),
112+
)
113+
if decoded and execution_metadata.extra is None:
114+
execution_metadata.extra = {}
115+
116+
if (dp := decoded.get(PayloadType.ROUTER_REPLAY)) is not None:
117+
execution_metadata.extra["routing_matrices"] = dp.value
118+
execution_metadata.extra["routing_metadata"] = dp.metadata
119+
120+
if (dp := decoded.get(PayloadType.LOGPROBS)) is not None:
121+
logprobs = dp.value
122+
token_ids = dp.token_ids
123+
execution_metadata.extra["completion_logprobs"] = logprobs
124+
if token_ids is not None:
125+
execution_metadata.extra["completion_token_ids"] = token_ids
126+
execution_metadata.extra["logprobs_metadata"] = dp.metadata
127+
128+
for i in range(len(messages) - 1, -1, -1):
129+
if messages[i].role == "assistant":
130+
content_entries = [{"logprob": lp} for lp in logprobs]
131+
if token_ids is not None:
132+
for entry, tid in zip(content_entries, token_ids):
133+
entry["token_id"] = tid
134+
messages[i].logprobs = {"content": content_entries}
135+
break
136+
137+
if (dp := decoded.get(PayloadType.PROMPT_TOKEN_IDS)) is not None:
138+
execution_metadata.extra["prompt_token_ids"] = dp.value
139+
execution_metadata.extra["prompt_token_ids_metadata"] = dp.metadata
144140

145141
return EvaluationRow(
146142
messages=messages,
Lines changed: 36 additions & 104 deletions
Original file line numberDiff line numberDiff line change
@@ -1,109 +1,41 @@
1-
"""LP/v1 binary deserializer for per-token logprobs payloads.
1+
"""Deprecated compatibility shim for ``eval_protocol.tracing.logprobs``.
22
3-
Implements the inverse of the tracing gateway's ``logprobs_serializer.serialize_logprobs``.
4-
See that module for the full header specification.
3+
Import from ``eval_protocol.tracing.logprobs`` (or ``decode_payloads`` from
4+
``eval_protocol.tracing``) instead. This module re-exports the LP/v1 helpers
5+
that lived here before the tracing package refactor.
56
"""
67

78
from __future__ import annotations
89

9-
import base64
10-
import struct
11-
from typing import Any, Dict, List, Optional, Tuple
12-
13-
import zstandard as zstd
14-
15-
MAGIC = b"LP01"
16-
HEADER_VERSION = 1
17-
MISSING_TOKEN_ID = -1
18-
ENTRY_FORMAT = "<if"
19-
ENTRY_SIZE = struct.calcsize(ENTRY_FORMAT) # 8 bytes
20-
HEADER_FORMAT = "<4sBBHIIQ"
21-
HEADER_SIZE = struct.calcsize(HEADER_FORMAT) # 24 bytes
22-
23-
24-
def _parse_header(raw: bytes) -> Dict[str, Any]:
25-
if len(raw) < HEADER_SIZE:
26-
raise ValueError(f"Payload too short for lp/v1 header: {len(raw)} < {HEADER_SIZE}")
27-
28-
(
29-
magic,
30-
version,
31-
flags,
32-
reserved_u16,
33-
token_count,
34-
body_byte_length,
35-
reserved_u64,
36-
) = struct.unpack(HEADER_FORMAT, raw[:HEADER_SIZE])
37-
38-
if magic != MAGIC:
39-
raise ValueError(f"Bad LP/v1 magic: {magic!r}")
40-
if version != HEADER_VERSION:
41-
raise ValueError(f"Unsupported lp/v1 header version: {version}")
42-
43-
return {
44-
"flags": flags,
45-
"reserved_u16": reserved_u16,
46-
"token_count": token_count,
47-
"body_byte_length": body_byte_length,
48-
"reserved_u64": reserved_u64,
49-
}
50-
51-
52-
def parse_logprobs(raw: bytes) -> Tuple[List[float], Optional[List[int]], Dict[str, Any]]:
53-
"""Parse uncompressed LP/v1 bytes into logprobs, optional token ids, and metadata."""
54-
header = _parse_header(raw)
55-
token_count = header["token_count"]
56-
body_byte_length = header["body_byte_length"]
57-
58-
if token_count == 0:
59-
raise ValueError("LP/v1 token_count must be > 0")
60-
if body_byte_length != token_count * ENTRY_SIZE:
61-
raise ValueError(
62-
f"body_byte_length ({body_byte_length}) != token_count * {ENTRY_SIZE} "
63-
f"({token_count * ENTRY_SIZE})"
64-
)
65-
66-
expected_len = HEADER_SIZE + body_byte_length
67-
if len(raw) != expected_len:
68-
raise ValueError(f"LP/v1 payload length mismatch: {len(raw)} != {expected_len}")
69-
70-
logprobs: List[float] = []
71-
token_ids: List[int] = []
72-
all_token_ids_valid = True
73-
offset = HEADER_SIZE
74-
for _ in range(token_count):
75-
wire_id, logprob = struct.unpack(ENTRY_FORMAT, raw[offset : offset + ENTRY_SIZE])
76-
offset += ENTRY_SIZE
77-
logprobs.append(logprob)
78-
if wire_id == MISSING_TOKEN_ID:
79-
all_token_ids_valid = False
80-
token_ids.append(wire_id)
81-
else:
82-
token_ids.append(wire_id)
83-
84-
metadata: Dict[str, Any] = {
85-
"scope": "completion_only",
86-
"completion_token_count": token_count,
87-
"all_token_ids_valid": all_token_ids_valid,
88-
}
89-
header.update(metadata)
90-
ids_out: Optional[List[int]] = token_ids if all_token_ids_valid else None
91-
return logprobs, ids_out, header
92-
93-
94-
def decompress_and_parse_lp(data_b64: str) -> Tuple[List[float], Optional[List[int]], Dict[str, Any]]:
95-
"""Decompress and unpack an LP/v1 payload into completion logprobs and token ids.
96-
97-
Args:
98-
data_b64: Base64-encoded zstd-compressed LP binary blob from
99-
``payloads.logprobs.data``.
100-
101-
Returns:
102-
``(logprobs, token_ids, metadata)`` where ``logprobs`` is per-completion-token
103-
scalars, ``token_ids`` is ``None`` if any wire id was ``MISSING_TOKEN_ID``,
104-
and ``metadata`` includes ``all_token_ids_valid`` and ``completion_token_count``.
105-
"""
106-
compressed = base64.b64decode(data_b64)
107-
decompressor = zstd.ZstdDecompressor()
108-
raw = decompressor.decompress(compressed)
109-
return parse_logprobs(raw)
10+
import warnings
11+
12+
warnings.warn(
13+
"eval_protocol.adapters.lp_deserializer is deprecated; "
14+
"import from eval_protocol.tracing.logprobs instead.",
15+
DeprecationWarning,
16+
stacklevel=2,
17+
)
18+
19+
from eval_protocol.tracing.logprobs import ( # noqa: E402
20+
ENTRY_FORMAT,
21+
ENTRY_SIZE,
22+
HEADER_FORMAT,
23+
HEADER_SIZE,
24+
HEADER_VERSION,
25+
MAGIC,
26+
MISSING_TOKEN_ID,
27+
decompress_and_parse_lp,
28+
parse_logprobs,
29+
)
30+
31+
__all__ = [
32+
"ENTRY_FORMAT",
33+
"ENTRY_SIZE",
34+
"HEADER_FORMAT",
35+
"HEADER_SIZE",
36+
"HEADER_VERSION",
37+
"MAGIC",
38+
"MISSING_TOKEN_ID",
39+
"decompress_and_parse_lp",
40+
"parse_logprobs",
41+
]

0 commit comments

Comments
 (0)