Skip to content

Commit 4db581e

Browse files
author
Pooya Moradi
committed
Add intermediate eval hook: fire evaluate() every eval_interval outer steps
`eval_interval` was a silently-dead config: even though it's plumbed into tunix's `RLTrainingConfig.eval_every_n_steps`, tunix's `_run_eval` is a no-op unless an `eval_ds` is passed to `trainer.train()`. And even if you do pass one, tunix's default GRPO eval re-runs the full sampled rollout (num_generations responses per prompt), which is ~3hr/eval and impractical for trajectory monitoring. Install a `tunix.sft.hooks.TrainingHooks` subclass that hooks `on_train_step_end`, checks `rl_cluster.global_steps % eval_interval`, and calls maxtext's own `evaluate(...)` (greedy decode + the configured scoring pipeline). Gives matched-step PRE / step_N / POST trajectory logging at near-zero cost beyond the eval itself (which is already fast when `eval_batch_size` is set per commit d536d13). No-op when eval_interval <= 0 or num_test_batches <= 0. Soft-skips with a warning if tunix.sft.hooks isn't importable, so the launcher still works against a stock-only tunix.
1 parent ff05b79 commit 4db581e

2 files changed

Lines changed: 104 additions & 0 deletions

File tree

src/maxtext/trainers/post_train/rl/train_rl.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -693,6 +693,10 @@ def _rl_train_impl(argv: Sequence[str], kwargs: dict):
693693
max_logging.log("Capturing reference model state before training.")
694694
ref_state_before = nnx.to_pure_dict(nnx.state(reference_model.base, nnx.Param))
695695

696+
# Wire intermediate eval: fire greedy `evaluate(...)` every `eval_interval`
697+
# outer steps. No-op when eval_interval <= 0 or num_test_batches <= 0.
698+
utils_rl.install_intermediate_eval_hook(rl_cluster, trainer_config, test_dataset)
699+
696700
max_logging.warning("Starting RL training...")
697701
rl_trainer.train(train_dataset)
698702

src/maxtext/trainers/post_train/rl/utils_rl.py

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -760,3 +760,103 @@ def parse(
760760
return super().parse(
761761
messages=formatted_messages, add_generation_prompt=add_generation_prompt, is_first_msg=is_first_msg
762762
)
763+
764+
765+
def install_intermediate_eval_hook(
766+
rl_cluster: Any,
767+
trainer_config: Any,
768+
test_dataset: Any,
769+
) -> None:
770+
"""Fire `evaluate(...)` every `eval_interval` outer steps during training.
771+
772+
tunix's `eval_every_n_steps` in `RLTrainingConfig` is silently dead unless
773+
an `eval_ds` is passed to `trainer.train()`, and even then tunix's default
774+
`_run_eval` re-runs the full GRPO rollout (`num_generations` sampled per
775+
prompt), which is ~3hr/eval and impractical for trajectory monitoring.
776+
777+
This hook subclasses `tunix.sft.hooks.TrainingHooks` and at every
778+
`eval_interval` outer step (matched against `rl_cluster.global_steps`)
779+
calls maxtext's `evaluate(...)` — greedy decode + the configured scoring
780+
pipeline — and logs the result. Gives matched-step PRE/INTERMEDIATE/POST
781+
curves without any change to tunix.
782+
783+
No-op if `eval_interval <= 0` or `num_test_batches <= 0` or tunix's hooks
784+
module is unavailable.
785+
"""
786+
if trainer_config.num_test_batches <= 0:
787+
return
788+
eval_interval = int(getattr(trainer_config, "eval_interval", 0))
789+
if eval_interval <= 0:
790+
return
791+
try:
792+
# Soft-import: keeps the launcher usable against a stock-only tunix.
793+
from tunix.sft import hooks as _hk # pylint: disable=import-outside-toplevel
794+
except ImportError:
795+
max_logging.warning("[intermediate-eval] tunix.sft.hooks not importable; skipping hook install.")
796+
return
797+
# Lazy import to avoid the utils_rl <-> evaluate_rl circular import
798+
# (evaluate_rl imports utils_rl at module load time).
799+
from maxtext.trainers.post_train.rl.evaluate_rl import evaluate # pylint: disable=import-outside-toplevel
800+
801+
state: dict = {"last_step_evaluated": -1}
802+
803+
class _IntermediateEvalHook(_hk.TrainingHooks): # type: ignore[name-defined]
804+
"""Fires `evaluate(...)` every `eval_interval` outer steps."""
805+
806+
def on_train_start(self, train_ctx): # noqa: ARG002
807+
del train_ctx
808+
809+
def on_train_end(self, train_ctx): # noqa: ARG002
810+
del train_ctx
811+
812+
def on_train_step_start(self, train_ctx): # noqa: ARG002
813+
del train_ctx
814+
815+
def on_eval_step_start(self, train_ctx): # noqa: ARG002
816+
del train_ctx
817+
818+
def on_eval_step_end(self, train_ctx, *args, **kwargs): # noqa: ARG002
819+
del train_ctx, args, kwargs
820+
821+
def on_train_step_end(self, trainer, step, loss): # noqa: ARG002
822+
"""Fire `evaluate(...)` once per `eval_interval` outer steps."""
823+
del trainer, loss
824+
try:
825+
outer_step = int(rl_cluster.global_steps)
826+
except Exception: # pylint: disable=broad-exception-caught
827+
outer_step = int(step) if step is not None else -1
828+
if outer_step <= 0 or outer_step == state["last_step_evaluated"]:
829+
return
830+
if outer_step % eval_interval != 0:
831+
return
832+
state["last_step_evaluated"] = outer_step
833+
try:
834+
(corr, total, accuracy, partial_accuracy, format_accuracy), _ = evaluate(
835+
trainer_config,
836+
test_dataset,
837+
rl_cluster=rl_cluster,
838+
num_passes=trainer_config.num_eval_passes,
839+
corr_lst=trainer_config.eval_corr_lst,
840+
make_lst=trainer_config.eval_make_lst,
841+
)
842+
max_logging.warning(
843+
f"Intermediate Eval (step={outer_step}): {corr=}, {total=},"
844+
f" {accuracy=}%, {partial_accuracy=}%, {format_accuracy=}%"
845+
)
846+
except Exception as e: # pylint: disable=broad-exception-caught
847+
max_logging.warning(f"[intermediate-eval] step={outer_step} failed: {e!r}")
848+
849+
# PeftTrainer composes a single training_hooks; install if free, else warn.
850+
try:
851+
actor = rl_cluster.actor_trainer
852+
if getattr(actor, "training_hooks", None) is None:
853+
actor.training_hooks = _IntermediateEvalHook()
854+
max_logging.warning(
855+
f"[intermediate-eval] hook installed: evaluate(...) will fire every {eval_interval} outer steps."
856+
)
857+
else:
858+
max_logging.warning(
859+
"[intermediate-eval] actor.training_hooks already set; skipping install (chain manually if you need both)."
860+
)
861+
except Exception as e: # pylint: disable=broad-exception-caught
862+
max_logging.warning(f"[intermediate-eval] install failed: {e!r}")

0 commit comments

Comments
 (0)