@@ -45,14 +45,49 @@ def policy_gradient_loss(current_logps, old_logps, advantages, epsilon=0.2):
4545class GRPOTrainer :
4646 """Standalone GRPO trainer with direct WAA HTTP integration."""
4747
48- def __init__ (self , config : TrainingConfig ) -> None :
48+ def __init__ (
49+ self ,
50+ config : TrainingConfig ,
51+ * ,
52+ on_model_loaded : Any | None = None ,
53+ on_before_collect : Any | None = None ,
54+ on_rollout_complete : Any | None = None ,
55+ on_step_complete : Any | None = None ,
56+ ) -> None :
57+ """Initialize the trainer.
58+
59+ Args:
60+ config: Training configuration.
61+ on_model_loaded: ``(model, processor) -> None``
62+ Called after model and processor are loaded but before
63+ training starts. Use for custom setup like enabling
64+ gradient checkpointing on specific submodules or
65+ attaching hooks.
66+ on_before_collect: ``(task_id: str, env: WAADirect) -> None``
67+ Called before each rollout group collection. Use for
68+ WAA health checks, tunnel verification, or task-specific
69+ setup.
70+ on_rollout_complete: ``(rollout: Rollout, index: int) -> None``
71+ Called after each individual rollout. Use for capturing
72+ screenshots, thought traces, or per-rollout W&B logging.
73+ on_step_complete: ``(step: int, rollouts: list[Rollout], metrics: dict) -> None``
74+ Called after each training step with all rollouts and
75+ computed metrics (reward_mean, loss, etc.). Use for
76+ W&B step logging, early stopping, or custom eval.
77+ """
4978 self ._config = config
5079 self ._model : Any = None
5180 self ._processor : Any = None
5281 self ._optimizer : Any = None
5382 self ._env : WAADirect | None = None
5483 self ._task_configs : dict [str , Any ] = {}
5584
85+ # Callback hooks (all optional, default None = no-op)
86+ self ._on_model_loaded = on_model_loaded
87+ self ._on_before_collect = on_before_collect
88+ self ._on_rollout_complete = on_rollout_complete
89+ self ._on_step_complete = on_step_complete
90+
5691 # --- Constrained decoding -------------------------------------------
5792
5893 # Regex that matches ALL valid action formats. Allows a free-form
@@ -213,6 +248,9 @@ def _collect_group(self, task_id: str) -> list[Rollout]:
213248 """Collect N rollouts for one GRPO gradient step."""
214249 assert self ._env is not None
215250
251+ if self ._on_before_collect is not None :
252+ self ._on_before_collect (task_id , self ._env )
253+
216254 # Pre-rollout health check: verify WAA is responsive before committing
217255 # to a full group of rollouts (avoids wasting time on a dead server).
218256 probe = self ._env .probe ()
@@ -237,6 +275,8 @@ def _collect_group(self, task_id: str) -> list[Rollout]:
237275 r = self ._collect_rollout (task_id , instruction )
238276 rollouts .append (r )
239277 logger .info ("Rollout %d: %d steps, reward=%.2f" , i + 1 , len (r .steps ), r .reward )
278+ if self ._on_rollout_complete is not None :
279+ self ._on_rollout_complete (r , i )
240280 return rollouts
241281
242282 def _compute_rollout_loss (self , rollout : Rollout , advantage : float , scale : float ) -> float :
@@ -384,6 +424,10 @@ def train(self) -> str:
384424 self ._config .model_name , load_in_4bit = self ._config .load_in_4bit ,
385425 lora_r = self ._config .lora_r , lora_alpha = self ._config .lora_alpha ,
386426 lora_checkpoint = self ._config .lora_checkpoint )
427+
428+ if self ._on_model_loaded is not None :
429+ self ._on_model_loaded (self ._model , self ._processor )
430+
387431 self ._optimizer = torch .optim .AdamW (
388432 [p for p in self ._model .parameters () if p .requires_grad ], lr = self ._config .learning_rate )
389433 self ._env = WAADirect (server_url = self ._config .server_url , screen_size = self ._config .screen_size )
@@ -408,6 +452,10 @@ def train(self) -> str:
408452 m .update ({"step" : step , "task_id" : task_id , "elapsed" : time .time () - t0 , "step_time" : time .time () - ts })
409453 logger .info ("Step %d/%d: reward=%.2f loss=%.4f time=%.1fs" ,
410454 step + 1 , self ._config .num_training_steps , m .get ("reward_mean" , 0 ), m .get ("loss" , 0 ), m ["step_time" ])
455+
456+ if self ._on_step_complete is not None :
457+ self ._on_step_complete (step , rollouts , m )
458+
411459 if (step + 1 ) % self ._config .save_every_steps == 0 :
412460 self ._save_checkpoint (step + 1 )
413461
0 commit comments