Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 99 additions & 0 deletions src/maxtext/trainers/post_train/rl/hooks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
# Copyright 2023–2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Training hooks for post-train RL."""

from typing import Any

from tunix.sft import hooks as _tunix_hooks

from maxtext.trainers.post_train.rl.evaluate_rl import evaluate
from maxtext.utils import max_logging


class RLTrainingHooks(_tunix_hooks.TrainingHooks):
"""Tunix `TrainingHooks` subclass that fires `evaluate(...)` every
`eval_interval` outer steps during RL training.

tunix's `eval_every_n_steps` in `RLTrainingConfig` is silently dead unless
an `eval_ds` is passed to `trainer.train()`, and even then tunix's default
`_run_eval` re-runs the full GRPO rollout (`num_generations` sampled per
prompt), which is ~3hr/eval and impractical for trajectory monitoring.

This hook hooks `on_train_step_end`, checks
`rl_cluster.global_steps % eval_interval`, and calls maxtext's
`evaluate(...)` — greedy decode + the configured scoring pipeline —
logging the result. Gives matched-step PRE/INTERMEDIATE/POST curves
without any change to tunix.
"""

def __init__(
self,
rl_cluster: Any,
trainer_config: Any,
test_dataset: Any,
eval_interval: int,
):
self._rl_cluster = rl_cluster
self._trainer_config = trainer_config
self._test_dataset = test_dataset
self._eval_interval = eval_interval
self._last_step_evaluated = -1

# The five lifecycle methods below are abstract in `tunix.sft.hooks.TrainingHooks`,
# so subclasses MUST implement them. We have no per-step / per-train work to do here
# outside `on_train_step_end`, so they're no-op stubs.
def on_train_start(self, train_ctx): # noqa: ARG002
del train_ctx

def on_train_end(self, train_ctx): # noqa: ARG002
del train_ctx

def on_train_step_start(self, train_ctx): # noqa: ARG002
del train_ctx

def on_eval_step_start(self, train_ctx): # noqa: ARG002
del train_ctx

def on_eval_step_end(self, train_ctx, *args, **kwargs): # noqa: ARG002
del train_ctx, args, kwargs

def on_train_step_end(self, trainer, step, loss): # noqa: ARG002
"""Fire `evaluate(...)` once per `eval_interval` outer steps."""
del trainer, loss
try:
outer_step = int(self._rl_cluster.global_steps)
except Exception: # pylint: disable=broad-exception-caught
outer_step = int(step) if step is not None else -1
if outer_step <= 0 or outer_step == self._last_step_evaluated:
return
if outer_step % self._eval_interval != 0:
return
self._last_step_evaluated = outer_step
try:
tc = self._trainer_config
(corr, total, accuracy, partial_accuracy, format_accuracy), _ = evaluate(
tc,
self._test_dataset,
rl_cluster=self._rl_cluster,
num_passes=tc.num_eval_passes,
corr_lst=tc.eval_corr_lst,
make_lst=tc.eval_make_lst,
)
max_logging.warning(
f"Intermediate Eval (step={outer_step}): {corr=}, {total=},"
f" {accuracy=}%, {partial_accuracy=}%, {format_accuracy=}%"
)
except Exception as e: # pylint: disable=broad-exception-caught
max_logging.warning(f"[intermediate-eval] step={outer_step} failed: {e!r}")
4 changes: 4 additions & 0 deletions src/maxtext/trainers/post_train/rl/train_rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -693,6 +693,10 @@ def _rl_train_impl(argv: Sequence[str], kwargs: dict):
max_logging.log("Capturing reference model state before training.")
ref_state_before = nnx.to_pure_dict(nnx.state(reference_model.base, nnx.Param))

# Wire intermediate eval: fire greedy `evaluate(...)` every `eval_interval`
# outer steps. No-op when eval_interval <= 0 or num_test_batches <= 0.
utils_rl.install_training_hooks(rl_cluster, trainer_config, test_dataset)

max_logging.warning("Starting RL training...")
rl_trainer.train(train_dataset)

Expand Down
40 changes: 40 additions & 0 deletions src/maxtext/trainers/post_train/rl/utils_rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -760,3 +760,43 @@ def parse(
return super().parse(
messages=formatted_messages, add_generation_prompt=add_generation_prompt, is_first_msg=is_first_msg
)


def install_training_hooks(
rl_cluster: Any,
trainer_config: Any,
test_dataset: Any,
) -> None:
"""Install maxtext's `RLTrainingHooks` on the actor trainer.

No-op if `eval_interval <= 0` or `num_test_batches <= 0` or tunix's hooks
module is unavailable.
"""
if trainer_config.num_test_batches <= 0:
return
eval_interval = int(getattr(trainer_config, "eval_interval", 0))
if eval_interval <= 0:
return
try:
# `hooks` hard-imports `tunix.sft.hooks`. If that's missing (stock-only
# tunix without the SFT hooks API), the import below raises and we
# soft-skip rather than crash the launcher.
from maxtext.trainers.post_train.rl.hooks import RLTrainingHooks # pylint: disable=import-outside-toplevel
except ImportError:
max_logging.warning("[intermediate-eval] tunix.sft.hooks not importable; skipping hook install.")
return

# PeftTrainer composes a single training_hooks; install if free, else warn.
try:
actor = rl_cluster.actor_trainer
if getattr(actor, "training_hooks", None) is None:
actor.training_hooks = RLTrainingHooks(rl_cluster, trainer_config, test_dataset, eval_interval)
max_logging.warning(
f"[intermediate-eval] hook installed: evaluate(...) will fire every {eval_interval} outer steps."
)
else:
max_logging.warning(
"[intermediate-eval] actor.training_hooks already set; skipping install (chain manually if you need both)."
)
except Exception as e: # pylint: disable=broad-exception-caught
max_logging.warning(f"[intermediate-eval] install failed: {e!r}")
161 changes: 161 additions & 0 deletions tests/post_training/unit/rl_hooks_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
# Copyright 2023–2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Tests for post-train RL training hooks."""

import unittest
from types import SimpleNamespace
from unittest import mock

import pytest

pytestmark = [pytest.mark.cpu_only, pytest.mark.post_training]

from maxtext.trainers.post_train.rl import hooks as rl_hooks
from maxtext.trainers.post_train.rl import utils_rl


def _make_trainer_config(**overrides):
"""Build a SimpleNamespace with the trainer-config attributes hooks reads."""
defaults = {
"num_test_batches": 5,
"eval_interval": 10,
"num_eval_passes": 1,
"eval_corr_lst": False,
"eval_make_lst": False,
}
defaults.update(overrides)
return SimpleNamespace(**defaults)


def _make_rl_cluster(global_steps=0):
cluster = SimpleNamespace()
cluster.global_steps = global_steps
cluster.actor_trainer = SimpleNamespace(training_hooks=None)
return cluster


class RLTrainingHooksTest(unittest.TestCase):
"""Verify `RLTrainingHooks.on_train_step_end` step-gating + evaluate dispatch."""

def setUp(self):
super().setUp()
eval_patcher = mock.patch.object(rl_hooks, "evaluate")
self.mock_evaluate = eval_patcher.start()
self.addCleanup(eval_patcher.stop)
# evaluate(...) returns ((corr, total, acc, partial_acc, fmt_acc), _).
self.mock_evaluate.return_value = ((1, 2, 50.0, 50.0, 100.0), None)

def _build_hook(self, eval_interval=10, global_steps=0):
cluster = _make_rl_cluster(global_steps=global_steps)
cfg = _make_trainer_config(eval_interval=eval_interval)
return rl_hooks.RLTrainingHooks(cluster, cfg, test_dataset=None, eval_interval=eval_interval)

def test_fires_on_matching_step(self):
hook = self._build_hook(eval_interval=10, global_steps=10)
hook.on_train_step_end(trainer=None, step=10, loss=None)
self.mock_evaluate.assert_called_once()

def test_skips_when_step_not_multiple_of_interval(self):
hook = self._build_hook(eval_interval=10, global_steps=7)
hook.on_train_step_end(trainer=None, step=7, loss=None)
self.mock_evaluate.assert_not_called()

def test_skips_when_step_is_zero(self):
hook = self._build_hook(eval_interval=10, global_steps=0)
hook.on_train_step_end(trainer=None, step=0, loss=None)
self.mock_evaluate.assert_not_called()

def test_dedupes_repeat_calls_on_same_step(self):
hook = self._build_hook(eval_interval=10, global_steps=10)
hook.on_train_step_end(trainer=None, step=10, loss=None)
hook.on_train_step_end(trainer=None, step=10, loss=None)
self.assertEqual(self.mock_evaluate.call_count, 1)

def test_swallows_evaluate_exception(self):
"""A failing evaluate shouldn't propagate and break the training step."""
self.mock_evaluate.side_effect = RuntimeError("boom")
hook = self._build_hook(eval_interval=10, global_steps=10)
hook.on_train_step_end(trainer=None, step=10, loss=None) # must not raise

def test_falls_back_to_step_arg_when_global_steps_unreadable(self):
"""When rl_cluster.global_steps raises, use the `step` arg instead."""

class _ClusterWithBadGlobalSteps:
"""Stand-in rl_cluster whose `global_steps` property always raises."""

def __init__(self):
self.actor_trainer = SimpleNamespace(training_hooks=None)

@property
def global_steps(self):
raise RuntimeError("not ready")

bad_cluster = _ClusterWithBadGlobalSteps()
cfg = _make_trainer_config(eval_interval=10)
hook = rl_hooks.RLTrainingHooks(bad_cluster, cfg, test_dataset=None, eval_interval=10)
hook.on_train_step_end(trainer=None, step=10, loss=None)
self.mock_evaluate.assert_called_once()


class InstallTrainingHooksTest(unittest.TestCase):
"""Verify `utils_rl.install_training_hooks` gating + attach behavior."""

def test_noop_when_num_test_batches_nonpositive(self):
cluster = _make_rl_cluster()
cfg = _make_trainer_config(num_test_batches=0, eval_interval=10)
utils_rl.install_training_hooks(cluster, cfg, test_dataset=None)
self.assertIsNone(cluster.actor_trainer.training_hooks)

def test_noop_when_eval_interval_nonpositive(self):
cluster = _make_rl_cluster()
cfg = _make_trainer_config(num_test_batches=5, eval_interval=0)
utils_rl.install_training_hooks(cluster, cfg, test_dataset=None)
self.assertIsNone(cluster.actor_trainer.training_hooks)

def test_noop_when_eval_interval_attr_missing(self):
cluster = _make_rl_cluster()
cfg = SimpleNamespace(num_test_batches=5) # no eval_interval attribute
utils_rl.install_training_hooks(cluster, cfg, test_dataset=None)
self.assertIsNone(cluster.actor_trainer.training_hooks)

def test_attaches_hook_on_happy_path(self):
cluster = _make_rl_cluster()
cfg = _make_trainer_config(num_test_batches=5, eval_interval=10)
utils_rl.install_training_hooks(cluster, cfg, test_dataset="dummy")
self.assertIsInstance(cluster.actor_trainer.training_hooks, rl_hooks.RLTrainingHooks)

def test_does_not_overwrite_existing_training_hooks(self):
cluster = _make_rl_cluster()
sentinel = object()
cluster.actor_trainer.training_hooks = sentinel
cfg = _make_trainer_config(num_test_batches=5, eval_interval=10)
utils_rl.install_training_hooks(cluster, cfg, test_dataset=None)
self.assertIs(cluster.actor_trainer.training_hooks, sentinel)

def test_swallows_importerror_when_hooks_module_missing(self):
"""If `from .hooks import RLTrainingHooks` fails, install soft-skips.

Setting `sys.modules[name] = None` makes Python's import system raise
ImportError on the next import attempt for that name (documented behavior).
"""
cluster = _make_rl_cluster()
cfg = _make_trainer_config(num_test_batches=5, eval_interval=10)
with mock.patch.dict("sys.modules", {"maxtext.trainers.post_train.rl.hooks": None}):
utils_rl.install_training_hooks(cluster, cfg, test_dataset=None)
self.assertIsNone(cluster.actor_trainer.training_hooks)


if __name__ == "__main__":
unittest.main()
Loading