Skip to content

Commit 7c82b28

Browse files
Add LP/v1 logprobs trace payload deserialization (FIR-21499) (#452)
* Add LP/v1 logprobs deserialization for tracing gateway payloads (FIR-21499). Decode logprobs payloads into completion_logprobs and Message.logprobs on EvaluationRow. Pop base_url from completion_params before OpenAI SDK calls so dev inference API can be encoded in gateway tracing URLs. Co-authored-by: Cursor <cursoragent@cursor.com> * Hoist r3 and lp deserializer imports to module level in fireworks_tracing. Co-authored-by: Cursor <cursoragent@cursor.com> * Keep base_url in completion_params; strip at OpenAI call site. Use get() in build_init_request so base_url remains available for gateway URL encoding, and filter it out in remote_server when calling the SDK. Co-authored-by: Cursor <cursoragent@cursor.com> * remove redudant funciton --------- Co-authored-by: Cursor <cursoragent@cursor.com>
1 parent 99e49fa commit 7c82b28

6 files changed

Lines changed: 315 additions & 5 deletions

File tree

eval_protocol/adapters/fireworks_tracing.py

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616

1717
from eval_protocol.models import EvaluationRow, InputMetadata, ExecutionMetadata, Message
1818
from .base import BaseAdapter
19+
from .lp_deserializer import decompress_and_parse_lp
20+
from .r3_deserializer import decompress_and_parse_r3
1921
from .utils import extract_messages_from_data
2022
from ..common_utils import get_user_agent
2123

@@ -106,8 +108,6 @@ def convert_trace_dict_to_evaluation_row(
106108
router_replay = payloads.get("router_replay")
107109
if isinstance(router_replay, dict) and router_replay.get("data"):
108110
try:
109-
from .r3_deserializer import decompress_and_parse_r3
110-
111111
matrices, r3_meta = decompress_and_parse_r3(router_replay["data"])
112112
if execution_metadata.extra is None:
113113
execution_metadata.extra = {}
@@ -116,6 +116,32 @@ def convert_trace_dict_to_evaluation_row(
116116
except Exception as e:
117117
logger.warning("Failed to decompress R3 payload for trace %s: %s", trace.get("id"), e)
118118

119+
logprobs_payload = payloads.get("logprobs")
120+
if isinstance(logprobs_payload, dict) and logprobs_payload.get("data"):
121+
try:
122+
logprobs, token_ids, lp_meta = decompress_and_parse_lp(logprobs_payload["data"])
123+
if execution_metadata.extra is None:
124+
execution_metadata.extra = {}
125+
execution_metadata.extra["completion_logprobs"] = logprobs
126+
if token_ids is not None:
127+
execution_metadata.extra["completion_token_ids"] = token_ids
128+
execution_metadata.extra["logprobs_metadata"] = lp_meta
129+
130+
for i in range(len(messages) - 1, -1, -1):
131+
if messages[i].role == "assistant":
132+
content_entries = [{"logprob": lp} for lp in logprobs]
133+
if token_ids is not None:
134+
for entry, tid in zip(content_entries, token_ids):
135+
entry["token_id"] = tid
136+
messages[i].logprobs = {"content": content_entries}
137+
break
138+
except Exception as e:
139+
logger.warning(
140+
"Failed to decompress logprobs payload for trace %s: %s",
141+
trace.get("id"),
142+
e,
143+
)
144+
119145
return EvaluationRow(
120146
messages=messages,
121147
tools=tools,
Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
"""LP/v1 binary deserializer for per-token logprobs payloads.
2+
3+
Implements the inverse of the tracing gateway's ``logprobs_serializer.serialize_logprobs``.
4+
See that module for the full header specification.
5+
"""
6+
7+
from __future__ import annotations
8+
9+
import base64
10+
import struct
11+
from typing import Any, Dict, List, Optional, Tuple
12+
13+
import zstandard as zstd
14+
15+
MAGIC = b"LP01"
16+
HEADER_VERSION = 1
17+
MISSING_TOKEN_ID = -1
18+
ENTRY_FORMAT = "<if"
19+
ENTRY_SIZE = struct.calcsize(ENTRY_FORMAT) # 8 bytes
20+
HEADER_FORMAT = "<4sBBHIIQ"
21+
HEADER_SIZE = struct.calcsize(HEADER_FORMAT) # 24 bytes
22+
23+
24+
def _parse_header(raw: bytes) -> Dict[str, Any]:
25+
if len(raw) < HEADER_SIZE:
26+
raise ValueError(f"Payload too short for lp/v1 header: {len(raw)} < {HEADER_SIZE}")
27+
28+
(
29+
magic,
30+
version,
31+
flags,
32+
reserved_u16,
33+
token_count,
34+
body_byte_length,
35+
reserved_u64,
36+
) = struct.unpack(HEADER_FORMAT, raw[:HEADER_SIZE])
37+
38+
if magic != MAGIC:
39+
raise ValueError(f"Bad LP/v1 magic: {magic!r}")
40+
if version != HEADER_VERSION:
41+
raise ValueError(f"Unsupported lp/v1 header version: {version}")
42+
43+
return {
44+
"flags": flags,
45+
"reserved_u16": reserved_u16,
46+
"token_count": token_count,
47+
"body_byte_length": body_byte_length,
48+
"reserved_u64": reserved_u64,
49+
}
50+
51+
52+
def parse_logprobs(raw: bytes) -> Tuple[List[float], Optional[List[int]], Dict[str, Any]]:
53+
"""Parse uncompressed LP/v1 bytes into logprobs, optional token ids, and metadata."""
54+
header = _parse_header(raw)
55+
token_count = header["token_count"]
56+
body_byte_length = header["body_byte_length"]
57+
58+
if token_count == 0:
59+
raise ValueError("LP/v1 token_count must be > 0")
60+
if body_byte_length != token_count * ENTRY_SIZE:
61+
raise ValueError(
62+
f"body_byte_length ({body_byte_length}) != token_count * {ENTRY_SIZE} "
63+
f"({token_count * ENTRY_SIZE})"
64+
)
65+
66+
expected_len = HEADER_SIZE + body_byte_length
67+
if len(raw) != expected_len:
68+
raise ValueError(f"LP/v1 payload length mismatch: {len(raw)} != {expected_len}")
69+
70+
logprobs: List[float] = []
71+
token_ids: List[int] = []
72+
all_token_ids_valid = True
73+
offset = HEADER_SIZE
74+
for _ in range(token_count):
75+
wire_id, logprob = struct.unpack(ENTRY_FORMAT, raw[offset : offset + ENTRY_SIZE])
76+
offset += ENTRY_SIZE
77+
logprobs.append(logprob)
78+
if wire_id == MISSING_TOKEN_ID:
79+
all_token_ids_valid = False
80+
token_ids.append(wire_id)
81+
else:
82+
token_ids.append(wire_id)
83+
84+
metadata: Dict[str, Any] = {
85+
"scope": "completion_only",
86+
"completion_token_count": token_count,
87+
"all_token_ids_valid": all_token_ids_valid,
88+
}
89+
header.update(metadata)
90+
ids_out: Optional[List[int]] = token_ids if all_token_ids_valid else None
91+
return logprobs, ids_out, header
92+
93+
94+
def decompress_and_parse_lp(data_b64: str) -> Tuple[List[float], Optional[List[int]], Dict[str, Any]]:
95+
"""Decompress and unpack an LP/v1 payload into completion logprobs and token ids.
96+
97+
Args:
98+
data_b64: Base64-encoded zstd-compressed LP binary blob from
99+
``payloads.logprobs.data``.
100+
101+
Returns:
102+
``(logprobs, token_ids, metadata)`` where ``logprobs`` is per-completion-token
103+
scalars, ``token_ids`` is ``None`` if any wire id was ``MISSING_TOKEN_ID``,
104+
and ``metadata`` includes ``all_token_ids_valid`` and ``completion_token_count``.
105+
"""
106+
compressed = base64.b64decode(data_b64)
107+
decompressor = zstd.ZstdDecompressor()
108+
raw = decompressor.decompress(compressed)
109+
return parse_logprobs(raw)

eval_protocol/pytest/tracing_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ def build_init_request(
103103
if not completion_params_dict.get("model"):
104104
raise ValueError("Model must be provided in completion_params")
105105

106-
# Extract base_url from completion_params
106+
# Extract base_url from completion_params for tracing-gateway URL encoding
107107
completion_params_base_url: Optional[str] = completion_params_dict.get("base_url")
108108

109109
# Strip non-OpenAI fields from messages
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
"""Tests for logprobs payload handling in fireworks_tracing adapter."""
2+
3+
from __future__ import annotations
4+
5+
import base64
6+
import struct
7+
8+
import pytest
9+
import zstandard as zstd
10+
11+
pytest.importorskip("mcp")
12+
13+
from eval_protocol.adapters.fireworks_tracing import convert_trace_dict_to_evaluation_row
14+
from eval_protocol.adapters.lp_deserializer import (
15+
ENTRY_FORMAT,
16+
ENTRY_SIZE,
17+
HEADER_FORMAT,
18+
MAGIC,
19+
MISSING_TOKEN_ID,
20+
)
21+
22+
23+
def _lp_b64(tokens: list[tuple[int, float]]) -> str:
24+
token_count = len(tokens)
25+
body_byte_length = token_count * ENTRY_SIZE
26+
header = struct.pack(
27+
HEADER_FORMAT,
28+
MAGIC,
29+
1,
30+
0,
31+
0,
32+
token_count,
33+
body_byte_length,
34+
0,
35+
)
36+
body = b"".join(struct.pack(ENTRY_FORMAT, tid, lp) for tid, lp in tokens)
37+
raw = header + body
38+
compressed = zstd.ZstdCompressor().compress(raw)
39+
return base64.b64encode(compressed).decode("ascii")
40+
41+
42+
def _base_trace(*, with_token_ids: bool = True) -> dict:
43+
tokens = [(10, -0.1), (11, -0.2)] if with_token_ids else [(MISSING_TOKEN_ID, -0.1), (12, -0.2)]
44+
return {
45+
"id": "trace-1",
46+
"input": {
47+
"messages": [
48+
{"role": "user", "content": "hi"},
49+
{"role": "assistant", "content": "hello"},
50+
],
51+
},
52+
"output": {"role": "assistant", "content": "hello"},
53+
"payloads": {
54+
"logprobs": {
55+
"data": _lp_b64(tokens),
56+
"manifest": {"PayloadVersion": "lp/v1"},
57+
},
58+
},
59+
}
60+
61+
62+
class TestConvertTraceLogprobs:
63+
def test_attaches_completion_logprobs_and_message_logprobs(self):
64+
row = convert_trace_dict_to_evaluation_row(_base_trace())
65+
assert row is not None
66+
67+
extra = row.execution_metadata.extra
68+
assert extra is not None
69+
assert extra["completion_logprobs"] == pytest.approx([-0.1, -0.2])
70+
assert extra["completion_token_ids"] == [10, 11]
71+
72+
assistant = row.messages[-1]
73+
assert assistant.role == "assistant"
74+
content = assistant.logprobs["content"]
75+
assert len(content) == len(extra["completion_logprobs"])
76+
assert content[0]["token_id"] == 10
77+
assert content[1]["token_id"] == 11
78+
assert content[0]["logprob"] == pytest.approx(-0.1)
79+
assert content[1]["logprob"] == pytest.approx(-0.2)
80+
81+
def test_omits_token_id_keys_when_any_missing(self):
82+
row = convert_trace_dict_to_evaluation_row(_base_trace(with_token_ids=False))
83+
assert row is not None
84+
85+
extra = row.execution_metadata.extra
86+
assert "completion_logprobs" in extra
87+
assert "completion_token_ids" not in extra
88+
89+
content = row.messages[-1].logprobs["content"]
90+
assert len(content) == 2
91+
assert all("token_id" not in entry for entry in content)
92+
assert content[0]["logprob"] == pytest.approx(-0.1)
93+
assert content[1]["logprob"] == pytest.approx(-0.2)
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
"""Tests for LP/v1 binary deserializer (gateway-compatible)."""
2+
3+
from __future__ import annotations
4+
5+
import base64
6+
import struct
7+
8+
import pytest
9+
import zstandard as zstd
10+
11+
from eval_protocol.adapters.lp_deserializer import (
12+
ENTRY_FORMAT,
13+
ENTRY_SIZE,
14+
HEADER_FORMAT,
15+
HEADER_SIZE,
16+
MAGIC,
17+
MISSING_TOKEN_ID,
18+
decompress_and_parse_lp,
19+
parse_logprobs,
20+
)
21+
22+
# Golden raw bytes: two tokens (7, -0.25) and (8, -0.5) — must match gateway serializer.
23+
GOLDEN_RAW_HEX = (
24+
"4c503031010000000200000010000000000000000000000007000000000080be"
25+
"08000000000000bf"
26+
)
27+
28+
29+
def _build_raw(tokens: list[tuple[int, float]]) -> bytes:
30+
token_count = len(tokens)
31+
body_byte_length = token_count * ENTRY_SIZE
32+
header = struct.pack(
33+
HEADER_FORMAT,
34+
MAGIC,
35+
1,
36+
0,
37+
0,
38+
token_count,
39+
body_byte_length,
40+
0,
41+
)
42+
body = b"".join(struct.pack(ENTRY_FORMAT, tid, lp) for tid, lp in tokens)
43+
return header + body
44+
45+
46+
def _compress_b64(raw: bytes) -> str:
47+
return base64.b64encode(zstd.ZstdCompressor().compress(raw)).decode("ascii")
48+
49+
50+
class TestParseLogprobs:
51+
def test_golden_bytes_match_gateway(self):
52+
raw = bytes.fromhex(GOLDEN_RAW_HEX)
53+
logprobs, token_ids, meta = parse_logprobs(raw)
54+
assert logprobs == [-0.25, -0.5]
55+
assert token_ids == [7, 8]
56+
assert meta["all_token_ids_valid"] is True
57+
assert meta["token_count"] == 2
58+
59+
def test_missing_token_id_omits_token_ids_list(self):
60+
raw = _build_raw([(MISSING_TOKEN_ID, -0.3), (42, -0.4)])
61+
logprobs, token_ids, meta = parse_logprobs(raw)
62+
assert logprobs == pytest.approx([-0.3, -0.4])
63+
assert token_ids is None
64+
assert meta["all_token_ids_valid"] is False
65+
66+
def test_decompress_and_parse_round_trip(self):
67+
raw = bytes.fromhex(GOLDEN_RAW_HEX)
68+
b64 = _compress_b64(raw)
69+
logprobs, token_ids, meta = decompress_and_parse_lp(b64)
70+
assert logprobs == [-0.25, -0.5]
71+
assert token_ids == [7, 8]
72+
assert meta["scope"] == "completion_only"
73+
74+
def test_rejects_bad_magic(self):
75+
raw = _build_raw([(1, -0.1)])
76+
bad = b"XXXX" + raw[4:]
77+
with pytest.raises(ValueError, match="Bad LP/v1 magic"):
78+
parse_logprobs(bad)

tests/remote_server/remote_server.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,8 +55,12 @@ def _worker():
5555
md = {k: v for k, v in md.items() if v is not None}
5656
messages_payload.append(md)
5757

58-
# Spread all completion_params (model, temperature, max_tokens, etc.)
59-
completion_kwargs = {"messages": messages_payload, **req.completion_params}
58+
# Spread completion_params; omit base_url (client uses req.model_base_url; gateway
59+
# encodes inference base_url into the tracing path via build_init_request).
60+
completion_kwargs = {
61+
"messages": messages_payload,
62+
**{k: v for k, v in req.completion_params.items() if k != "base_url"},
63+
}
6064

6165
if req.tools:
6266
completion_kwargs["tools"] = req.tools

0 commit comments

Comments
 (0)