Skip to content

Commit aee752b

Browse files
Decode prompt token trace payloads
Hydrate prompt_token_ids from Fireworks tracing payloads so RemoteRolloutProcessor can pass token-native prompt IDs through assistant turn metadata. Co-authored-by: Cursor <cursoragent@cursor.com>
1 parent 1bd5447 commit aee752b

6 files changed

Lines changed: 402 additions & 0 deletions

File tree

eval_protocol/adapters/fireworks_tracing.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from eval_protocol.models import EvaluationRow, InputMetadata, ExecutionMetadata, Message
1818
from .base import BaseAdapter
1919
from .lp_deserializer import decompress_and_parse_lp
20+
from .pti_deserializer import decompress_and_parse_pti
2021
from .r3_deserializer import decompress_and_parse_r3
2122
from .utils import extract_messages_from_data
2223
from ..common_utils import get_user_agent
@@ -142,6 +143,21 @@ def convert_trace_dict_to_evaluation_row(
142143
e,
143144
)
144145

146+
prompt_ids_payload = payloads.get("prompt_token_ids")
147+
if isinstance(prompt_ids_payload, dict) and prompt_ids_payload.get("data"):
148+
try:
149+
prompt_token_ids, pti_meta = decompress_and_parse_pti(prompt_ids_payload["data"])
150+
if execution_metadata.extra is None:
151+
execution_metadata.extra = {}
152+
execution_metadata.extra["prompt_token_ids"] = prompt_token_ids
153+
execution_metadata.extra["prompt_token_ids_metadata"] = pti_meta
154+
except Exception as e:
155+
logger.warning(
156+
"Failed to decompress prompt token IDs payload for trace %s: %s",
157+
trace.get("id"),
158+
e,
159+
)
160+
145161
return EvaluationRow(
146162
messages=messages,
147163
tools=tools,
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
"""PTI/v1 binary deserializer for prompt token ID payloads.
2+
3+
Implements the inverse of the tracing gateway's
4+
``prompt_token_ids_serializer.serialize_prompt_token_ids``.
5+
"""
6+
7+
from __future__ import annotations
8+
9+
import base64
10+
import struct
11+
from typing import Any, Dict, List, Tuple
12+
13+
import zstandard as zstd
14+
15+
MAGIC = b"PTI1"
16+
HEADER_VERSION = 1
17+
ENTRY_FORMAT = "<i"
18+
ENTRY_SIZE = struct.calcsize(ENTRY_FORMAT) # 4 bytes
19+
HEADER_FORMAT = "<4sBBHIIQ"
20+
HEADER_SIZE = struct.calcsize(HEADER_FORMAT) # 24 bytes
21+
22+
23+
def _parse_header(raw: bytes) -> Dict[str, Any]:
24+
if len(raw) < HEADER_SIZE:
25+
raise ValueError(f"Payload too short for PTI/v1 header: {len(raw)} < {HEADER_SIZE}")
26+
27+
(
28+
magic,
29+
version,
30+
flags,
31+
reserved_u16,
32+
token_count,
33+
body_byte_length,
34+
reserved_u64,
35+
) = struct.unpack(HEADER_FORMAT, raw[:HEADER_SIZE])
36+
37+
if magic != MAGIC:
38+
raise ValueError(f"Bad PTI/v1 magic: {magic!r}")
39+
if version != HEADER_VERSION:
40+
raise ValueError(f"Unsupported PTI/v1 header version: {version}")
41+
42+
return {
43+
"flags": flags,
44+
"reserved_u16": reserved_u16,
45+
"token_count": token_count,
46+
"body_byte_length": body_byte_length,
47+
"reserved_u64": reserved_u64,
48+
}
49+
50+
51+
def parse_prompt_token_ids(raw: bytes) -> Tuple[List[int], Dict[str, Any]]:
52+
"""Parse uncompressed PTI/v1 bytes into prompt token IDs and metadata."""
53+
header = _parse_header(raw)
54+
token_count = header["token_count"]
55+
body_byte_length = header["body_byte_length"]
56+
57+
if token_count == 0:
58+
raise ValueError("PTI/v1 token_count must be > 0")
59+
if body_byte_length != token_count * ENTRY_SIZE:
60+
raise ValueError(
61+
f"body_byte_length ({body_byte_length}) != token_count * {ENTRY_SIZE} "
62+
f"({token_count * ENTRY_SIZE})"
63+
)
64+
65+
expected_len = HEADER_SIZE + body_byte_length
66+
if len(raw) != expected_len:
67+
raise ValueError(f"PTI/v1 payload length mismatch: {len(raw)} != {expected_len}")
68+
69+
token_ids: List[int] = []
70+
offset = HEADER_SIZE
71+
for _ in range(token_count):
72+
(token_id,) = struct.unpack(ENTRY_FORMAT, raw[offset : offset + ENTRY_SIZE])
73+
offset += ENTRY_SIZE
74+
token_ids.append(token_id)
75+
76+
metadata: Dict[str, Any] = {
77+
"scope": "prompt_only",
78+
"token_count": token_count,
79+
}
80+
header.update(metadata)
81+
return token_ids, header
82+
83+
84+
def decompress_and_parse_pti(data_b64: str) -> Tuple[List[int], Dict[str, Any]]:
85+
"""Decompress and unpack a PTI/v1 prompt token ID payload.
86+
87+
Args:
88+
data_b64: Base64-encoded zstd-compressed PTI binary blob from
89+
``payloads.prompt_token_ids.data``.
90+
91+
Returns:
92+
``(token_ids, metadata)`` where ``token_ids`` is the prompt token ID
93+
sequence and ``metadata`` includes ``token_count``.
94+
"""
95+
compressed = base64.b64decode(data_b64)
96+
decompressor = zstd.ZstdDecompressor()
97+
raw = decompressor.decompress(compressed)
98+
return parse_prompt_token_ids(raw)

eval_protocol/pytest/tracing_utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,8 @@ def _merge_payloads_into_longest_row(longest_row: EvaluationRow, rows: List[Eval
6363
for key in (
6464
"completion_logprobs",
6565
"completion_token_ids",
66+
"prompt_token_ids",
67+
"prompt_token_ids_metadata",
6668
"logprobs_metadata",
6769
"routing_matrices",
6870
"routing_metadata",
Lines changed: 209 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,209 @@
1+
#!/usr/bin/env python3
2+
"""E2E check: RemoteRolloutProcessor reads prompt_token_ids trace payloads.
3+
4+
This starts a tiny local `/init` server, sends one chat completion through the
5+
Fireworks tracing gateway with `return_token_ids`, and verifies that
6+
RemoteRolloutProcessor hydrates `assistant_turn_payloads[*].prompt_token_ids`.
7+
"""
8+
9+
from __future__ import annotations
10+
11+
import argparse
12+
import asyncio
13+
import logging
14+
import os
15+
import sys
16+
import socket
17+
import threading
18+
import time
19+
from pathlib import Path
20+
from typing import Any
21+
22+
sys.path.insert(0, str(Path(__file__).resolve().parents[1]))
23+
24+
import uvicorn
25+
from fastapi import FastAPI
26+
from openai import OpenAI
27+
28+
from eval_protocol import FireworksTracingHttpHandler, InitRequest, RolloutIdFilter, Status
29+
from eval_protocol.models import EvaluationRow, Message
30+
from eval_protocol.pytest.remote_rollout_processor import RemoteRolloutProcessor
31+
from eval_protocol.pytest.types import RolloutProcessorConfig
32+
33+
logger = logging.getLogger("remote_rollout_prompt_token_ids")
34+
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
35+
36+
37+
def _free_port() -> int:
38+
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
39+
sock.bind(("127.0.0.1", 0))
40+
return int(sock.getsockname()[1])
41+
42+
43+
def _message_to_dict(message: Message | dict[str, Any]) -> dict[str, Any]:
44+
if isinstance(message, Message):
45+
return message.dump_mdoel_for_chat_completion_request()
46+
return {k: v for k, v in dict(message).items() if v is not None}
47+
48+
49+
def _make_app(gateway_url: str) -> FastAPI:
50+
app = FastAPI()
51+
app_logger = logging.getLogger(f"{__name__}.server")
52+
app_logger.setLevel(logging.INFO)
53+
54+
@app.get("/")
55+
def health() -> dict[str, str]:
56+
return {"status": "ok"}
57+
58+
@app.post("/init")
59+
def init(req: InitRequest) -> dict[str, str]:
60+
rollout_logger = logging.getLogger(f"{__name__}.{req.metadata.rollout_id}")
61+
rollout_logger.addFilter(RolloutIdFilter(req.metadata.rollout_id))
62+
if not any(isinstance(handler, FireworksTracingHttpHandler) for handler in rollout_logger.handlers):
63+
rollout_logger.addHandler(FireworksTracingHttpHandler(gateway_base_url=gateway_url))
64+
rollout_logger.setLevel(logging.INFO)
65+
66+
def _worker() -> None:
67+
try:
68+
conversation = [_message_to_dict(message) for message in (req.messages or [])]
69+
params = dict(req.completion_params or {})
70+
params.pop("base_url", None)
71+
params["extra_body"] = {
72+
**dict(params.get("extra_body") or {}),
73+
"return_token_ids": True,
74+
}
75+
params.setdefault("temperature", 0)
76+
params.setdefault("max_tokens", 8)
77+
78+
if not req.model_base_url:
79+
raise ValueError("model_base_url is required")
80+
if not params.get("model"):
81+
raise ValueError("completion_params.model is required")
82+
83+
client = OpenAI(base_url=req.model_base_url, api_key=req.api_key)
84+
response = client.chat.completions.create(messages=conversation, **params)
85+
content = response.choices[0].message.content or ""
86+
logger.info("remote server generated content=%r", content)
87+
88+
rollout_logger.info(
89+
"rollout %s finished",
90+
req.metadata.rollout_id,
91+
extra={"status": Status.rollout_finished()},
92+
)
93+
except Exception as exc:
94+
rollout_logger.exception(
95+
"rollout %s failed",
96+
req.metadata.rollout_id,
97+
extra={"status": Status.rollout_unknown_error(str(exc))},
98+
)
99+
100+
threading.Thread(target=_worker, daemon=True).start()
101+
return {"status": "started"}
102+
103+
return app
104+
105+
106+
def _wait_ready(url: str, timeout_seconds: float = 30.0) -> None:
107+
import requests
108+
109+
deadline = time.time() + timeout_seconds
110+
while time.time() < deadline:
111+
try:
112+
resp = requests.get(url, timeout=2)
113+
if resp.status_code == 200:
114+
return
115+
except Exception:
116+
pass
117+
time.sleep(0.2)
118+
raise TimeoutError(f"server not ready: {url}")
119+
120+
121+
async def _run(args: argparse.Namespace) -> None:
122+
api_key = args.api_key or os.getenv("FIREWORKS_DEV_API_KEY") or os.getenv("FIREWORKS_API_KEY")
123+
if not api_key:
124+
raise ValueError("Set FIREWORKS_DEV_API_KEY or FIREWORKS_API_KEY")
125+
126+
# FireworksTracingHttpHandler reads FIREWORKS_API_KEY.
127+
os.environ["FIREWORKS_API_KEY"] = api_key
128+
os.environ["EP_REMOTE_API_KEY"] = api_key
129+
130+
port = args.port or _free_port()
131+
remote_base_url = f"http://127.0.0.1:{port}"
132+
app = _make_app(args.gateway_url)
133+
config = uvicorn.Config(app, host="127.0.0.1", port=port, log_level="warning")
134+
server = uvicorn.Server(config)
135+
thread = threading.Thread(target=server.run, daemon=True)
136+
thread.start()
137+
_wait_ready(f"{remote_base_url}/")
138+
139+
rollout_id = f"rrp-prompt-ids-{int(time.time())}"
140+
row = EvaluationRow(
141+
messages=[Message(role="user", content="Reply with exactly: ok")],
142+
)
143+
row.input_metadata.row_id = "row-0"
144+
row.input_metadata.completion_params = {
145+
"model": args.model,
146+
"base_url": args.api_base_url,
147+
"temperature": 0,
148+
"max_tokens": 8,
149+
}
150+
row.execution_metadata.rollout_id = rollout_id
151+
row.execution_metadata.invocation_id = "inv-0"
152+
row.execution_metadata.experiment_id = "fir2-1747-rrp-e2e"
153+
row.execution_metadata.run_id = "run-0"
154+
155+
processor = RemoteRolloutProcessor(
156+
remote_base_url=remote_base_url,
157+
model_base_url=args.gateway_url,
158+
include_payloads=True,
159+
timeout_seconds=args.timeout_seconds,
160+
poll_interval=args.poll_interval,
161+
)
162+
try:
163+
task = processor(
164+
[row],
165+
RolloutProcessorConfig(
166+
completion_params=row.input_metadata.completion_params,
167+
mcp_config_path="",
168+
semaphore=asyncio.Semaphore(1),
169+
steps=1,
170+
),
171+
)[0]
172+
completed = await task
173+
finally:
174+
await processor.acleanup()
175+
server.should_exit = True
176+
thread.join(timeout=5)
177+
178+
extra = completed.execution_metadata.extra or {}
179+
turn_payloads = extra.get("assistant_turn_payloads") or []
180+
prompt_ids = None
181+
if turn_payloads:
182+
prompt_ids = turn_payloads[0].get("prompt_token_ids")
183+
if prompt_ids is None:
184+
prompt_ids = extra.get("prompt_token_ids")
185+
186+
print(f"rollout_id={rollout_id}")
187+
print(f"messages={len(completed.messages)}")
188+
print(f"assistant_turn_payloads={turn_payloads}")
189+
print(f"prompt_token_ids_len={len(prompt_ids) if isinstance(prompt_ids, list) else None}")
190+
print(f"prompt_token_ids_head={prompt_ids[:8] if isinstance(prompt_ids, list) else None}")
191+
192+
if not isinstance(prompt_ids, list) or not prompt_ids:
193+
raise AssertionError("RemoteRolloutProcessor did not hydrate prompt_token_ids")
194+
195+
196+
def main() -> None:
197+
parser = argparse.ArgumentParser(description=__doc__)
198+
parser.add_argument("--gateway-url", default=os.getenv("EP_MODEL_BASE_URL", "https://litellm-gateway-dev-j4kzagdteq-uc.a.run.app"))
199+
parser.add_argument("--api-base-url", default=os.getenv("FIREWORKS_API_BASE_URL", "https://dev.api.fireworks.ai/inference/v1"))
200+
parser.add_argument("--model", default=os.getenv("TRACING_E2E_MODEL", "accounts/pyroworks-dev/deployments/malaysia2-intended-butterfly"))
201+
parser.add_argument("--api-key", default=None)
202+
parser.add_argument("--port", type=int, default=0)
203+
parser.add_argument("--timeout-seconds", type=float, default=180.0)
204+
parser.add_argument("--poll-interval", type=float, default=2.0)
205+
asyncio.run(_run(parser.parse_args()))
206+
207+
208+
if __name__ == "__main__":
209+
main()

0 commit comments

Comments
 (0)