Skip to content

Commit 0bd055d

Browse files
Add per-turn payload merge for remote traces
Co-authored-by: Cursor <cursoragent@cursor.com>
1 parent 7c82b28 commit 0bd055d

4 files changed

Lines changed: 527 additions & 1 deletion

File tree

eval_protocol/pytest/tracing_utils.py

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,55 @@ def fetch_traces() -> List[EvaluationRow]:
2828
include_payloads=config.include_payloads,
2929
)
3030

31-
return DynamicDataLoader(generators=[fetch_traces], preprocess_fn=filter_longest_conversation)
31+
def preprocess_traces(rows: List[EvaluationRow]) -> List[EvaluationRow]:
32+
filtered_rows = filter_longest_conversation(rows)
33+
if config.include_payloads and filtered_rows:
34+
_merge_payloads_into_longest_row(filtered_rows[0], rows)
35+
return filtered_rows
36+
37+
return DynamicDataLoader(generators=[fetch_traces], preprocess_fn=preprocess_traces)
38+
39+
40+
def _merge_payloads_into_longest_row(longest_row: EvaluationRow, rows: List[EvaluationRow]) -> None:
41+
"""
42+
Preserve per-turn payload-derived metadata after selecting the longest trace row.
43+
44+
Each trace row carries payloads for its final assistant turn. The longest row
45+
keeps the full conversation, while its top-level execution metadata remains
46+
the payload metadata for the final completion for backward compatibility.
47+
"""
48+
target_assistants = longest_row.get_assistant_messages()
49+
assistant_turn_payloads = []
50+
51+
for row in sorted(rows, key=lambda item: len(item.messages)):
52+
source = row.last_assistant_message()
53+
source_turn_index = len(row.get_assistant_messages()) - 1
54+
if source_turn_index < 0 or source_turn_index >= len(target_assistants):
55+
continue
56+
57+
if source and source.logprobs and not target_assistants[source_turn_index].logprobs:
58+
target_assistants[source_turn_index].logprobs = source.logprobs
59+
60+
extra = row.execution_metadata.extra or {}
61+
turn_payload = {
62+
key: extra[key]
63+
for key in (
64+
"completion_logprobs",
65+
"completion_token_ids",
66+
"logprobs_metadata",
67+
"routing_matrices",
68+
"routing_metadata",
69+
)
70+
if key in extra
71+
}
72+
if turn_payload:
73+
turn_payload["assistant_turn_index"] = source_turn_index
74+
assistant_turn_payloads.append(turn_payload)
75+
76+
if assistant_turn_payloads:
77+
if longest_row.execution_metadata.extra is None:
78+
longest_row.execution_metadata.extra = {}
79+
longest_row.execution_metadata.extra["assistant_turn_payloads"] = assistant_turn_payloads
3280

3381

3482
def build_fireworks_tracing_url(

tests/manual/test_logprobs_e2e.py

Lines changed: 320 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,320 @@
1+
"""Minimal e2e test for logprobs trace payloads via RemoteRolloutProcessor.
2+
3+
Spins up the reference remote server locally, which makes the LLM call
4+
through litellm-gateway-dev. RemoteRolloutProcessor polls the dev gateway
5+
and fetches traces with include_payloads=True.
6+
7+
Run with:
8+
cd eval-protocol-python-sdk
9+
FIREWORKS_API_KEY="$FIREWORKS_DEV_API_KEY" \\
10+
pytest tests/manual/test_logprobs_e2e.py -v -s
11+
12+
Requires gateway+consumer dev deploy with logprobs payload support and deployment:
13+
accounts/pyroworks-dev/deployments/malaysia2-careful-paprika
14+
"""
15+
16+
import os
17+
import socket
18+
import subprocess
19+
import sys
20+
import time
21+
from typing import List
22+
23+
import pytest
24+
import requests
25+
26+
from eval_protocol.data_loader.dynamic_data_loader import DynamicDataLoader
27+
from eval_protocol.models import EvaluationRow, EvaluateResult, Message, MetricResult
28+
from eval_protocol.pytest import evaluation_test
29+
from eval_protocol.pytest.remote_rollout_processor import RemoteRolloutProcessor
30+
31+
DEPLOYMENT = "accounts/pyroworks-dev/deployments/malaysia2-careful-paprika"
32+
GATEWAY_DEV_URL = "https://litellm-gateway-dev-j4kzagdteq-uc.a.run.app"
33+
FIREWORKS_DEV_INFERENCE_BASE = "https://dev.api.fireworks.ai/inference/v1"
34+
35+
36+
def _find_available_port() -> int:
37+
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
38+
s.bind(("", 0))
39+
return s.getsockname()[1]
40+
41+
42+
SERVER_PORT = _find_available_port()
43+
44+
45+
def _wait_for_server(port: int, timeout: int = 30):
46+
start = time.time()
47+
while time.time() - start < timeout:
48+
try:
49+
requests.get(f"http://127.0.0.1:{port}")
50+
return
51+
except requests.exceptions.ConnectionError:
52+
time.sleep(0.5)
53+
raise TimeoutError(f"Remote server did not start within {timeout}s")
54+
55+
56+
@pytest.fixture
57+
def remote_server_module(request):
58+
return getattr(request, "param", "tests.remote_server.remote_server")
59+
60+
61+
@pytest.fixture(autouse=True)
62+
def _remote_server(remote_server_module):
63+
env = os.environ.copy()
64+
env["FW_TRACING_GATEWAY_BASE_URL"] = GATEWAY_DEV_URL
65+
api_key = os.environ.get("FIREWORKS_API_KEY") or os.environ.get("FIREWORKS_DEV_API_KEY")
66+
if api_key:
67+
env["FIREWORKS_API_KEY"] = api_key
68+
proc = subprocess.Popen(
69+
[
70+
sys.executable,
71+
"-m",
72+
remote_server_module,
73+
"--host",
74+
"127.0.0.1",
75+
"--port",
76+
str(SERVER_PORT),
77+
],
78+
env=env,
79+
)
80+
_wait_for_server(SERVER_PORT)
81+
yield
82+
proc.terminate()
83+
proc.wait()
84+
85+
86+
def input_rows() -> List[EvaluationRow]:
87+
return [
88+
EvaluationRow(messages=[Message(role="user", content="What is 2+2?")]),
89+
]
90+
91+
92+
def two_turn_input_rows() -> List[EvaluationRow]:
93+
return [
94+
EvaluationRow(messages=[Message(role="user", content="What is 2+2?")]),
95+
]
96+
97+
98+
def _logprobs_content(message: Message) -> list:
99+
if not message.logprobs:
100+
return []
101+
return message.logprobs.get("content") or []
102+
103+
104+
@pytest.mark.parametrize(
105+
"completion_params",
106+
[
107+
{
108+
"model": DEPLOYMENT,
109+
"logprobs": True,
110+
"base_url": FIREWORKS_DEV_INFERENCE_BASE,
111+
}
112+
],
113+
)
114+
@evaluation_test(
115+
data_loaders=DynamicDataLoader(generators=[input_rows]),
116+
rollout_processor=RemoteRolloutProcessor(
117+
remote_base_url=f"http://127.0.0.1:{SERVER_PORT}",
118+
model_base_url=GATEWAY_DEV_URL,
119+
include_payloads=True,
120+
timeout_seconds=180,
121+
),
122+
)
123+
async def test_logprobs_present(row: EvaluationRow) -> EvaluationRow:
124+
"""Verify completion logprobs and Message.logprobs after remote rollout."""
125+
126+
has_response = len(row.messages) > 1
127+
assistant_msg = row.messages[-1] if has_response else None
128+
129+
extra = row.execution_metadata.extra or {}
130+
completion_logprobs = extra.get("completion_logprobs") or []
131+
has_completion_logprobs = len(completion_logprobs) > 0
132+
133+
message_content = None
134+
if assistant_msg and assistant_msg.logprobs:
135+
message_content = assistant_msg.logprobs.get("content") or []
136+
137+
has_message_logprobs = message_content is not None and len(message_content) > 0
138+
lengths_match = (
139+
has_completion_logprobs
140+
and has_message_logprobs
141+
and len(message_content) == len(completion_logprobs)
142+
)
143+
144+
if has_completion_logprobs:
145+
print(
146+
f"\n Logprobs OK: {len(completion_logprobs)} completion tokens"
147+
f" | message.content len={len(message_content) if message_content else 0}"
148+
)
149+
else:
150+
print(f"\n No logprobs in extra={extra}")
151+
152+
score = 1.0 if (has_response and has_completion_logprobs and lengths_match) else 0.0
153+
reason_parts = []
154+
if not has_response:
155+
reason_parts.append("no assistant response")
156+
if not has_completion_logprobs:
157+
reason_parts.append("no completion_logprobs in execution_metadata.extra")
158+
if not lengths_match:
159+
reason_parts.append(
160+
f"message.logprobs content length ({len(message_content or [])}) "
161+
f"!= completion_logprobs ({len(completion_logprobs)})"
162+
)
163+
164+
reason = "All checks passed" if score == 1.0 else "; ".join(reason_parts)
165+
166+
row.evaluation_result = EvaluateResult(
167+
score=score,
168+
reason=reason,
169+
metrics={
170+
"has_response": MetricResult(
171+
score=float(has_response),
172+
is_score_valid=True,
173+
reason="got response" if has_response else "no response",
174+
),
175+
"has_completion_logprobs": MetricResult(
176+
score=float(has_completion_logprobs),
177+
is_score_valid=True,
178+
reason="present" if has_completion_logprobs else "missing",
179+
),
180+
"logprobs_lengths_match": MetricResult(
181+
score=float(lengths_match),
182+
is_score_valid=True,
183+
reason="match" if lengths_match else "mismatch",
184+
),
185+
},
186+
)
187+
188+
assert has_response, f"Expected assistant response. Messages: {row.messages}"
189+
assert has_completion_logprobs, (
190+
f"Expected completion_logprobs in extra but got: {row.execution_metadata.extra}"
191+
)
192+
assert lengths_match, (
193+
"Expected len(message.logprobs['content']) == len(completion_logprobs); "
194+
f"got {len(message_content or [])} vs {len(completion_logprobs)}"
195+
)
196+
197+
return row
198+
199+
200+
@pytest.mark.parametrize(
201+
"remote_server_module",
202+
["tests.remote_server.remote_server_two_turn_logprobs"],
203+
indirect=True,
204+
)
205+
@pytest.mark.parametrize(
206+
"completion_params",
207+
[
208+
{
209+
"model": DEPLOYMENT,
210+
"logprobs": True,
211+
"base_url": FIREWORKS_DEV_INFERENCE_BASE,
212+
}
213+
],
214+
)
215+
@evaluation_test(
216+
data_loaders=DynamicDataLoader(generators=[two_turn_input_rows]),
217+
rollout_processor=RemoteRolloutProcessor(
218+
remote_base_url=f"http://127.0.0.1:{SERVER_PORT}",
219+
model_base_url=GATEWAY_DEV_URL,
220+
include_payloads=True,
221+
timeout_seconds=180,
222+
),
223+
)
224+
async def test_two_turn_logprobs_present(row: EvaluationRow) -> EvaluationRow:
225+
"""Verify each assistant turn in a two-turn remote rollout has logprobs."""
226+
227+
roles = [message.role for message in row.messages]
228+
assistant_messages = row.get_assistant_messages()
229+
logprob_lengths = [len(_logprobs_content(message)) for message in assistant_messages]
230+
231+
has_two_turn_shape = roles == ["user", "assistant", "user", "assistant"]
232+
has_two_assistant_turns = len(assistant_messages) == 2
233+
all_turns_have_logprobs = has_two_assistant_turns and all(length > 0 for length in logprob_lengths)
234+
235+
extra = row.execution_metadata.extra or {}
236+
final_completion_logprobs = extra.get("completion_logprobs") or []
237+
assistant_turn_payloads = extra.get("assistant_turn_payloads") or []
238+
final_lengths_match = (
239+
has_two_assistant_turns
240+
and len(final_completion_logprobs) > 0
241+
and len(final_completion_logprobs) == logprob_lengths[-1]
242+
)
243+
has_payloads_for_each_turn = len(assistant_turn_payloads) == len(assistant_messages)
244+
turn_payload_lengths_match = has_payloads_for_each_turn and all(
245+
payload.get("assistant_turn_index") == idx
246+
and len(payload.get("completion_logprobs") or []) == logprob_lengths[idx]
247+
for idx, payload in enumerate(assistant_turn_payloads)
248+
)
249+
250+
if all_turns_have_logprobs:
251+
print(f"\n Two-turn logprobs OK: assistant token counts={logprob_lengths}")
252+
else:
253+
print(f"\n Missing two-turn logprobs: roles={roles} token_counts={logprob_lengths}")
254+
255+
all_ok = (
256+
has_two_turn_shape
257+
and all_turns_have_logprobs
258+
and final_lengths_match
259+
and turn_payload_lengths_match
260+
)
261+
reason_parts = []
262+
if not has_two_turn_shape:
263+
reason_parts.append(f"expected user/assistant/user/assistant roles but got {roles}")
264+
if not has_two_assistant_turns:
265+
reason_parts.append(f"expected 2 assistant turns but got {len(assistant_messages)}")
266+
if has_two_assistant_turns and not all_turns_have_logprobs:
267+
reason_parts.append(f"missing assistant logprobs; token_counts={logprob_lengths}")
268+
if not final_lengths_match:
269+
reason_parts.append(
270+
"final assistant message logprobs length "
271+
f"({logprob_lengths[-1] if logprob_lengths else 0}) "
272+
f"!= completion_logprobs ({len(final_completion_logprobs)})"
273+
)
274+
if not has_payloads_for_each_turn:
275+
reason_parts.append(f"expected per-turn payloads for each assistant turn but got {assistant_turn_payloads}")
276+
if has_payloads_for_each_turn and not turn_payload_lengths_match:
277+
reason_parts.append(f"per-turn payload lengths do not match message logprobs: {assistant_turn_payloads}")
278+
279+
row.evaluation_result = EvaluateResult(
280+
score=1.0 if all_ok else 0.0,
281+
reason="All checks passed" if all_ok else "; ".join(reason_parts),
282+
metrics={
283+
"has_two_turn_shape": MetricResult(
284+
score=float(has_two_turn_shape),
285+
is_score_valid=True,
286+
reason="match" if has_two_turn_shape else "unexpected roles",
287+
),
288+
"all_turns_have_logprobs": MetricResult(
289+
score=float(all_turns_have_logprobs),
290+
is_score_valid=True,
291+
reason="present" if all_turns_have_logprobs else "missing",
292+
),
293+
"final_logprobs_lengths_match": MetricResult(
294+
score=float(final_lengths_match),
295+
is_score_valid=True,
296+
reason="match" if final_lengths_match else "mismatch",
297+
),
298+
"turn_payload_lengths_match": MetricResult(
299+
score=float(turn_payload_lengths_match),
300+
is_score_valid=True,
301+
reason="match" if turn_payload_lengths_match else "mismatch",
302+
),
303+
},
304+
)
305+
306+
assert has_two_turn_shape, f"Expected two-turn conversation but got roles: {roles}"
307+
assert all_turns_have_logprobs, (
308+
"Expected logprobs on both assistant turns; "
309+
f"token_counts={logprob_lengths}, messages={row.messages}"
310+
)
311+
assert final_lengths_match, (
312+
"Expected final assistant logprobs to match completion_logprobs; "
313+
f"got {logprob_lengths[-1] if logprob_lengths else 0} vs {len(final_completion_logprobs)}"
314+
)
315+
assert turn_payload_lengths_match, (
316+
"Expected assistant_turn_payloads to match each assistant turn's logprobs; "
317+
f"payloads={assistant_turn_payloads}, token_counts={logprob_lengths}"
318+
)
319+
320+
return row

0 commit comments

Comments
 (0)