Skip to content

Commit 0f6177e

Browse files
author
Pooya Moradi
committed
Add intermediate eval hook: fire evaluate() every eval_interval outer steps
`eval_interval` was a silently-dead config: even though it's plumbed into tunix's `RLTrainingConfig.eval_every_n_steps`, tunix's `_run_eval` is a no-op unless an `eval_ds` is passed to `trainer.train()`. And even if you do pass one, tunix's default GRPO eval re-runs the full sampled rollout (num_generations responses per prompt), which is ~3hr/eval and impractical for trajectory monitoring. Install a `tunix.sft.hooks.TrainingHooks` subclass that hooks `on_train_step_end`, checks `rl_cluster.global_steps % eval_interval`, and calls maxtext's own `evaluate(...)` (greedy decode + the configured scoring pipeline). Gives matched-step PRE / step_N / POST trajectory logging at near-zero cost beyond the eval itself (which is already fast when `eval_batch_size` is set per commit d536d13). No-op when eval_interval <= 0 or num_test_batches <= 0. Soft-skips with a warning if tunix.sft.hooks isn't importable, so the launcher still works against a stock-only tunix.
1 parent ff05b79 commit 0f6177e

3 files changed

Lines changed: 140 additions & 0 deletions

File tree

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
# Copyright 2023–2026 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Training hooks for post-train RL."""
16+
17+
from typing import Any
18+
19+
from tunix.sft import hooks as _tunix_hooks
20+
21+
from maxtext.trainers.post_train.rl.evaluate_rl import evaluate
22+
from maxtext.utils import max_logging
23+
24+
25+
class RLTrainingHooks(_tunix_hooks.TrainingHooks):
26+
"""Tunix `TrainingHooks` subclass that fires `evaluate(...)` every
27+
`eval_interval` outer steps during RL training.
28+
29+
tunix's `eval_every_n_steps` in `RLTrainingConfig` is silently dead unless
30+
an `eval_ds` is passed to `trainer.train()`, and even then tunix's default
31+
`_run_eval` re-runs the full GRPO rollout (`num_generations` sampled per
32+
prompt), which is ~3hr/eval and impractical for trajectory monitoring.
33+
34+
This hook hooks `on_train_step_end`, checks
35+
`rl_cluster.global_steps % eval_interval`, and calls maxtext's
36+
`evaluate(...)` — greedy decode + the configured scoring pipeline —
37+
logging the result. Gives matched-step PRE/INTERMEDIATE/POST curves
38+
without any change to tunix.
39+
"""
40+
41+
def __init__(
42+
self,
43+
rl_cluster: Any,
44+
trainer_config: Any,
45+
test_dataset: Any,
46+
eval_interval: int,
47+
):
48+
self._rl_cluster = rl_cluster
49+
self._trainer_config = trainer_config
50+
self._test_dataset = test_dataset
51+
self._eval_interval = eval_interval
52+
self._last_step_evaluated = -1
53+
54+
def on_train_start(self, train_ctx): # noqa: ARG002
55+
del train_ctx
56+
57+
def on_train_end(self, train_ctx): # noqa: ARG002
58+
del train_ctx
59+
60+
def on_train_step_start(self, train_ctx): # noqa: ARG002
61+
del train_ctx
62+
63+
def on_eval_step_start(self, train_ctx): # noqa: ARG002
64+
del train_ctx
65+
66+
def on_eval_step_end(self, train_ctx, *args, **kwargs): # noqa: ARG002
67+
del train_ctx, args, kwargs
68+
69+
def on_train_step_end(self, trainer, step, loss): # noqa: ARG002
70+
"""Fire `evaluate(...)` once per `eval_interval` outer steps."""
71+
del trainer, loss
72+
try:
73+
outer_step = int(self._rl_cluster.global_steps)
74+
except Exception: # pylint: disable=broad-exception-caught
75+
outer_step = int(step) if step is not None else -1
76+
if outer_step <= 0 or outer_step == self._last_step_evaluated:
77+
return
78+
if outer_step % self._eval_interval != 0:
79+
return
80+
self._last_step_evaluated = outer_step
81+
try:
82+
tc = self._trainer_config
83+
(corr, total, accuracy, partial_accuracy, format_accuracy), _ = evaluate(
84+
tc,
85+
self._test_dataset,
86+
rl_cluster=self._rl_cluster,
87+
num_passes=tc.num_eval_passes,
88+
corr_lst=tc.eval_corr_lst,
89+
make_lst=tc.eval_make_lst,
90+
)
91+
max_logging.warning(
92+
f"Intermediate Eval (step={outer_step}): {corr=}, {total=},"
93+
f" {accuracy=}%, {partial_accuracy=}%, {format_accuracy=}%"
94+
)
95+
except Exception as e: # pylint: disable=broad-exception-caught
96+
max_logging.warning(f"[intermediate-eval] step={outer_step} failed: {e!r}")

src/maxtext/trainers/post_train/rl/train_rl.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -693,6 +693,10 @@ def _rl_train_impl(argv: Sequence[str], kwargs: dict):
693693
max_logging.log("Capturing reference model state before training.")
694694
ref_state_before = nnx.to_pure_dict(nnx.state(reference_model.base, nnx.Param))
695695

696+
# Wire intermediate eval: fire greedy `evaluate(...)` every `eval_interval`
697+
# outer steps. No-op when eval_interval <= 0 or num_test_batches <= 0.
698+
utils_rl.install_training_hooks(rl_cluster, trainer_config, test_dataset)
699+
696700
max_logging.warning("Starting RL training...")
697701
rl_trainer.train(train_dataset)
698702

src/maxtext/trainers/post_train/rl/utils_rl.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -760,3 +760,43 @@ def parse(
760760
return super().parse(
761761
messages=formatted_messages, add_generation_prompt=add_generation_prompt, is_first_msg=is_first_msg
762762
)
763+
764+
765+
def install_training_hooks(
766+
rl_cluster: Any,
767+
trainer_config: Any,
768+
test_dataset: Any,
769+
) -> None:
770+
"""Install maxtext's `RLTrainingHooks` on the actor trainer.
771+
772+
No-op if `eval_interval <= 0` or `num_test_batches <= 0` or tunix's hooks
773+
module is unavailable.
774+
"""
775+
if trainer_config.num_test_batches <= 0:
776+
return
777+
eval_interval = int(getattr(trainer_config, "eval_interval", 0))
778+
if eval_interval <= 0:
779+
return
780+
try:
781+
# Soft-import keeps the launcher usable against a stock-only tunix. The
782+
# hooks module hard-imports `tunix.sft.hooks`, so the ImportError surfaces
783+
# here when tunix doesn't have it.
784+
from maxtext.trainers.post_train.rl.hooks import RLTrainingHooks # pylint: disable=import-outside-toplevel
785+
except ImportError:
786+
max_logging.warning("[intermediate-eval] tunix.sft.hooks not importable; skipping hook install.")
787+
return
788+
789+
# PeftTrainer composes a single training_hooks; install if free, else warn.
790+
try:
791+
actor = rl_cluster.actor_trainer
792+
if getattr(actor, "training_hooks", None) is None:
793+
actor.training_hooks = RLTrainingHooks(rl_cluster, trainer_config, test_dataset, eval_interval)
794+
max_logging.warning(
795+
f"[intermediate-eval] hook installed: evaluate(...) will fire every {eval_interval} outer steps."
796+
)
797+
else:
798+
max_logging.warning(
799+
"[intermediate-eval] actor.training_hooks already set; skipping install (chain manually if you need both)."
800+
)
801+
except Exception as e: # pylint: disable=broad-exception-caught
802+
max_logging.warning(f"[intermediate-eval] install failed: {e!r}")

0 commit comments

Comments
 (0)