Commit fc40bf4
fix: wire on_before_collect and on_rollout_complete callbacks through rollout_func (#243)
* fix: add truncation warning to TRL generate paths
Add a truncation check after both generation paths (Outlines constrained
and HF unconstrained) in generate_fn. When the output length reaches
max_new_tokens - 1, a warning is logged suggesting to increase
max_new_tokens or enable constrained_decoding. This helps diagnose
cases where the model generates excessively long reasoning that gets
cut off before producing a parseable action.
Also replaced the tautological truncation tests in test_trl_robustness.py
(which reimplemented the check logic inline) with tests that exercise the
actual generate_fn code path by calling it through the rollout function
with mocked torch and model.generate.
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
* fix: wire on_before_collect and on_rollout_complete callbacks through rollout_func
The GRPOTrainer wrapper accepted on_before_collect and on_rollout_complete
callbacks but silently ignored them. HookBridge stored them but only
implemented on_step_end (for on_step_complete). TRL has no pre-rollout
callback event, so these must fire from within make_waa_rollout_func.
Changes:
- Add on_before_collect and on_rollout_complete params to make_waa_rollout_func
- Fire on_before_collect(task_id, env) before each episode
- Fire on_rollout_complete(rollout_dict, gen_idx) after each episode
- Wrap both in try/except so broken callbacks cannot crash training
- Pass callbacks from GRPOTrainer.train() to make_waa_rollout_func
- Remove these two callbacks from HookBridge (keep only on_step_complete)
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
---------
Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>1 parent 36ac839 commit fc40bf4
4 files changed
Lines changed: 268 additions & 35 deletions
File tree
- openadapt_evals/training
- tests
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
87 | 87 | | |
88 | 88 | | |
89 | 89 | | |
90 | | - | |
| 90 | + | |
91 | 91 | | |
92 | 92 | | |
93 | 93 | | |
| |||
106 | 106 | | |
107 | 107 | | |
108 | 108 | | |
109 | | - | |
| 109 | + | |
110 | 110 | | |
111 | 111 | | |
112 | 112 | | |
113 | 113 | | |
114 | | - | |
| 114 | + | |
115 | 115 | | |
116 | 116 | | |
117 | 117 | | |
| |||
349 | 349 | | |
350 | 350 | | |
351 | 351 | | |
352 | | - | |
| 352 | + | |
353 | 353 | | |
354 | 354 | | |
355 | 355 | | |
| |||
365 | 365 | | |
366 | 366 | | |
367 | 367 | | |
| 368 | + | |
| 369 | + | |
368 | 370 | | |
369 | 371 | | |
370 | 372 | | |
| |||
389 | 391 | | |
390 | 392 | | |
391 | 393 | | |
| 394 | + | |
| 395 | + | |
| 396 | + | |
| 397 | + | |
| 398 | + | |
| 399 | + | |
| 400 | + | |
| 401 | + | |
392 | 402 | | |
393 | 403 | | |
394 | 404 | | |
| |||
423 | 433 | | |
424 | 434 | | |
425 | 435 | | |
426 | | - | |
427 | | - | |
428 | | - | |
429 | | - | |
430 | 436 | | |
431 | 437 | | |
432 | 438 | | |
| |||
470 | 476 | | |
471 | 477 | | |
472 | 478 | | |
473 | | - | |
474 | | - | |
475 | | - | |
476 | | - | |
477 | 479 | | |
478 | 480 | | |
479 | 481 | | |
| |||
507 | 509 | | |
508 | 510 | | |
509 | 511 | | |
510 | | - | |
511 | 512 | | |
512 | 513 | | |
513 | 514 | | |
| |||
524 | 525 | | |
525 | 526 | | |
526 | 527 | | |
527 | | - | |
528 | 528 | | |
529 | 529 | | |
530 | 530 | | |
531 | 531 | | |
532 | | - | |
533 | | - | |
534 | | - | |
535 | 532 | | |
536 | 533 | | |
537 | 534 | | |
| |||
551 | 548 | | |
552 | 549 | | |
553 | 550 | | |
554 | | - | |
555 | 551 | | |
556 | 552 | | |
557 | 553 | | |
| |||
562 | 558 | | |
563 | 559 | | |
564 | 560 | | |
565 | | - | |
566 | 561 | | |
567 | 562 | | |
568 | 563 | | |
569 | | - | |
570 | 564 | | |
571 | 565 | | |
572 | 566 | | |
573 | 567 | | |
574 | 568 | | |
575 | 569 | | |
576 | 570 | | |
577 | | - | |
578 | 571 | | |
579 | 572 | | |
580 | 573 | | |
| |||
589 | 582 | | |
590 | 583 | | |
591 | 584 | | |
592 | | - | |
593 | 585 | | |
594 | 586 | | |
595 | 587 | | |
596 | 588 | | |
597 | 589 | | |
598 | 590 | | |
599 | 591 | | |
| 592 | + | |
| 593 | + | |
| 594 | + | |
| 595 | + | |
| 596 | + | |
| 597 | + | |
| 598 | + | |
| 599 | + | |
| 600 | + | |
| 601 | + | |
| 602 | + | |
600 | 603 | | |
601 | 604 | | |
602 | 605 | | |
| |||
608 | 611 | | |
609 | 612 | | |
610 | 613 | | |
| 614 | + | |
| 615 | + | |
| 616 | + | |
| 617 | + | |
| 618 | + | |
| 619 | + | |
| 620 | + | |
| 621 | + | |
| 622 | + | |
| 623 | + | |
| 624 | + | |
| 625 | + | |
| 626 | + | |
| 627 | + | |
| 628 | + | |
| 629 | + | |
| 630 | + | |
| 631 | + | |
| 632 | + | |
611 | 633 | | |
612 | 634 | | |
613 | 635 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
166 | 166 | | |
167 | 167 | | |
168 | 168 | | |
| 169 | + | |
| 170 | + | |
169 | 171 | | |
170 | 172 | | |
171 | 173 | | |
| |||
185 | 187 | | |
186 | 188 | | |
187 | 189 | | |
188 | | - | |
189 | | - | |
| 190 | + | |
| 191 | + | |
| 192 | + | |
| 193 | + | |
190 | 194 | | |
191 | 195 | | |
192 | 196 | | |
193 | 197 | | |
194 | | - | |
195 | | - | |
| 198 | + | |
| 199 | + | |
196 | 200 | | |
197 | 201 | | |
198 | | - | |
199 | | - | |
200 | | - | |
201 | | - | |
202 | | - | |
203 | | - | |
204 | | - | |
205 | | - | |
206 | | - | |
| 202 | + | |
| 203 | + | |
| 204 | + | |
| 205 | + | |
| 206 | + | |
| 207 | + | |
| 208 | + | |
207 | 209 | | |
208 | 210 | | |
209 | 211 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
239 | 239 | | |
240 | 240 | | |
241 | 241 | | |
| 242 | + | |
| 243 | + | |
| 244 | + | |
| 245 | + | |
| 246 | + | |
| 247 | + | |
| 248 | + | |
| 249 | + | |
| 250 | + | |
| 251 | + | |
| 252 | + | |
| 253 | + | |
| 254 | + | |
| 255 | + | |
| 256 | + | |
| 257 | + | |
| 258 | + | |
| 259 | + | |
| 260 | + | |
| 261 | + | |
| 262 | + | |
| 263 | + | |
| 264 | + | |
| 265 | + | |
| 266 | + | |
| 267 | + | |
| 268 | + | |
| 269 | + | |
| 270 | + | |
| 271 | + | |
| 272 | + | |
| 273 | + | |
| 274 | + | |
| 275 | + | |
| 276 | + | |
| 277 | + | |
| 278 | + | |
| 279 | + | |
| 280 | + | |
| 281 | + | |
| 282 | + | |
| 283 | + | |
| 284 | + | |
| 285 | + | |
| 286 | + | |
| 287 | + | |
| 288 | + | |
| 289 | + | |
| 290 | + | |
| 291 | + | |
| 292 | + | |
| 293 | + | |
| 294 | + | |
| 295 | + | |
| 296 | + | |
0 commit comments