Skip to content

Commit c1b42ca

Browse files
author
Pooya Moradi
committed
Add intermediate eval hook: fire evaluate() every eval_interval outer steps
`eval_interval` was a silently-dead config: even though it's plumbed into tunix's `RLTrainingConfig.eval_every_n_steps`, tunix's `_run_eval` is a no-op unless an `eval_ds` is passed to `trainer.train()`. And even if you do pass one, tunix's default GRPO eval re-runs the full sampled rollout (num_generations responses per prompt), which is ~3hr/eval and impractical for trajectory monitoring. Install a `tunix.sft.hooks.TrainingHooks` subclass that hooks `on_train_step_end`, checks `rl_cluster.global_steps % eval_interval`, and calls maxtext's own `evaluate(...)` (greedy decode + the configured scoring pipeline). Gives matched-step PRE / step_N / POST trajectory logging at near-zero cost beyond the eval itself (which is already fast when `eval_batch_size` is set per commit d536d13). No-op when eval_interval <= 0 or num_test_batches <= 0. Soft-skips with a warning if tunix.sft.hooks isn't importable, so the launcher still works against a stock-only tunix.
1 parent ff05b79 commit c1b42ca

1 file changed

Lines changed: 101 additions & 0 deletions

File tree

src/maxtext/trainers/post_train/rl/train_rl.py

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -352,6 +352,103 @@ def _use_raw_prompt(x):
352352
return train_dataset, test_dataset
353353

354354

355+
def _install_intermediate_eval_hook(
356+
rl_cluster: Any,
357+
trainer_config: Any,
358+
test_dataset: Any,
359+
) -> None:
360+
"""Fire `evaluate(...)` every `eval_interval` outer steps during training.
361+
362+
tunix's `eval_every_n_steps` in `RLTrainingConfig` is silently dead unless
363+
an `eval_ds` is passed to `trainer.train()`, and even then tunix's default
364+
`_run_eval` re-runs the full GRPO rollout (`num_generations` sampled per
365+
prompt), which is ~3hr/eval and impractical for trajectory monitoring.
366+
367+
This hook subclasses `tunix.sft.hooks.TrainingHooks` and at every
368+
`eval_interval` outer step (matched against `rl_cluster.global_steps`)
369+
calls maxtext's `evaluate(...)` — greedy decode + the configured scoring
370+
pipeline — and logs the result. Gives matched-step PRE/INTERMEDIATE/POST
371+
curves without any change to tunix.
372+
373+
No-op if `eval_interval <= 0` or `num_test_batches <= 0` or tunix's hooks
374+
module is unavailable.
375+
"""
376+
if trainer_config.num_test_batches <= 0:
377+
return
378+
eval_interval = int(getattr(trainer_config, "eval_interval", 0))
379+
if eval_interval <= 0:
380+
return
381+
try:
382+
# Soft-import: keeps the launcher usable against a stock-only tunix.
383+
from tunix.sft import hooks as _hk # pylint: disable=import-outside-toplevel
384+
except ImportError:
385+
max_logging.warning("[intermediate-eval] tunix.sft.hooks not importable; skipping hook" " install.")
386+
return
387+
388+
state: dict = {"last_step_evaluated": -1}
389+
390+
class _IntermediateEvalHook(_hk.TrainingHooks): # type: ignore[name-defined]
391+
"""Fires `evaluate(...)` every `eval_interval` outer steps."""
392+
393+
def on_train_start(self, train_ctx): # noqa: ARG002
394+
del train_ctx
395+
396+
def on_train_end(self, train_ctx): # noqa: ARG002
397+
del train_ctx
398+
399+
def on_train_step_start(self, train_ctx): # noqa: ARG002
400+
del train_ctx
401+
402+
def on_eval_step_start(self, train_ctx): # noqa: ARG002
403+
del train_ctx
404+
405+
def on_eval_step_end(self, train_ctx, *args, **kwargs): # noqa: ARG002
406+
del train_ctx, args, kwargs
407+
408+
def on_train_step_end(self, trainer, step, loss): # noqa: ARG002
409+
"""Fire `evaluate(...)` once per `eval_interval` outer steps."""
410+
del trainer, loss
411+
try:
412+
outer_step = int(rl_cluster.global_steps)
413+
except Exception: # pylint: disable=broad-exception-caught
414+
outer_step = int(step) if step is not None else -1
415+
if outer_step <= 0 or outer_step == state["last_step_evaluated"]:
416+
return
417+
if outer_step % eval_interval != 0:
418+
return
419+
state["last_step_evaluated"] = outer_step
420+
try:
421+
(corr, total, accuracy, partial_accuracy, format_accuracy), _ = evaluate(
422+
trainer_config,
423+
test_dataset,
424+
rl_cluster=rl_cluster,
425+
num_passes=trainer_config.num_eval_passes,
426+
corr_lst=trainer_config.eval_corr_lst,
427+
make_lst=trainer_config.eval_make_lst,
428+
)
429+
max_logging.warning(
430+
f"Intermediate Eval (step={outer_step}): {corr=}, {total=},"
431+
f" {accuracy=}%, {partial_accuracy=}%, {format_accuracy=}%"
432+
)
433+
except Exception as e: # pylint: disable=broad-exception-caught
434+
max_logging.warning(f"[intermediate-eval] step={outer_step} failed: {e!r}")
435+
436+
# PeftTrainer composes a single training_hooks; install if free, else warn.
437+
try:
438+
actor = rl_cluster.actor_trainer
439+
if getattr(actor, "training_hooks", None) is None:
440+
actor.training_hooks = _IntermediateEvalHook()
441+
max_logging.warning(
442+
"[intermediate-eval] hook installed: evaluate(...) will fire every" f" {eval_interval} outer steps."
443+
)
444+
else:
445+
max_logging.warning(
446+
"[intermediate-eval] actor.training_hooks already set; skipping" " install (chain manually if you need both)."
447+
)
448+
except Exception as e: # pylint: disable=broad-exception-caught
449+
max_logging.warning(f"[intermediate-eval] install failed: {e!r}")
450+
451+
355452
def create_rl_components(
356453
trainer_config,
357454
sampler_config,
@@ -693,6 +790,10 @@ def _rl_train_impl(argv: Sequence[str], kwargs: dict):
693790
max_logging.log("Capturing reference model state before training.")
694791
ref_state_before = nnx.to_pure_dict(nnx.state(reference_model.base, nnx.Param))
695792

793+
# Wire intermediate eval: fire greedy `evaluate(...)` every `eval_interval`
794+
# outer steps. No-op when eval_interval <= 0 or num_test_batches <= 0.
795+
_install_intermediate_eval_hook(rl_cluster, trainer_config, test_dataset)
796+
696797
max_logging.warning("Starting RL training...")
697798
rl_trainer.train(train_dataset)
698799

0 commit comments

Comments
 (0)