Skip to content

Commit b6c1af9

Browse files
committed
feat: preserve message token ids
1 parent b3b02c8 commit b6c1af9

4 files changed

Lines changed: 64 additions & 3 deletions

File tree

eval_protocol/models.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from openai.types.chat.chat_completion_message_tool_call import (
1515
ChatCompletionMessageToolCall,
1616
)
17-
from pydantic import BaseModel, ConfigDict, Field
17+
from pydantic import BaseModel, ConfigDict, Field, model_validator
1818

1919
from eval_protocol.get_pep440_version import get_pep440_version
2020
from eval_protocol.human_id import generate_id
@@ -517,6 +517,13 @@ class Message(BaseModel):
517517
function_call: Optional[FunctionCall] = None
518518
control_plane_step: Optional[Dict[str, Any]] = None
519519
weight: Optional[int] = None
520+
token_ids: Optional[List[int]] = Field(
521+
default=None,
522+
description=(
523+
"Optional token IDs for this message. When set on assistant messages, "
524+
"these should come from the same generation call as logprobs."
525+
),
526+
)
520527
logprobs: Optional[Any] = Field(
521528
default=None,
522529
description=(
@@ -529,9 +536,21 @@ def dump_mdoel_for_chat_completion_request(self):
529536
"""Only keep chat completion accepted fields"""
530537
return self.model_dump(
531538
exclude_none=True,
532-
exclude={"control_plane_step", "reasoning_content", "weight", "logprobs"},
539+
exclude={"control_plane_step", "reasoning_content", "weight", "token_ids", "logprobs"},
533540
)
534541

542+
@model_validator(mode="after")
543+
def _validate_token_ids_logprobs_alignment(self) -> "Message":
544+
if self.token_ids is None or self.logprobs is None:
545+
return self
546+
if isinstance(self.logprobs, list) and all(isinstance(lp, (int, float)) for lp in self.logprobs):
547+
if len(self.token_ids) != len(self.logprobs):
548+
raise ValueError(
549+
"token_ids and float logprobs must have the same length "
550+
f"(got {len(self.token_ids)} token_ids and {len(self.logprobs)} logprobs)"
551+
)
552+
return self
553+
535554
@classmethod
536555
def model_validate(cls, obj, *args, **kwargs):
537556
if isinstance(obj, dict):

eval_protocol/pytest/default_single_turn_rollout_process.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,28 @@ def _serialize_logprobs(logprobs: Any) -> Any:
4444
return logprobs
4545

4646

47+
def _extract_token_ids_from_logprobs(logprobs: Any) -> List[int] | None:
48+
"""Extract token IDs from a serialized provider logprobs payload when present."""
49+
50+
if not isinstance(logprobs, dict):
51+
return None
52+
53+
content = logprobs.get("content")
54+
if isinstance(content, list) and content:
55+
token_ids: List[int] = []
56+
for item in content:
57+
if not isinstance(item, dict) or item.get("token_id") is None:
58+
return None
59+
token_ids.append(int(item["token_id"]))
60+
return token_ids
61+
62+
raw_token_ids = logprobs.get("token_ids")
63+
if isinstance(raw_token_ids, list) and raw_token_ids:
64+
return [int(token_id) for token_id in raw_token_ids]
65+
66+
return None
67+
68+
4769
class SingleTurnRolloutProcessor(RolloutProcessor):
4870
"""Single turn rollout processor for direct LLM calls."""
4971

@@ -136,6 +158,7 @@ async def process_row(row: EvaluationRow) -> EvaluationRow:
136158
assistant_message = response.choices[0].message
137159
finish_reason = getattr(response.choices[0], "finish_reason", None)
138160
assistant_logprobs = _serialize_logprobs(getattr(response.choices[0], "logprobs", None))
161+
assistant_token_ids = _extract_token_ids_from_logprobs(assistant_logprobs)
139162

140163
# Extract content
141164
assistant_content = assistant_message.content or ""
@@ -190,6 +213,7 @@ async def process_row(row: EvaluationRow) -> EvaluationRow:
190213
content=assistant_content,
191214
reasoning_content=reasoning_content,
192215
tool_calls=converted_tool_calls,
216+
token_ids=assistant_token_ids,
193217
logprobs=assistant_logprobs,
194218
)
195219
]

tests/test_eval_protocol_import.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -262,6 +262,23 @@ def test_message_creation(self):
262262
assert msg.role == "user"
263263
assert msg.content == "Test message"
264264

265+
def test_message_preserves_token_ids(self):
266+
"""Test token IDs round-trip on messages."""
267+
from eval_protocol import Message
268+
269+
msg = Message(role="assistant", content="Hi", token_ids=[1, 2], logprobs=[-0.1, -0.2])
270+
assert msg.model_dump()["token_ids"] == [1, 2]
271+
272+
def test_message_rejects_misaligned_float_logprobs(self):
273+
"""Test token IDs and flat float logprobs must align."""
274+
import pytest
275+
from pydantic import ValidationError
276+
277+
from eval_protocol import Message
278+
279+
with pytest.raises(ValidationError):
280+
Message(role="assistant", content="Hi", token_ids=[1, 2], logprobs=[-0.1])
281+
265282
def test_utility_functions(self):
266283
"""Test that utility functions work through eval_protocol."""
267284
from eval_protocol import create_llm_resource, load_jsonl

tests/test_rollout_logprobs.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def test_single_turn_rollout_captures_logprobs(monkeypatch):
2828
async def fake_acompletion(**kwargs):
2929
assert kwargs["logprobs"] is True
3030
assert kwargs["top_logprobs"] == 2
31-
logprobs = {"content": [{"token": "hello", "logprob": -0.1, "top_logprobs": []}]}
31+
logprobs = {"content": [{"token": "hello", "token_id": 15339, "logprob": -0.1, "top_logprobs": []}]}
3232
return ModelResponse(
3333
id="resp-1",
3434
choices=[
@@ -53,6 +53,7 @@ async def _run() -> None:
5353
assistant_logprobs = completed_rows[0].messages[-1].logprobs
5454
assert isinstance(assistant_logprobs, dict)
5555
assert assistant_logprobs["content"][0]["token"] == "hello"
56+
assert completed_rows[0].messages[-1].token_ids == [15339]
5657
assert assistant_logprobs["content"][0]["logprob"] == -0.1
5758

5859
asyncio.run(_run())

0 commit comments

Comments
 (0)