diff --git a/src/maxtext/trainers/post_train/rl/hooks.py b/src/maxtext/trainers/post_train/rl/hooks.py new file mode 100644 index 0000000000..665269b826 --- /dev/null +++ b/src/maxtext/trainers/post_train/rl/hooks.py @@ -0,0 +1,99 @@ +# Copyright 2023–2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Training hooks for post-train RL.""" + +from typing import Any + +from tunix.sft import hooks as _tunix_hooks + +from maxtext.trainers.post_train.rl.evaluate_rl import evaluate +from maxtext.utils import max_logging + + +class RLTrainingHooks(_tunix_hooks.TrainingHooks): + """Tunix `TrainingHooks` subclass that fires `evaluate(...)` every + `eval_interval` outer steps during RL training. + + tunix's `eval_every_n_steps` in `RLTrainingConfig` is silently dead unless + an `eval_ds` is passed to `trainer.train()`, and even then tunix's default + `_run_eval` re-runs the full GRPO rollout (`num_generations` sampled per + prompt), which is ~3hr/eval and impractical for trajectory monitoring. + + This hook hooks `on_train_step_end`, checks + `rl_cluster.global_steps % eval_interval`, and calls maxtext's + `evaluate(...)` — greedy decode + the configured scoring pipeline — + logging the result. Gives matched-step PRE/INTERMEDIATE/POST curves + without any change to tunix. + """ + + def __init__( + self, + rl_cluster: Any, + trainer_config: Any, + test_dataset: Any, + eval_interval: int, + ): + self._rl_cluster = rl_cluster + self._trainer_config = trainer_config + self._test_dataset = test_dataset + self._eval_interval = eval_interval + self._last_step_evaluated = -1 + + # The five lifecycle methods below are abstract in `tunix.sft.hooks.TrainingHooks`, + # so subclasses MUST implement them. We have no per-step / per-train work to do here + # outside `on_train_step_end`, so they're no-op stubs. + def on_train_start(self, train_ctx): # noqa: ARG002 + del train_ctx + + def on_train_end(self, train_ctx): # noqa: ARG002 + del train_ctx + + def on_train_step_start(self, train_ctx): # noqa: ARG002 + del train_ctx + + def on_eval_step_start(self, train_ctx): # noqa: ARG002 + del train_ctx + + def on_eval_step_end(self, train_ctx, *args, **kwargs): # noqa: ARG002 + del train_ctx, args, kwargs + + def on_train_step_end(self, trainer, step, loss): # noqa: ARG002 + """Fire `evaluate(...)` once per `eval_interval` outer steps.""" + del trainer, loss + try: + outer_step = int(self._rl_cluster.global_steps) + except Exception: # pylint: disable=broad-exception-caught + outer_step = int(step) if step is not None else -1 + if outer_step <= 0 or outer_step == self._last_step_evaluated: + return + if outer_step % self._eval_interval != 0: + return + self._last_step_evaluated = outer_step + try: + tc = self._trainer_config + (corr, total, accuracy, partial_accuracy, format_accuracy), _ = evaluate( + tc, + self._test_dataset, + rl_cluster=self._rl_cluster, + num_passes=tc.num_eval_passes, + corr_lst=tc.eval_corr_lst, + make_lst=tc.eval_make_lst, + ) + max_logging.warning( + f"Intermediate Eval (step={outer_step}): {corr=}, {total=}," + f" {accuracy=}%, {partial_accuracy=}%, {format_accuracy=}%" + ) + except Exception as e: # pylint: disable=broad-exception-caught + max_logging.warning(f"[intermediate-eval] step={outer_step} failed: {e!r}") diff --git a/src/maxtext/trainers/post_train/rl/train_rl.py b/src/maxtext/trainers/post_train/rl/train_rl.py index a7df87c8e5..dbc7c8d608 100644 --- a/src/maxtext/trainers/post_train/rl/train_rl.py +++ b/src/maxtext/trainers/post_train/rl/train_rl.py @@ -693,6 +693,10 @@ def _rl_train_impl(argv: Sequence[str], kwargs: dict): max_logging.log("Capturing reference model state before training.") ref_state_before = nnx.to_pure_dict(nnx.state(reference_model.base, nnx.Param)) + # Wire intermediate eval: fire greedy `evaluate(...)` every `eval_interval` + # outer steps. No-op when eval_interval <= 0 or num_test_batches <= 0. + utils_rl.install_training_hooks(rl_cluster, trainer_config, test_dataset) + max_logging.warning("Starting RL training...") rl_trainer.train(train_dataset) diff --git a/src/maxtext/trainers/post_train/rl/utils_rl.py b/src/maxtext/trainers/post_train/rl/utils_rl.py index 9e1f115f2a..8f37c63bf2 100644 --- a/src/maxtext/trainers/post_train/rl/utils_rl.py +++ b/src/maxtext/trainers/post_train/rl/utils_rl.py @@ -760,3 +760,43 @@ def parse( return super().parse( messages=formatted_messages, add_generation_prompt=add_generation_prompt, is_first_msg=is_first_msg ) + + +def install_training_hooks( + rl_cluster: Any, + trainer_config: Any, + test_dataset: Any, +) -> None: + """Install maxtext's `RLTrainingHooks` on the actor trainer. + + No-op if `eval_interval <= 0` or `num_test_batches <= 0` or tunix's hooks + module is unavailable. + """ + if trainer_config.num_test_batches <= 0: + return + eval_interval = int(getattr(trainer_config, "eval_interval", 0)) + if eval_interval <= 0: + return + try: + # `hooks` hard-imports `tunix.sft.hooks`. If that's missing (stock-only + # tunix without the SFT hooks API), the import below raises and we + # soft-skip rather than crash the launcher. + from maxtext.trainers.post_train.rl.hooks import RLTrainingHooks # pylint: disable=import-outside-toplevel + except ImportError: + max_logging.warning("[intermediate-eval] tunix.sft.hooks not importable; skipping hook install.") + return + + # PeftTrainer composes a single training_hooks; install if free, else warn. + try: + actor = rl_cluster.actor_trainer + if getattr(actor, "training_hooks", None) is None: + actor.training_hooks = RLTrainingHooks(rl_cluster, trainer_config, test_dataset, eval_interval) + max_logging.warning( + f"[intermediate-eval] hook installed: evaluate(...) will fire every {eval_interval} outer steps." + ) + else: + max_logging.warning( + "[intermediate-eval] actor.training_hooks already set; skipping install (chain manually if you need both)." + ) + except Exception as e: # pylint: disable=broad-exception-caught + max_logging.warning(f"[intermediate-eval] install failed: {e!r}") diff --git a/tests/post_training/unit/rl_hooks_test.py b/tests/post_training/unit/rl_hooks_test.py new file mode 100644 index 0000000000..5d9fa4795a --- /dev/null +++ b/tests/post_training/unit/rl_hooks_test.py @@ -0,0 +1,161 @@ +# Copyright 2023–2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for post-train RL training hooks.""" + +import unittest +from types import SimpleNamespace +from unittest import mock + +import pytest + +pytestmark = [pytest.mark.cpu_only, pytest.mark.post_training] + +from maxtext.trainers.post_train.rl import hooks as rl_hooks +from maxtext.trainers.post_train.rl import utils_rl + + +def _make_trainer_config(**overrides): + """Build a SimpleNamespace with the trainer-config attributes hooks reads.""" + defaults = { + "num_test_batches": 5, + "eval_interval": 10, + "num_eval_passes": 1, + "eval_corr_lst": False, + "eval_make_lst": False, + } + defaults.update(overrides) + return SimpleNamespace(**defaults) + + +def _make_rl_cluster(global_steps=0): + cluster = SimpleNamespace() + cluster.global_steps = global_steps + cluster.actor_trainer = SimpleNamespace(training_hooks=None) + return cluster + + +class RLTrainingHooksTest(unittest.TestCase): + """Verify `RLTrainingHooks.on_train_step_end` step-gating + evaluate dispatch.""" + + def setUp(self): + super().setUp() + eval_patcher = mock.patch.object(rl_hooks, "evaluate") + self.mock_evaluate = eval_patcher.start() + self.addCleanup(eval_patcher.stop) + # evaluate(...) returns ((corr, total, acc, partial_acc, fmt_acc), _). + self.mock_evaluate.return_value = ((1, 2, 50.0, 50.0, 100.0), None) + + def _build_hook(self, eval_interval=10, global_steps=0): + cluster = _make_rl_cluster(global_steps=global_steps) + cfg = _make_trainer_config(eval_interval=eval_interval) + return rl_hooks.RLTrainingHooks(cluster, cfg, test_dataset=None, eval_interval=eval_interval) + + def test_fires_on_matching_step(self): + hook = self._build_hook(eval_interval=10, global_steps=10) + hook.on_train_step_end(trainer=None, step=10, loss=None) + self.mock_evaluate.assert_called_once() + + def test_skips_when_step_not_multiple_of_interval(self): + hook = self._build_hook(eval_interval=10, global_steps=7) + hook.on_train_step_end(trainer=None, step=7, loss=None) + self.mock_evaluate.assert_not_called() + + def test_skips_when_step_is_zero(self): + hook = self._build_hook(eval_interval=10, global_steps=0) + hook.on_train_step_end(trainer=None, step=0, loss=None) + self.mock_evaluate.assert_not_called() + + def test_dedupes_repeat_calls_on_same_step(self): + hook = self._build_hook(eval_interval=10, global_steps=10) + hook.on_train_step_end(trainer=None, step=10, loss=None) + hook.on_train_step_end(trainer=None, step=10, loss=None) + self.assertEqual(self.mock_evaluate.call_count, 1) + + def test_swallows_evaluate_exception(self): + """A failing evaluate shouldn't propagate and break the training step.""" + self.mock_evaluate.side_effect = RuntimeError("boom") + hook = self._build_hook(eval_interval=10, global_steps=10) + hook.on_train_step_end(trainer=None, step=10, loss=None) # must not raise + + def test_falls_back_to_step_arg_when_global_steps_unreadable(self): + """When rl_cluster.global_steps raises, use the `step` arg instead.""" + + class _ClusterWithBadGlobalSteps: + """Stand-in rl_cluster whose `global_steps` property always raises.""" + + def __init__(self): + self.actor_trainer = SimpleNamespace(training_hooks=None) + + @property + def global_steps(self): + raise RuntimeError("not ready") + + bad_cluster = _ClusterWithBadGlobalSteps() + cfg = _make_trainer_config(eval_interval=10) + hook = rl_hooks.RLTrainingHooks(bad_cluster, cfg, test_dataset=None, eval_interval=10) + hook.on_train_step_end(trainer=None, step=10, loss=None) + self.mock_evaluate.assert_called_once() + + +class InstallTrainingHooksTest(unittest.TestCase): + """Verify `utils_rl.install_training_hooks` gating + attach behavior.""" + + def test_noop_when_num_test_batches_nonpositive(self): + cluster = _make_rl_cluster() + cfg = _make_trainer_config(num_test_batches=0, eval_interval=10) + utils_rl.install_training_hooks(cluster, cfg, test_dataset=None) + self.assertIsNone(cluster.actor_trainer.training_hooks) + + def test_noop_when_eval_interval_nonpositive(self): + cluster = _make_rl_cluster() + cfg = _make_trainer_config(num_test_batches=5, eval_interval=0) + utils_rl.install_training_hooks(cluster, cfg, test_dataset=None) + self.assertIsNone(cluster.actor_trainer.training_hooks) + + def test_noop_when_eval_interval_attr_missing(self): + cluster = _make_rl_cluster() + cfg = SimpleNamespace(num_test_batches=5) # no eval_interval attribute + utils_rl.install_training_hooks(cluster, cfg, test_dataset=None) + self.assertIsNone(cluster.actor_trainer.training_hooks) + + def test_attaches_hook_on_happy_path(self): + cluster = _make_rl_cluster() + cfg = _make_trainer_config(num_test_batches=5, eval_interval=10) + utils_rl.install_training_hooks(cluster, cfg, test_dataset="dummy") + self.assertIsInstance(cluster.actor_trainer.training_hooks, rl_hooks.RLTrainingHooks) + + def test_does_not_overwrite_existing_training_hooks(self): + cluster = _make_rl_cluster() + sentinel = object() + cluster.actor_trainer.training_hooks = sentinel + cfg = _make_trainer_config(num_test_batches=5, eval_interval=10) + utils_rl.install_training_hooks(cluster, cfg, test_dataset=None) + self.assertIs(cluster.actor_trainer.training_hooks, sentinel) + + def test_swallows_importerror_when_hooks_module_missing(self): + """If `from .hooks import RLTrainingHooks` fails, install soft-skips. + + Setting `sys.modules[name] = None` makes Python's import system raise + ImportError on the next import attempt for that name (documented behavior). + """ + cluster = _make_rl_cluster() + cfg = _make_trainer_config(num_test_batches=5, eval_interval=10) + with mock.patch.dict("sys.modules", {"maxtext.trainers.post_train.rl.hooks": None}): + utils_rl.install_training_hooks(cluster, cfg, test_dataset=None) + self.assertIsNone(cluster.actor_trainer.training_hooks) + + +if __name__ == "__main__": + unittest.main()