Skip to content

Commit 670f1e1

Browse files
author
Pooya Moradi
committed
Add intermediate eval hook: fire evaluate() every eval_interval outer steps
`eval_interval` was a silently-dead config: even though it's plumbed into tunix's `RLTrainingConfig.eval_every_n_steps`, tunix's `_run_eval` is a no-op unless an `eval_ds` is passed to `trainer.train()`. And even if you do pass one, tunix's default GRPO eval re-runs the full sampled rollout (num_generations responses per prompt), which is ~3hr/eval and impractical for trajectory monitoring. Install a `tunix.sft.hooks.TrainingHooks` subclass that hooks `on_train_step_end`, checks `rl_cluster.global_steps % eval_interval`, and calls maxtext's own `evaluate(...)` (greedy decode + the configured scoring pipeline). Gives matched-step PRE / step_N / POST trajectory logging at near-zero cost beyond the eval itself (which is already fast when `eval_batch_size` is set per commit d536d13). No-op when eval_interval <= 0 or num_test_batches <= 0. Soft-skips with a warning if tunix.sft.hooks isn't importable, so the launcher still works against a stock-only tunix.
1 parent ff05b79 commit 670f1e1

4 files changed

Lines changed: 304 additions & 0 deletions

File tree

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
# Copyright 2023–2026 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Training hooks for post-train RL."""
16+
17+
from typing import Any
18+
19+
from tunix.sft import hooks as _tunix_hooks
20+
21+
from maxtext.trainers.post_train.rl.evaluate_rl import evaluate
22+
from maxtext.utils import max_logging
23+
24+
25+
class RLTrainingHooks(_tunix_hooks.TrainingHooks):
26+
"""Tunix `TrainingHooks` subclass that fires `evaluate(...)` every
27+
`eval_interval` outer steps during RL training.
28+
29+
tunix's `eval_every_n_steps` in `RLTrainingConfig` is silently dead unless
30+
an `eval_ds` is passed to `trainer.train()`, and even then tunix's default
31+
`_run_eval` re-runs the full GRPO rollout (`num_generations` sampled per
32+
prompt), which is ~3hr/eval and impractical for trajectory monitoring.
33+
34+
This hook hooks `on_train_step_end`, checks
35+
`rl_cluster.global_steps % eval_interval`, and calls maxtext's
36+
`evaluate(...)` — greedy decode + the configured scoring pipeline —
37+
logging the result. Gives matched-step PRE/INTERMEDIATE/POST curves
38+
without any change to tunix.
39+
"""
40+
41+
def __init__(
42+
self,
43+
rl_cluster: Any,
44+
trainer_config: Any,
45+
test_dataset: Any,
46+
eval_interval: int,
47+
):
48+
self._rl_cluster = rl_cluster
49+
self._trainer_config = trainer_config
50+
self._test_dataset = test_dataset
51+
self._eval_interval = eval_interval
52+
self._last_step_evaluated = -1
53+
54+
# The five lifecycle methods below are abstract in `tunix.sft.hooks.TrainingHooks`,
55+
# so subclasses MUST implement them. We have no per-step / per-train work to do here
56+
# outside `on_train_step_end`, so they're no-op stubs.
57+
def on_train_start(self, train_ctx): # noqa: ARG002
58+
del train_ctx
59+
60+
def on_train_end(self, train_ctx): # noqa: ARG002
61+
del train_ctx
62+
63+
def on_train_step_start(self, train_ctx): # noqa: ARG002
64+
del train_ctx
65+
66+
def on_eval_step_start(self, train_ctx): # noqa: ARG002
67+
del train_ctx
68+
69+
def on_eval_step_end(self, train_ctx, *args, **kwargs): # noqa: ARG002
70+
del train_ctx, args, kwargs
71+
72+
def on_train_step_end(self, trainer, step, loss): # noqa: ARG002
73+
"""Fire `evaluate(...)` once per `eval_interval` outer steps."""
74+
del trainer, loss
75+
try:
76+
outer_step = int(self._rl_cluster.global_steps)
77+
except Exception: # pylint: disable=broad-exception-caught
78+
outer_step = int(step) if step is not None else -1
79+
if outer_step <= 0 or outer_step == self._last_step_evaluated:
80+
return
81+
if outer_step % self._eval_interval != 0:
82+
return
83+
self._last_step_evaluated = outer_step
84+
try:
85+
tc = self._trainer_config
86+
(corr, total, accuracy, partial_accuracy, format_accuracy), _ = evaluate(
87+
tc,
88+
self._test_dataset,
89+
rl_cluster=self._rl_cluster,
90+
num_passes=tc.num_eval_passes,
91+
corr_lst=tc.eval_corr_lst,
92+
make_lst=tc.eval_make_lst,
93+
)
94+
max_logging.warning(
95+
f"Intermediate Eval (step={outer_step}): {corr=}, {total=},"
96+
f" {accuracy=}%, {partial_accuracy=}%, {format_accuracy=}%"
97+
)
98+
except Exception as e: # pylint: disable=broad-exception-caught
99+
max_logging.warning(f"[intermediate-eval] step={outer_step} failed: {e!r}")

src/maxtext/trainers/post_train/rl/train_rl.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -693,6 +693,10 @@ def _rl_train_impl(argv: Sequence[str], kwargs: dict):
693693
max_logging.log("Capturing reference model state before training.")
694694
ref_state_before = nnx.to_pure_dict(nnx.state(reference_model.base, nnx.Param))
695695

696+
# Wire intermediate eval: fire greedy `evaluate(...)` every `eval_interval`
697+
# outer steps. No-op when eval_interval <= 0 or num_test_batches <= 0.
698+
utils_rl.install_training_hooks(rl_cluster, trainer_config, test_dataset)
699+
696700
max_logging.warning("Starting RL training...")
697701
rl_trainer.train(train_dataset)
698702

src/maxtext/trainers/post_train/rl/utils_rl.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -760,3 +760,43 @@ def parse(
760760
return super().parse(
761761
messages=formatted_messages, add_generation_prompt=add_generation_prompt, is_first_msg=is_first_msg
762762
)
763+
764+
765+
def install_training_hooks(
766+
rl_cluster: Any,
767+
trainer_config: Any,
768+
test_dataset: Any,
769+
) -> None:
770+
"""Install maxtext's `RLTrainingHooks` on the actor trainer.
771+
772+
No-op if `eval_interval <= 0` or `num_test_batches <= 0` or tunix's hooks
773+
module is unavailable.
774+
"""
775+
if trainer_config.num_test_batches <= 0:
776+
return
777+
eval_interval = int(getattr(trainer_config, "eval_interval", 0))
778+
if eval_interval <= 0:
779+
return
780+
try:
781+
# `hooks` hard-imports `tunix.sft.hooks`. If that's missing (stock-only
782+
# tunix without the SFT hooks API), the import below raises and we
783+
# soft-skip rather than crash the launcher.
784+
from maxtext.trainers.post_train.rl.hooks import RLTrainingHooks # pylint: disable=import-outside-toplevel
785+
except ImportError:
786+
max_logging.warning("[intermediate-eval] tunix.sft.hooks not importable; skipping hook install.")
787+
return
788+
789+
# PeftTrainer composes a single training_hooks; install if free, else warn.
790+
try:
791+
actor = rl_cluster.actor_trainer
792+
if getattr(actor, "training_hooks", None) is None:
793+
actor.training_hooks = RLTrainingHooks(rl_cluster, trainer_config, test_dataset, eval_interval)
794+
max_logging.warning(
795+
f"[intermediate-eval] hook installed: evaluate(...) will fire every {eval_interval} outer steps."
796+
)
797+
else:
798+
max_logging.warning(
799+
"[intermediate-eval] actor.training_hooks already set; skipping install (chain manually if you need both)."
800+
)
801+
except Exception as e: # pylint: disable=broad-exception-caught
802+
max_logging.warning(f"[intermediate-eval] install failed: {e!r}")
Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,161 @@
1+
# Copyright 2023–2026 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Tests for post-train RL training hooks."""
16+
17+
import unittest
18+
from types import SimpleNamespace
19+
from unittest import mock
20+
21+
import pytest
22+
23+
pytestmark = [pytest.mark.cpu_only, pytest.mark.post_training]
24+
25+
from maxtext.trainers.post_train.rl import hooks as rl_hooks
26+
from maxtext.trainers.post_train.rl import utils_rl
27+
28+
29+
def _make_trainer_config(**overrides):
30+
"""Build a SimpleNamespace with the trainer-config attributes hooks reads."""
31+
defaults = {
32+
"num_test_batches": 5,
33+
"eval_interval": 10,
34+
"num_eval_passes": 1,
35+
"eval_corr_lst": False,
36+
"eval_make_lst": False,
37+
}
38+
defaults.update(overrides)
39+
return SimpleNamespace(**defaults)
40+
41+
42+
def _make_rl_cluster(global_steps=0):
43+
cluster = SimpleNamespace()
44+
cluster.global_steps = global_steps
45+
cluster.actor_trainer = SimpleNamespace(training_hooks=None)
46+
return cluster
47+
48+
49+
class RLTrainingHooksTest(unittest.TestCase):
50+
"""Verify `RLTrainingHooks.on_train_step_end` step-gating + evaluate dispatch."""
51+
52+
def setUp(self):
53+
super().setUp()
54+
eval_patcher = mock.patch.object(rl_hooks, "evaluate")
55+
self.mock_evaluate = eval_patcher.start()
56+
self.addCleanup(eval_patcher.stop)
57+
# evaluate(...) returns ((corr, total, acc, partial_acc, fmt_acc), _).
58+
self.mock_evaluate.return_value = ((1, 2, 50.0, 50.0, 100.0), None)
59+
60+
def _build_hook(self, eval_interval=10, global_steps=0):
61+
cluster = _make_rl_cluster(global_steps=global_steps)
62+
cfg = _make_trainer_config(eval_interval=eval_interval)
63+
return rl_hooks.RLTrainingHooks(cluster, cfg, test_dataset=None, eval_interval=eval_interval)
64+
65+
def test_fires_on_matching_step(self):
66+
hook = self._build_hook(eval_interval=10, global_steps=10)
67+
hook.on_train_step_end(trainer=None, step=10, loss=None)
68+
self.mock_evaluate.assert_called_once()
69+
70+
def test_skips_when_step_not_multiple_of_interval(self):
71+
hook = self._build_hook(eval_interval=10, global_steps=7)
72+
hook.on_train_step_end(trainer=None, step=7, loss=None)
73+
self.mock_evaluate.assert_not_called()
74+
75+
def test_skips_when_step_is_zero(self):
76+
hook = self._build_hook(eval_interval=10, global_steps=0)
77+
hook.on_train_step_end(trainer=None, step=0, loss=None)
78+
self.mock_evaluate.assert_not_called()
79+
80+
def test_dedupes_repeat_calls_on_same_step(self):
81+
hook = self._build_hook(eval_interval=10, global_steps=10)
82+
hook.on_train_step_end(trainer=None, step=10, loss=None)
83+
hook.on_train_step_end(trainer=None, step=10, loss=None)
84+
self.assertEqual(self.mock_evaluate.call_count, 1)
85+
86+
def test_swallows_evaluate_exception(self):
87+
"""A failing evaluate shouldn't propagate and break the training step."""
88+
self.mock_evaluate.side_effect = RuntimeError("boom")
89+
hook = self._build_hook(eval_interval=10, global_steps=10)
90+
hook.on_train_step_end(trainer=None, step=10, loss=None) # must not raise
91+
92+
def test_falls_back_to_step_arg_when_global_steps_unreadable(self):
93+
"""When rl_cluster.global_steps raises, use the `step` arg instead."""
94+
95+
class _ClusterWithBadGlobalSteps:
96+
"""Stand-in rl_cluster whose `global_steps` property always raises."""
97+
98+
def __init__(self):
99+
self.actor_trainer = SimpleNamespace(training_hooks=None)
100+
101+
@property
102+
def global_steps(self):
103+
raise RuntimeError("not ready")
104+
105+
bad_cluster = _ClusterWithBadGlobalSteps()
106+
cfg = _make_trainer_config(eval_interval=10)
107+
hook = rl_hooks.RLTrainingHooks(bad_cluster, cfg, test_dataset=None, eval_interval=10)
108+
hook.on_train_step_end(trainer=None, step=10, loss=None)
109+
self.mock_evaluate.assert_called_once()
110+
111+
112+
class InstallTrainingHooksTest(unittest.TestCase):
113+
"""Verify `utils_rl.install_training_hooks` gating + attach behavior."""
114+
115+
def test_noop_when_num_test_batches_nonpositive(self):
116+
cluster = _make_rl_cluster()
117+
cfg = _make_trainer_config(num_test_batches=0, eval_interval=10)
118+
utils_rl.install_training_hooks(cluster, cfg, test_dataset=None)
119+
self.assertIsNone(cluster.actor_trainer.training_hooks)
120+
121+
def test_noop_when_eval_interval_nonpositive(self):
122+
cluster = _make_rl_cluster()
123+
cfg = _make_trainer_config(num_test_batches=5, eval_interval=0)
124+
utils_rl.install_training_hooks(cluster, cfg, test_dataset=None)
125+
self.assertIsNone(cluster.actor_trainer.training_hooks)
126+
127+
def test_noop_when_eval_interval_attr_missing(self):
128+
cluster = _make_rl_cluster()
129+
cfg = SimpleNamespace(num_test_batches=5) # no eval_interval attribute
130+
utils_rl.install_training_hooks(cluster, cfg, test_dataset=None)
131+
self.assertIsNone(cluster.actor_trainer.training_hooks)
132+
133+
def test_attaches_hook_on_happy_path(self):
134+
cluster = _make_rl_cluster()
135+
cfg = _make_trainer_config(num_test_batches=5, eval_interval=10)
136+
utils_rl.install_training_hooks(cluster, cfg, test_dataset="dummy")
137+
self.assertIsInstance(cluster.actor_trainer.training_hooks, rl_hooks.RLTrainingHooks)
138+
139+
def test_does_not_overwrite_existing_training_hooks(self):
140+
cluster = _make_rl_cluster()
141+
sentinel = object()
142+
cluster.actor_trainer.training_hooks = sentinel
143+
cfg = _make_trainer_config(num_test_batches=5, eval_interval=10)
144+
utils_rl.install_training_hooks(cluster, cfg, test_dataset=None)
145+
self.assertIs(cluster.actor_trainer.training_hooks, sentinel)
146+
147+
def test_swallows_importerror_when_hooks_module_missing(self):
148+
"""If `from .hooks import RLTrainingHooks` fails, install soft-skips.
149+
150+
Setting `sys.modules[name] = None` makes Python's import system raise
151+
ImportError on the next import attempt for that name (documented behavior).
152+
"""
153+
cluster = _make_rl_cluster()
154+
cfg = _make_trainer_config(num_test_batches=5, eval_interval=10)
155+
with mock.patch.dict("sys.modules", {"maxtext.trainers.post_train.rl.hooks": None}):
156+
utils_rl.install_training_hooks(cluster, cfg, test_dataset=None)
157+
self.assertIsNone(cluster.actor_trainer.training_hooks)
158+
159+
160+
if __name__ == "__main__":
161+
unittest.main()

0 commit comments

Comments
 (0)