Skip to content

Commit ac2df2f

Browse files
abrichrclaude
andauthored
fix: TelemetryCallback __bases__ crash + 12 TRL integration tests (#231)
The dynamic __bases__ assignment to inject TrainerCallback as a base class fails in Python: "deallocator differs from object". Fixed by creating a proper subclass at definition time instead. 12 new tests: - Mock rollout_func: correct keys, count, reward variance - Config separation: TrainingConfig has no TRL fields, wrapper accepts trl_config - Wrapper construction: all callback combinations, trl_config passthrough - TelemetryCallback: importable, fires events Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 0acd411 commit ac2df2f

2 files changed

Lines changed: 251 additions & 9 deletions

File tree

openadapt_evals/integrations/trl_callbacks.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -162,16 +162,17 @@ def on_train_end(
162162

163163

164164
# Register as a TrainerCallback subclass at import time so TRL recognizes it.
165-
# If transformers is not installed, the class still works as a plain object
166-
# (the callback methods are called by name, not by inheritance check in recent
167-
# TRL versions).
165+
# If transformers is installed, wrap with proper inheritance.
166+
# We can't patch __bases__ after the fact (Python doesn't allow it when
167+
# deallocators differ), so we create a subclass instead.
168168
try:
169169
from transformers import TrainerCallback as _TrainerCallback
170170

171-
# Dynamically add TrainerCallback as a base class
172-
TelemetryCallback.__bases__ = (_TrainerCallback,) + TelemetryCallback.__bases__
171+
class _TelemetryCallbackWithBase(_TrainerCallback, TelemetryCallback):
172+
"""TelemetryCallback with proper TrainerCallback inheritance."""
173+
pass
174+
175+
# Replace the module-level name so imports get the subclass
176+
TelemetryCallback = _TelemetryCallbackWithBase # type: ignore[misc]
173177
except ImportError:
174-
logger.debug(
175-
"transformers not installed; TelemetryCallback will work as a "
176-
"duck-typed callback but won't inherit from TrainerCallback"
177-
)
178+
pass # TelemetryCallback works as duck-typed callback without inheritance

tests/test_trl_integration.py

Lines changed: 241 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,241 @@
1+
"""Tests for TRL GRPOTrainer integration.
2+
3+
Validates the rollout_func, mock mode, config separation, and wrapper
4+
without requiring a GPU, real model, or WAA server.
5+
"""
6+
7+
from __future__ import annotations
8+
9+
from unittest.mock import MagicMock, patch
10+
import pytest
11+
12+
13+
# ---------------------------------------------------------------------------
14+
# Mock rollout_func tests
15+
# ---------------------------------------------------------------------------
16+
17+
18+
class TestMockRolloutFunc:
19+
"""Test the mock rollout function from train_trl_grpo.py."""
20+
21+
def _make_task_configs(self, n=3):
22+
"""Create simple task configs."""
23+
from openadapt_evals.task_config import TaskConfig
24+
25+
configs = []
26+
for i in range(n):
27+
tc = MagicMock(spec=TaskConfig)
28+
tc.name = f"Task {i}"
29+
tc.id = f"task-{i}"
30+
tc.milestones = [MagicMock() for _ in range(2)]
31+
tc.max_steps = 10
32+
configs.append(tc)
33+
return configs
34+
35+
def test_mock_returns_correct_keys(self):
36+
"""Mock rollout returns prompt_ids, completion_ids, logprobs, env_reward."""
37+
# Import the mock creator from the training script
38+
import importlib.util
39+
spec = importlib.util.spec_from_file_location(
40+
"train_trl_grpo", "scripts/train_trl_grpo.py",
41+
)
42+
mod = importlib.util.module_from_spec(spec)
43+
spec.loader.exec_module(mod)
44+
45+
configs = self._make_task_configs()
46+
rollout_func = mod.create_mock_rollout_func(configs)
47+
48+
mock_trainer = MagicMock()
49+
mock_trainer.args.num_generations = 4
50+
51+
result = rollout_func(["Task 0", "Task 1"], mock_trainer)
52+
53+
assert "prompt_ids" in result
54+
assert "completion_ids" in result
55+
assert "logprobs" in result
56+
assert "env_reward" in result
57+
58+
def test_mock_returns_correct_count(self):
59+
"""Mock returns num_prompts * num_generations entries."""
60+
import importlib.util
61+
spec = importlib.util.spec_from_file_location(
62+
"train_trl_grpo", "scripts/train_trl_grpo.py",
63+
)
64+
mod = importlib.util.module_from_spec(spec)
65+
spec.loader.exec_module(mod)
66+
67+
configs = self._make_task_configs()
68+
rollout_func = mod.create_mock_rollout_func(configs)
69+
70+
mock_trainer = MagicMock()
71+
mock_trainer.args.num_generations = 4
72+
73+
result = rollout_func(["Task 0", "Task 1"], mock_trainer)
74+
75+
expected = 2 * 4 # 2 prompts * 4 generations
76+
assert len(result["env_reward"]) == expected
77+
assert len(result["prompt_ids"]) == expected
78+
79+
def test_mock_has_reward_variance(self):
80+
"""Mock produces different reward values (needed for GRPO)."""
81+
import importlib.util
82+
spec = importlib.util.spec_from_file_location(
83+
"train_trl_grpo", "scripts/train_trl_grpo.py",
84+
)
85+
mod = importlib.util.module_from_spec(spec)
86+
spec.loader.exec_module(mod)
87+
88+
configs = self._make_task_configs()
89+
rollout_func = mod.create_mock_rollout_func(configs)
90+
91+
mock_trainer = MagicMock()
92+
mock_trainer.args.num_generations = 8
93+
94+
# Run multiple times to get reward variance (randomized)
95+
all_rewards = []
96+
for _ in range(5):
97+
result = rollout_func(["Task 0"], mock_trainer)
98+
all_rewards.extend(result["env_reward"])
99+
100+
unique_rewards = set(all_rewards)
101+
assert len(unique_rewards) > 1, (
102+
f"Mock should produce reward variance, got {unique_rewards}"
103+
)
104+
105+
106+
# ---------------------------------------------------------------------------
107+
# Config separation tests
108+
# ---------------------------------------------------------------------------
109+
110+
111+
class TestConfigSeparation:
112+
"""Verify TrainingConfig and TRL GRPOConfig have clean separation."""
113+
114+
def test_training_config_has_no_trl_fields(self):
115+
"""TrainingConfig should NOT have loss_type, gradient_accumulation, etc."""
116+
from openadapt_evals.training.standalone.config import TrainingConfig
117+
118+
tc = TrainingConfig()
119+
# These belong to TRL's GRPOConfig, not ours
120+
assert not hasattr(tc, "loss_type"), "loss_type belongs in GRPOConfig"
121+
assert not hasattr(tc, "gradient_accumulation_steps"), "belongs in GRPOConfig"
122+
assert not hasattr(tc, "per_device_train_batch_size"), "belongs in GRPOConfig"
123+
assert not hasattr(tc, "bf16"), "belongs in GRPOConfig"
124+
assert not hasattr(tc, "report_to"), "belongs in GRPOConfig"
125+
assert not hasattr(tc, "use_vllm"), "belongs in GRPOConfig"
126+
127+
def test_training_config_has_our_fields(self):
128+
"""TrainingConfig should have OpenAdapt-specific fields."""
129+
from openadapt_evals.training.standalone.config import TrainingConfig
130+
131+
tc = TrainingConfig()
132+
assert hasattr(tc, "server_url")
133+
assert hasattr(tc, "task_dir")
134+
assert hasattr(tc, "constrained_decoding")
135+
assert hasattr(tc, "max_new_tokens")
136+
assert hasattr(tc, "vision_loss_mode")
137+
assert hasattr(tc, "model_name")
138+
assert hasattr(tc, "use_unsloth")
139+
assert hasattr(tc, "weave_project")
140+
141+
def test_wrapper_accepts_trl_config(self):
142+
"""The TRL wrapper accepts a trl_config kwarg."""
143+
from openadapt_evals.training.trl_wrapper import GRPOTrainer
144+
from openadapt_evals.training.standalone.config import TrainingConfig
145+
146+
tc = TrainingConfig(task_dir="tasks/")
147+
148+
# Should not crash — trl_config is stored, not used until train()
149+
trainer = GRPOTrainer(tc, trl_config="mock_grpo_config")
150+
assert trainer._trl_config == "mock_grpo_config"
151+
152+
def test_wrapper_defaults_without_trl_config(self):
153+
"""Without trl_config, wrapper builds defaults from TrainingConfig."""
154+
from openadapt_evals.training.trl_wrapper import GRPOTrainer
155+
from openadapt_evals.training.standalone.config import TrainingConfig
156+
157+
tc = TrainingConfig(task_dir="tasks/")
158+
trainer = GRPOTrainer(tc)
159+
assert trainer._trl_config is None # will build defaults in train()
160+
161+
162+
# ---------------------------------------------------------------------------
163+
# TRL wrapper construction tests
164+
# ---------------------------------------------------------------------------
165+
166+
167+
class TestTRLWrapperConstruction:
168+
"""Test the wrapper can be constructed with all callback combinations."""
169+
170+
def test_no_callbacks(self):
171+
from openadapt_evals.training.trl_wrapper import GRPOTrainer
172+
from openadapt_evals.training.standalone.config import TrainingConfig
173+
174+
trainer = GRPOTrainer(TrainingConfig())
175+
assert trainer._on_model_loaded is None
176+
assert trainer._on_step_complete is None
177+
178+
def test_all_callbacks(self):
179+
from openadapt_evals.training.trl_wrapper import GRPOTrainer
180+
from openadapt_evals.training.standalone.config import TrainingConfig
181+
182+
fn = lambda *a, **kw: None
183+
trainer = GRPOTrainer(
184+
TrainingConfig(),
185+
on_model_loaded=fn,
186+
on_before_collect=fn,
187+
on_rollout_complete=fn,
188+
on_step_complete=fn,
189+
)
190+
assert trainer._on_model_loaded is fn
191+
assert trainer._on_before_collect is fn
192+
assert trainer._on_rollout_complete is fn
193+
assert trainer._on_step_complete is fn
194+
195+
def test_trl_config_passthrough(self):
196+
"""TRL config is stored as-is, not translated."""
197+
from openadapt_evals.training.trl_wrapper import GRPOTrainer
198+
from openadapt_evals.training.standalone.config import TrainingConfig
199+
200+
mock_trl = MagicMock()
201+
mock_trl.loss_type = "dapo"
202+
mock_trl.output_dir = "/tmp/test"
203+
204+
trainer = GRPOTrainer(TrainingConfig(), trl_config=mock_trl)
205+
assert trainer._trl_config.loss_type == "dapo"
206+
assert trainer._trl_config.output_dir == "/tmp/test"
207+
208+
209+
# ---------------------------------------------------------------------------
210+
# TelemetryCallback tests
211+
# ---------------------------------------------------------------------------
212+
213+
214+
class TestTelemetryCallback:
215+
"""Test the TRL TelemetryCallback."""
216+
217+
def test_callback_importable(self):
218+
try:
219+
from openadapt_evals.integrations.trl_callbacks import TelemetryCallback
220+
cb = TelemetryCallback()
221+
assert cb is not None
222+
except ImportError:
223+
pytest.skip("trl_callbacks not available")
224+
225+
def test_callback_fires_events(self):
226+
try:
227+
from openadapt_evals.integrations.trl_callbacks import TelemetryCallback
228+
except ImportError:
229+
pytest.skip("trl_callbacks not available")
230+
231+
cb = TelemetryCallback()
232+
# These should not crash even without a real trainer
233+
args = MagicMock()
234+
state = MagicMock()
235+
state.global_step = 5
236+
state.log_history = [{"loss": 0.5, "reward_mean": 0.7}]
237+
control = MagicMock()
238+
239+
with patch("openadapt_evals.telemetry.capture_event"):
240+
cb.on_train_begin(args, state, control)
241+
cb.on_step_end(args, state, control)

0 commit comments

Comments
 (0)