-
Notifications
You must be signed in to change notification settings - Fork 17
Expand file tree
/
Copy pathtest_single_turn_rollout_processor.py
More file actions
165 lines (125 loc) · 5.71 KB
/
Copy pathtest_single_turn_rollout_processor.py
File metadata and controls
165 lines (125 loc) · 5.71 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
import asyncio
from types import SimpleNamespace
import pytest
from eval_protocol.models import EvaluationRow, Message
from eval_protocol.pytest import SingleTurnRolloutProcessor
class _DummyConfig:
def __init__(self):
self.completion_params = {"model": "fake-model", "temperature": 0}
self.semaphore = asyncio.Semaphore(10)
@pytest.mark.asyncio
async def test_single_turn_drops_trailing_assistant_by_default(monkeypatch):
# Arrange dataset row with trailing assistant message
row = EvaluationRow(
messages=[
Message(role="user", content="What is 2+2?"),
Message(role="assistant", content="Old response"),
]
)
# Capture the messages payload passed to the LLM call
captured = {}
# Patch module-level imports in the processor module
import eval_protocol.pytest.default_single_turn_rollout_process as mod
class StubChoices:
pass
class StubModelResponse:
def __init__(self, text: str):
self.choices = [StubChoices()]
# Emulate OpenAI-like response.message fields
self.choices[0].message = SimpleNamespace(content=text, tool_calls=None)
# Minimal usage payload
self.usage = SimpleNamespace(prompt_tokens=1, completion_tokens=1, total_tokens=2)
async def fake_acompletion(**kwargs):
# Verify that trailing assistant was dropped before sending
msgs = kwargs.get("messages", [])
assert msgs, "Expected non-empty messages payload"
captured["messages"] = msgs
assert msgs[-1]["role"] != "assistant", "Trailing assistant should be dropped by default"
return StubModelResponse(text="4")
# Monkeypatch the processor module's symbols to avoid dependency on litellm types
monkeypatch.setattr(mod, "ModelResponse", StubModelResponse, raising=True)
monkeypatch.setattr(mod, "Choices", StubChoices, raising=True)
monkeypatch.setattr(mod, "acompletion", fake_acompletion, raising=True)
processor = SingleTurnRolloutProcessor()
config = _DummyConfig()
# Act
tasks = processor([row], config)
out = await tasks[0]
# Assert: request trimmed the trailing assistant
sent_msgs = captured["messages"]
assert len(sent_msgs) == 1
assert sent_msgs[0]["role"] == "user"
assert out.messages[-1].role == "assistant"
assert out.messages[-1].content == "4"
# Ensure previous trailing assistant was not duplicated
assert [m.role for m in out.messages] == ["user", "assistant"]
@pytest.mark.asyncio
async def test_single_turn_keeps_trailing_assistant_when_disabled(monkeypatch):
# Arrange dataset row with trailing assistant message
row = EvaluationRow(
messages=[
Message(role="user", content="Say hi"),
Message(role="assistant", content="Hi!"),
]
)
captured = {}
import eval_protocol.pytest.default_single_turn_rollout_process as mod
class StubChoices:
pass
class StubModelResponse:
def __init__(self, text: str):
self.choices = [StubChoices()]
self.choices[0].message = SimpleNamespace(content=text, tool_calls=None)
self.usage = SimpleNamespace(prompt_tokens=1, completion_tokens=1, total_tokens=2)
async def fake_acompletion(**kwargs):
msgs = kwargs.get("messages", [])
captured["messages"] = msgs
# With opt-out, trailing assistant is preserved
assert msgs[-1]["role"] == "assistant"
return StubModelResponse(text="Hello again")
monkeypatch.setattr(mod, "ModelResponse", StubModelResponse, raising=True)
monkeypatch.setattr(mod, "Choices", StubChoices, raising=True)
monkeypatch.setattr(mod, "acompletion", fake_acompletion, raising=True)
processor = SingleTurnRolloutProcessor(drop_trailing_assistant_messages=False)
config = _DummyConfig()
# Act
tasks = processor([row], config)
out = await tasks[0]
# Assert: both original messages plus new assistant
sent_msgs = captured["messages"]
assert [m["role"] for m in sent_msgs] == ["user", "assistant"]
assert [m.role for m in out.messages] == ["user", "assistant", "assistant"]
assert out.messages[-1].content == "Hello again"
@pytest.mark.asyncio
async def test_single_turn_handles_missing_usage_block(monkeypatch):
row = EvaluationRow(messages=[Message(role="user", content="Describe the picture")])
import eval_protocol.pytest.default_single_turn_rollout_process as mod
class StubChoices:
pass
class StubModelResponse:
def __init__(self, text: str):
self.choices = [StubChoices()]
self.choices[0].message = SimpleNamespace(content=text, tool_calls=None)
self.usage = None
async def fake_acompletion(**kwargs):
return StubModelResponse(text="It looks like creme brulee")
class StubLogger:
def __init__(self):
self.logged = []
def log(self, row):
self.logged.append(row)
def read(self, rollout_id=None):
return list(self.logged)
stub_logger = StubLogger()
monkeypatch.setattr(mod, "ModelResponse", StubModelResponse, raising=True)
monkeypatch.setattr(mod, "Choices", StubChoices, raising=True)
monkeypatch.setattr(mod, "acompletion", fake_acompletion, raising=True)
monkeypatch.setattr(mod, "default_logger", stub_logger, raising=False)
processor = SingleTurnRolloutProcessor()
config = _DummyConfig()
tasks = processor([row], config)
out = await tasks[0]
assert [m.role for m in out.messages] == ["user", "assistant"]
assert out.messages[-1].content == "It looks like creme brulee"
# Usage should remain unset when the provider omits it
assert out.execution_metadata.usage is None