Skip to content

fix: wire on_before_collect and on_rollout_complete callbacks through rollout_func#243

Merged
abrichr merged 3 commits into
mainfrom
fix/hookbridge-wire-callbacks
Mar 29, 2026
Merged

fix: wire on_before_collect and on_rollout_complete callbacks through rollout_func#243
abrichr merged 3 commits into
mainfrom
fix/hookbridge-wire-callbacks

Conversation

@abrichr
Copy link
Copy Markdown
Member

@abrichr abrichr commented Mar 29, 2026

Summary

  • Wire on_before_collect and on_rollout_complete callbacks through make_waa_rollout_func() instead of silently storing them in HookBridge (which only implemented on_step_end)
  • Fire on_before_collect(task_id, env) before each episode and on_rollout_complete(rollout_dict, gen_idx) after each episode, with try/except wrapping so broken callbacks cannot crash training
  • Clean up HookBridge to only handle on_step_complete (the only callback that maps to a TRL TrainerCallback event)

Test plan

  • test_on_before_collect_fires -- verify callback called with correct (task_id, env) args
  • test_on_rollout_complete_fires -- verify callback receives reward and gen_idx
  • test_callbacks_optional -- verify rollout_func works when callbacks are None
  • test_broken_callback_does_not_crash_training -- verify try/except wrapping
  • test_wrapper_passes_callbacks_to_rollout_func -- verify GRPOTrainer forwards callbacks to make_waa_rollout_func via source inspection
  • All 32 TRL tests pass, full suite 1474 passed (7 pre-existing failures unrelated)

🤖 Generated with Claude Code

abrichr and others added 3 commits March 29, 2026 14:44
Add a truncation check after both generation paths (Outlines constrained
and HF unconstrained) in generate_fn. When the output length reaches
max_new_tokens - 1, a warning is logged suggesting to increase
max_new_tokens or enable constrained_decoding. This helps diagnose
cases where the model generates excessively long reasoning that gets
cut off before producing a parseable action.

Also replaced the tautological truncation tests in test_trl_robustness.py
(which reimplemented the check logic inline) with tests that exercise the
actual generate_fn code path by calling it through the rollout function
with mocked torch and model.generate.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
… rollout_func

The GRPOTrainer wrapper accepted on_before_collect and on_rollout_complete
callbacks but silently ignored them. HookBridge stored them but only
implemented on_step_end (for on_step_complete). TRL has no pre-rollout
callback event, so these must fire from within make_waa_rollout_func.

Changes:
- Add on_before_collect and on_rollout_complete params to make_waa_rollout_func
- Fire on_before_collect(task_id, env) before each episode
- Fire on_rollout_complete(rollout_dict, gen_idx) after each episode
- Wrap both in try/except so broken callbacks cannot crash training
- Pass callbacks from GRPOTrainer.train() to make_waa_rollout_func
- Remove these two callbacks from HookBridge (keep only on_step_complete)

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@abrichr abrichr merged commit fc40bf4 into main Mar 29, 2026
1 check passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant