Skip to content

Commit 28b0193

Browse files
abrichrclaude
andauthored
feat: add callback hooks to standalone GRPO trainer (#198)
Four optional callback hooks eliminate the need for monkey-patching: - on_model_loaded(model, processor): Custom model setup (gradient checkpointing on specific submodules, hook attachment) - on_before_collect(task_id, env): WAA health checks, tunnel verification, task-specific setup before rollout collection - on_rollout_complete(rollout, index): Per-rollout W&B logging, screenshot/thought capture - on_step_complete(step, rollouts, metrics): Per-step W&B logging, early stopping, custom evaluation All callbacks are keyword-only with None defaults (no-op). Eliminates 3 of 6 monkey-patches reported by customer. Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent f7a7199 commit 28b0193

1 file changed

Lines changed: 49 additions & 1 deletion

File tree

openadapt_evals/training/standalone/trainer.py

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,14 +45,49 @@ def policy_gradient_loss(current_logps, old_logps, advantages, epsilon=0.2):
4545
class GRPOTrainer:
4646
"""Standalone GRPO trainer with direct WAA HTTP integration."""
4747

48-
def __init__(self, config: TrainingConfig) -> None:
48+
def __init__(
49+
self,
50+
config: TrainingConfig,
51+
*,
52+
on_model_loaded: Any | None = None,
53+
on_before_collect: Any | None = None,
54+
on_rollout_complete: Any | None = None,
55+
on_step_complete: Any | None = None,
56+
) -> None:
57+
"""Initialize the trainer.
58+
59+
Args:
60+
config: Training configuration.
61+
on_model_loaded: ``(model, processor) -> None``
62+
Called after model and processor are loaded but before
63+
training starts. Use for custom setup like enabling
64+
gradient checkpointing on specific submodules or
65+
attaching hooks.
66+
on_before_collect: ``(task_id: str, env: WAADirect) -> None``
67+
Called before each rollout group collection. Use for
68+
WAA health checks, tunnel verification, or task-specific
69+
setup.
70+
on_rollout_complete: ``(rollout: Rollout, index: int) -> None``
71+
Called after each individual rollout. Use for capturing
72+
screenshots, thought traces, or per-rollout W&B logging.
73+
on_step_complete: ``(step: int, rollouts: list[Rollout], metrics: dict) -> None``
74+
Called after each training step with all rollouts and
75+
computed metrics (reward_mean, loss, etc.). Use for
76+
W&B step logging, early stopping, or custom eval.
77+
"""
4978
self._config = config
5079
self._model: Any = None
5180
self._processor: Any = None
5281
self._optimizer: Any = None
5382
self._env: WAADirect | None = None
5483
self._task_configs: dict[str, Any] = {}
5584

85+
# Callback hooks (all optional, default None = no-op)
86+
self._on_model_loaded = on_model_loaded
87+
self._on_before_collect = on_before_collect
88+
self._on_rollout_complete = on_rollout_complete
89+
self._on_step_complete = on_step_complete
90+
5691
# --- Constrained decoding -------------------------------------------
5792

5893
# Regex that matches ALL valid action formats. Allows a free-form
@@ -213,6 +248,9 @@ def _collect_group(self, task_id: str) -> list[Rollout]:
213248
"""Collect N rollouts for one GRPO gradient step."""
214249
assert self._env is not None
215250

251+
if self._on_before_collect is not None:
252+
self._on_before_collect(task_id, self._env)
253+
216254
# Pre-rollout health check: verify WAA is responsive before committing
217255
# to a full group of rollouts (avoids wasting time on a dead server).
218256
probe = self._env.probe()
@@ -237,6 +275,8 @@ def _collect_group(self, task_id: str) -> list[Rollout]:
237275
r = self._collect_rollout(task_id, instruction)
238276
rollouts.append(r)
239277
logger.info("Rollout %d: %d steps, reward=%.2f", i + 1, len(r.steps), r.reward)
278+
if self._on_rollout_complete is not None:
279+
self._on_rollout_complete(r, i)
240280
return rollouts
241281

242282
def _compute_rollout_loss(self, rollout: Rollout, advantage: float, scale: float) -> float:
@@ -384,6 +424,10 @@ def train(self) -> str:
384424
self._config.model_name, load_in_4bit=self._config.load_in_4bit,
385425
lora_r=self._config.lora_r, lora_alpha=self._config.lora_alpha,
386426
lora_checkpoint=self._config.lora_checkpoint)
427+
428+
if self._on_model_loaded is not None:
429+
self._on_model_loaded(self._model, self._processor)
430+
387431
self._optimizer = torch.optim.AdamW(
388432
[p for p in self._model.parameters() if p.requires_grad], lr=self._config.learning_rate)
389433
self._env = WAADirect(server_url=self._config.server_url, screen_size=self._config.screen_size)
@@ -408,6 +452,10 @@ def train(self) -> str:
408452
m.update({"step": step, "task_id": task_id, "elapsed": time.time() - t0, "step_time": time.time() - ts})
409453
logger.info("Step %d/%d: reward=%.2f loss=%.4f time=%.1fs",
410454
step + 1, self._config.num_training_steps, m.get("reward_mean", 0), m.get("loss", 0), m["step_time"])
455+
456+
if self._on_step_complete is not None:
457+
self._on_step_complete(step, rollouts, m)
458+
411459
if (step + 1) % self._config.save_every_steps == 0:
412460
self._save_checkpoint(step + 1)
413461

0 commit comments

Comments
 (0)