Skip to content

Commit 1bd5447

Browse files
Preserve per-turn trace payloads (#453)
* Add per-turn payload merge for remote traces Co-authored-by: Cursor <cursoragent@cursor.com> * Remove manual logprobs e2e additions Co-authored-by: Cursor <cursoragent@cursor.com> --------- Co-authored-by: Cursor <cursoragent@cursor.com>
1 parent 7c82b28 commit 1bd5447

2 files changed

Lines changed: 106 additions & 1 deletion

File tree

eval_protocol/pytest/tracing_utils.py

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,55 @@ def fetch_traces() -> List[EvaluationRow]:
2828
include_payloads=config.include_payloads,
2929
)
3030

31-
return DynamicDataLoader(generators=[fetch_traces], preprocess_fn=filter_longest_conversation)
31+
def preprocess_traces(rows: List[EvaluationRow]) -> List[EvaluationRow]:
32+
filtered_rows = filter_longest_conversation(rows)
33+
if config.include_payloads and filtered_rows:
34+
_merge_payloads_into_longest_row(filtered_rows[0], rows)
35+
return filtered_rows
36+
37+
return DynamicDataLoader(generators=[fetch_traces], preprocess_fn=preprocess_traces)
38+
39+
40+
def _merge_payloads_into_longest_row(longest_row: EvaluationRow, rows: List[EvaluationRow]) -> None:
41+
"""
42+
Preserve per-turn payload-derived metadata after selecting the longest trace row.
43+
44+
Each trace row carries payloads for its final assistant turn. The longest row
45+
keeps the full conversation, while its top-level execution metadata remains
46+
the payload metadata for the final completion for backward compatibility.
47+
"""
48+
target_assistants = longest_row.get_assistant_messages()
49+
assistant_turn_payloads = []
50+
51+
for row in sorted(rows, key=lambda item: len(item.messages)):
52+
source = row.last_assistant_message()
53+
source_turn_index = len(row.get_assistant_messages()) - 1
54+
if source_turn_index < 0 or source_turn_index >= len(target_assistants):
55+
continue
56+
57+
if source and source.logprobs and not target_assistants[source_turn_index].logprobs:
58+
target_assistants[source_turn_index].logprobs = source.logprobs
59+
60+
extra = row.execution_metadata.extra or {}
61+
turn_payload = {
62+
key: extra[key]
63+
for key in (
64+
"completion_logprobs",
65+
"completion_token_ids",
66+
"logprobs_metadata",
67+
"routing_matrices",
68+
"routing_metadata",
69+
)
70+
if key in extra
71+
}
72+
if turn_payload:
73+
turn_payload["assistant_turn_index"] = source_turn_index
74+
assistant_turn_payloads.append(turn_payload)
75+
76+
if assistant_turn_payloads:
77+
if longest_row.execution_metadata.extra is None:
78+
longest_row.execution_metadata.extra = {}
79+
longest_row.execution_metadata.extra["assistant_turn_payloads"] = assistant_turn_payloads
3280

3381

3482
def build_fireworks_tracing_url(

tests/pytest/test_tracing_utils.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
from eval_protocol.models import EvaluationRow, ExecutionMetadata, Message
2+
from eval_protocol.pytest.tracing_utils import _merge_payloads_into_longest_row
3+
4+
5+
def test_merge_payloads_into_longest_row_preserves_each_assistant_turn():
6+
first_turn_logprobs = {"content": [{"logprob": -0.1}, {"logprob": -0.2}]}
7+
second_turn_logprobs = {"content": [{"logprob": -0.3}]}
8+
first_turn = EvaluationRow(
9+
messages=[
10+
Message(role="user", content="What is 2+2?"),
11+
Message(role="assistant", content="4", logprobs=first_turn_logprobs),
12+
],
13+
execution_metadata=ExecutionMetadata(
14+
extra={
15+
"completion_logprobs": [-0.1, -0.2],
16+
"routing_matrices": ["first-matrix"],
17+
"routing_metadata": {"total_token_count": 1},
18+
},
19+
),
20+
)
21+
second_turn = EvaluationRow(
22+
messages=[
23+
Message(role="user", content="What is 2+2?"),
24+
Message(role="assistant", content="4"),
25+
Message(role="user", content="Use that in a sentence."),
26+
Message(role="assistant", content="4", logprobs=second_turn_logprobs),
27+
],
28+
execution_metadata=ExecutionMetadata(
29+
extra={
30+
"completion_logprobs": [-0.3],
31+
"routing_matrices": ["second-matrix"],
32+
"routing_metadata": {"total_token_count": 1},
33+
},
34+
),
35+
)
36+
37+
_merge_payloads_into_longest_row(second_turn, [first_turn, second_turn])
38+
39+
assistant_messages = second_turn.get_assistant_messages()
40+
assert assistant_messages[0].logprobs == first_turn_logprobs
41+
assert assistant_messages[1].logprobs == second_turn_logprobs
42+
assert second_turn.execution_metadata.extra is not None
43+
assert second_turn.execution_metadata.extra["routing_matrices"] == ["second-matrix"]
44+
assert second_turn.execution_metadata.extra["assistant_turn_payloads"] == [
45+
{
46+
"assistant_turn_index": 0,
47+
"completion_logprobs": [-0.1, -0.2],
48+
"routing_matrices": ["first-matrix"],
49+
"routing_metadata": {"total_token_count": 1},
50+
},
51+
{
52+
"assistant_turn_index": 1,
53+
"completion_logprobs": [-0.3],
54+
"routing_matrices": ["second-matrix"],
55+
"routing_metadata": {"total_token_count": 1},
56+
},
57+
]

0 commit comments

Comments
 (0)