Skip to content

Commit fc40bf4

Browse files
abrichrclaude
andauthored
fix: wire on_before_collect and on_rollout_complete callbacks through rollout_func (#243)
* fix: add truncation warning to TRL generate paths Add a truncation check after both generation paths (Outlines constrained and HF unconstrained) in generate_fn. When the output length reaches max_new_tokens - 1, a warning is logged suggesting to increase max_new_tokens or enable constrained_decoding. This helps diagnose cases where the model generates excessively long reasoning that gets cut off before producing a parseable action. Also replaced the tautological truncation tests in test_trl_robustness.py (which reimplemented the check logic inline) with tests that exercise the actual generate_fn code path by calling it through the rollout function with mocked torch and model.generate. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * fix: wire on_before_collect and on_rollout_complete callbacks through rollout_func The GRPOTrainer wrapper accepted on_before_collect and on_rollout_complete callbacks but silently ignored them. HookBridge stored them but only implemented on_step_end (for on_step_complete). TRL has no pre-rollout callback event, so these must fire from within make_waa_rollout_func. Changes: - Add on_before_collect and on_rollout_complete params to make_waa_rollout_func - Fire on_before_collect(task_id, env) before each episode - Fire on_rollout_complete(rollout_dict, gen_idx) after each episode - Wrap both in try/except so broken callbacks cannot crash training - Pass callbacks from GRPOTrainer.train() to make_waa_rollout_func - Remove these two callbacks from HookBridge (keep only on_step_complete) Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> --------- Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 36ac839 commit fc40bf4

4 files changed

Lines changed: 268 additions & 35 deletions

File tree

openadapt_evals/training/trl_rollout.py

Lines changed: 44 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@
8787
from openadapt_evals.training.standalone.prompt import SYSTEM_PROMPT # noqa: E402
8888

8989
# ---------------------------------------------------------------------------
90-
# Constrained decoding regex ported from standalone trainer
90+
# Constrained decoding regex -- ported from standalone trainer
9191
# ---------------------------------------------------------------------------
9292
# Matches the ``Thought: <reasoning>\nAction: <action>`` format.
9393
# All repetitions use unbounded quantifiers (+, *) instead of bounded ({1,N})
@@ -106,12 +106,12 @@
106106
# ---------------------------------------------------------------------------
107107
# When the model is SFT'd on JSON format (not DSL), switch constrained
108108
# decoding to: outlines.json(model, _AgentOutput) instead of regex.
109-
# This is NOT the default the default uses DSL regex (ACTION_REGEX).
109+
# This is NOT the default -- the default uses DSL regex (ACTION_REGEX).
110110
class _AgentOutput(BaseModel):
111111
"""Pydantic schema for Outlines JSON-mode constrained decoding.
112112
113113
Use with: ``outlines.json(model, _AgentOutput)`` once the model has
114-
been SFT'd on JSON action format. Currently unused default is DSL
114+
been SFT'd on JSON action format. Currently unused -- default is DSL
115115
regex via ACTION_REGEX.
116116
"""
117117

@@ -349,7 +349,7 @@ def _run_episode(
349349
if step_result.done:
350350
break
351351

352-
# Evaluate dense rewards if milestones, binary otherwise
352+
# Evaluate -- dense rewards if milestones, binary otherwise
353353
reward = env.evaluate_dense()
354354

355355
return prompt_ids, all_completion_ids, all_logprobs, reward
@@ -365,6 +365,8 @@ def make_waa_rollout_func(
365365
screenshot_retries: int = 3,
366366
screenshot_retry_delay: float = 1.0,
367367
stuck_threshold: int = 3,
368+
on_before_collect: Optional[Callable] = None,
369+
on_rollout_complete: Optional[Callable] = None,
368370
) -> Callable:
369371
"""Create a TRL-compatible rollout_func for WAA environments.
370372
@@ -389,6 +391,14 @@ def make_waa_rollout_func(
389391
stuck_threshold: Number of consecutive identical screenshots before
390392
breaking an episode early. Set to 0 to disable stuck detection.
391393
Ported from the standalone trainer's WAADirect.is_stuck().
394+
on_before_collect: ``(task_id, env) -> None`` callback fired before
395+
each episode begins. Useful for health checks, logging, or
396+
pre-rollout setup. A raised exception is caught and logged as
397+
a warning (does not abort the episode).
398+
on_rollout_complete: ``(rollout, index) -> None`` callback fired
399+
after each episode completes. ``rollout`` is a dict with keys
400+
``prompt``, ``task_id``, ``reward``, ``gen_idx``. A raised
401+
exception is caught and logged as a warning.
392402
393403
Returns:
394404
A callable suitable for GRPOTrainer(rollout_func=...).
@@ -423,10 +433,6 @@ def rollout_func(prompts: list[str], trainer: Any) -> dict[str, list]:
423433
num_generations = getattr(trainer.args, "num_generations", 8)
424434

425435
# --- Pre-rollout health check (P0) ---
426-
# Verify WAA server is responsive before committing GPU time to a
427-
# full batch of rollouts. Ported from standalone trainer's
428-
# _collect_group() which calls probe() before each group.
429-
# Skip for mock adapters (unittest.mock.MagicMock or WAAMockAdapter).
430436
_mod = getattr(type(adapter), "__module__", "") or ""
431437
_name = type(adapter).__name__.lower()
432438
_is_mock = "mock" in _name or "mock" in _mod
@@ -470,10 +476,6 @@ def generate_fn(screenshot_bytes: bytes, instruction: str):
470476
"""Generate action tokens from screenshot + instruction."""
471477
from PIL import Image
472478

473-
# --- Corrupt screenshot retry (P0) ---
474-
# On Azure VMs with QEMU, ~1-5% of screenshots are corrupt.
475-
# Retry with a brief delay rather than crashing the entire
476-
# rollout. Ported from standalone trainer's _collect_rollout().
477479
img = None
478480
for attempt in range(screenshot_retries):
479481
try:
@@ -507,7 +509,6 @@ def generate_fn(screenshot_bytes: bytes, instruction: str):
507509
]},
508510
]
509511

510-
# Tokenize with processor
511512
import torch
512513

513514
text_input = processor.apply_chat_template(
@@ -524,14 +525,10 @@ def generate_fn(screenshot_bytes: bytes, instruction: str):
524525
max_new_tokens=max_new_tokens,
525526
temperature=temperature,
526527
)
527-
# Tokenize the decoded text to get token IDs
528528
inner_tok = getattr(processor, "tokenizer", processor)
529529
completion_ids = inner_tok.encode(
530530
decoded, add_special_tokens=False,
531531
)
532-
# Outlines does not return per-token logprobs, so we
533-
# return empty logprobs. TRL recomputes logprobs from
534-
# the model during the training step anyway.
535532
logprobs: list[float] = []
536533

537534
# Truncation warning — detect when output was cut off
@@ -551,7 +548,6 @@ def generate_fn(screenshot_bytes: bytes, instruction: str):
551548
return_tensors="pt", padding=True,
552549
).to(device)
553550

554-
# Generate
555551
with torch.no_grad():
556552
outputs = model.generate(
557553
**inputs,
@@ -562,19 +558,16 @@ def generate_fn(screenshot_bytes: bytes, instruction: str):
562558
output_scores=True,
563559
)
564560

565-
# Extract completion tokens (everything after prompt)
566561
prompt_len = inputs["input_ids"].shape[1]
567562
completion_ids = outputs.sequences[0][prompt_len:].tolist()
568563

569-
# Compute per-token logprobs from scores
570564
logprobs = []
571565
if hasattr(outputs, "scores") and outputs.scores:
572566
for i, score in enumerate(outputs.scores):
573567
probs = torch.nn.functional.log_softmax(score[0], dim=-1)
574568
if i < len(completion_ids):
575569
logprobs.append(probs[completion_ids[i]].item())
576570

577-
# Decode text
578571
text = processor.decode(completion_ids, skip_special_tokens=True)
579572

580573
# Truncation warning — detect when output was cut off
@@ -589,14 +582,24 @@ def generate_fn(screenshot_bytes: bytes, instruction: str):
589582
return text, completion_ids, logprobs
590583

591584
for prompt in prompts:
592-
# Find matching task config
593585
tc = config_map.get(prompt)
594586

595587
for gen_idx in range(num_generations):
596588
env = RLEnvironment(adapter, task_config=tc)
597589

598590
task_id = tc.id if tc else "default"
599591

592+
# --- on_before_collect callback ---
593+
if on_before_collect is not None:
594+
try:
595+
on_before_collect(task_id, env)
596+
except Exception as exc:
597+
logger.warning(
598+
"on_before_collect callback raised for "
599+
"task_id=%s gen=%d: %s",
600+
task_id, gen_idx, exc,
601+
)
602+
600603
try:
601604
p_ids, c_ids, lps, reward = _run_episode(
602605
env, generate_fn, prompt, task_id, max_steps,
@@ -608,6 +611,25 @@ def generate_fn(screenshot_bytes: bytes, instruction: str):
608611
)
609612
p_ids, c_ids, lps, reward = [], [], [], 0.0
610613

614+
# --- on_rollout_complete callback ---
615+
if on_rollout_complete is not None:
616+
try:
617+
on_rollout_complete(
618+
{
619+
"prompt": prompt,
620+
"task_id": task_id,
621+
"reward": reward,
622+
"gen_idx": gen_idx,
623+
},
624+
gen_idx,
625+
)
626+
except Exception as exc:
627+
logger.warning(
628+
"on_rollout_complete callback raised for "
629+
"task_id=%s gen=%d: %s",
630+
task_id, gen_idx, exc,
631+
)
632+
611633
all_prompt_ids.append(p_ids)
612634
all_completion_ids.append(c_ids)
613635
all_logprobs.append(lps)

openadapt_evals/training/trl_wrapper.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,8 @@ def train(self) -> str:
166166
constrained_decoding=getattr(self._config, "constrained_decoding", False),
167167
max_new_tokens=self._config.max_new_tokens,
168168
temperature=self._config.temperature,
169+
on_before_collect=self._on_before_collect,
170+
on_rollout_complete=self._on_rollout_complete,
169171
)
170172

171173
# --- Reward ---
@@ -185,25 +187,25 @@ def env_reward_fn(completions, **kwargs):
185187
except ImportError:
186188
pass
187189

188-
if any([self._on_before_collect, self._on_rollout_complete,
189-
self._on_step_complete]):
190+
# on_before_collect and on_rollout_complete are passed directly to
191+
# make_waa_rollout_func (above) because TRL has no pre-rollout
192+
# callback. Only on_step_complete maps to TRL's on_step_end.
193+
if self._on_step_complete:
190194
try:
191195
from transformers import TrainerCallback
192196

193197
class HookBridge(TrainerCallback):
194-
def __init__(self, hooks):
195-
self._hooks = hooks
198+
def __init__(self, on_step_complete):
199+
self._on_step_complete = on_step_complete
196200

197201
def on_step_end(self, args, state, control, **kwargs):
198-
fn = self._hooks.get("on_step_complete")
199-
if fn:
200-
fn(state.global_step, [], kwargs.get("metrics", {}))
201-
202-
callbacks.append(HookBridge({
203-
"on_before_collect": self._on_before_collect,
204-
"on_rollout_complete": self._on_rollout_complete,
205-
"on_step_complete": self._on_step_complete,
206-
}))
202+
if self._on_step_complete:
203+
self._on_step_complete(
204+
state.global_step, [],
205+
kwargs.get("metrics", {}),
206+
)
207+
208+
callbacks.append(HookBridge(self._on_step_complete))
207209
except ImportError:
208210
pass
209211

tests/test_trl_integration.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -239,3 +239,58 @@ def test_callback_fires_events(self):
239239
with patch("openadapt_evals.telemetry.capture_event"):
240240
cb.on_train_begin(args, state, control)
241241
cb.on_step_end(args, state, control)
242+
243+
244+
# ---------------------------------------------------------------------------
245+
# Wrapper callback passthrough tests
246+
# ---------------------------------------------------------------------------
247+
248+
249+
class TestWrapperPassesCallbacks:
250+
"""Verify GRPOTrainer passes on_before_collect and on_rollout_complete
251+
through to make_waa_rollout_func, not into HookBridge."""
252+
253+
def test_wrapper_passes_callbacks_to_rollout_func(self):
254+
"""Verify on_before_collect and on_rollout_complete are forwarded
255+
to make_waa_rollout_func as keyword arguments.
256+
257+
Since train() has local imports of heavy dependencies (datasets, trl,
258+
torch), we verify by inspecting the source code of train() to confirm
259+
the kwargs are passed. This avoids needing GPU/torch in CI.
260+
"""
261+
from openadapt_evals.training.trl_wrapper import GRPOTrainer
262+
from openadapt_evals.training.standalone.config import TrainingConfig
263+
import inspect
264+
265+
before_fn = lambda task_id, env: None
266+
complete_fn = lambda rollout, index: None
267+
268+
trainer = GRPOTrainer(
269+
TrainingConfig(task_dir="tasks/"),
270+
on_before_collect=before_fn,
271+
on_rollout_complete=complete_fn,
272+
)
273+
274+
# 1. Verify the stored callbacks match what was passed.
275+
assert trainer._on_before_collect is before_fn
276+
assert trainer._on_rollout_complete is complete_fn
277+
278+
# 2. Verify train() source passes callbacks to make_waa_rollout_func.
279+
source = inspect.getsource(GRPOTrainer.train)
280+
assert "on_before_collect=self._on_before_collect" in source, (
281+
"train() must pass on_before_collect to make_waa_rollout_func"
282+
)
283+
assert "on_rollout_complete=self._on_rollout_complete" in source, (
284+
"train() must pass on_rollout_complete to make_waa_rollout_func"
285+
)
286+
287+
# 3. Verify HookBridge no longer stores these callbacks.
288+
hookbridge_section = source.split("class HookBridge")[1].split(
289+
"callbacks.append"
290+
)[0] if "class HookBridge" in source else ""
291+
assert "on_before_collect" not in hookbridge_section, (
292+
"HookBridge should not store on_before_collect"
293+
)
294+
assert "on_rollout_complete" not in hookbridge_section, (
295+
"HookBridge should not store on_rollout_complete"
296+
)

0 commit comments

Comments
 (0)