Skip to content

Commit f2c32ac

Browse files
author
Pooya Moradi
committed
Add intermediate eval hook: fire evaluate() every eval_interval outer steps
`eval_interval` was a silently-dead config: even though it's plumbed into tunix's `RLTrainingConfig.eval_every_n_steps`, tunix's `_run_eval` is a no-op unless an `eval_ds` is passed to `trainer.train()`. And even if you do pass one, tunix's default GRPO eval re-runs the full sampled rollout (num_generations responses per prompt), which is ~3hr/eval and impractical for trajectory monitoring. Install a `tunix.sft.hooks.TrainingHooks` subclass that hooks `on_train_step_end`, checks `rl_cluster.global_steps % eval_interval`, and calls maxtext's own `evaluate(...)` (greedy decode + the configured scoring pipeline). Gives matched-step PRE / step_N / POST trajectory logging at near-zero cost beyond the eval itself (which is already fast when `eval_batch_size` is set per commit d536d13). No-op when eval_interval <= 0 or num_test_batches <= 0. Soft-skips with a warning if tunix.sft.hooks isn't importable, so the launcher still works against a stock-only tunix.
1 parent ff05b79 commit f2c32ac

2 files changed

Lines changed: 122 additions & 0 deletions

File tree

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
# Copyright 2023–2026 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Training hooks for post-train RL."""
16+
17+
from typing import Any
18+
19+
from maxtext.trainers.post_train.rl.evaluate_rl import evaluate
20+
from maxtext.utils import max_logging
21+
22+
23+
def install_intermediate_eval_hook(
24+
rl_cluster: Any,
25+
trainer_config: Any,
26+
test_dataset: Any,
27+
) -> None:
28+
"""Fire `evaluate(...)` every `eval_interval` outer steps during training.
29+
30+
tunix's `eval_every_n_steps` in `RLTrainingConfig` is silently dead unless
31+
an `eval_ds` is passed to `trainer.train()`, and even then tunix's default
32+
`_run_eval` re-runs the full GRPO rollout (`num_generations` sampled per
33+
prompt), which is ~3hr/eval and impractical for trajectory monitoring.
34+
35+
This hook subclasses `tunix.sft.hooks.TrainingHooks` and at every
36+
`eval_interval` outer step (matched against `rl_cluster.global_steps`)
37+
calls maxtext's `evaluate(...)` — greedy decode + the configured scoring
38+
pipeline — and logs the result. Gives matched-step PRE/INTERMEDIATE/POST
39+
curves without any change to tunix.
40+
41+
No-op if `eval_interval <= 0` or `num_test_batches <= 0` or tunix's hooks
42+
module is unavailable.
43+
"""
44+
if trainer_config.num_test_batches <= 0:
45+
return
46+
eval_interval = int(getattr(trainer_config, "eval_interval", 0))
47+
if eval_interval <= 0:
48+
return
49+
try:
50+
# Soft-import: keeps the launcher usable against a stock-only tunix.
51+
from tunix.sft import hooks as _hk # pylint: disable=import-outside-toplevel
52+
except ImportError:
53+
max_logging.warning("[intermediate-eval] tunix.sft.hooks not importable; skipping hook install.")
54+
return
55+
56+
state: dict = {"last_step_evaluated": -1}
57+
58+
class _IntermediateEvalHook(_hk.TrainingHooks): # type: ignore[name-defined]
59+
"""Fires `evaluate(...)` every `eval_interval` outer steps."""
60+
61+
def on_train_start(self, train_ctx): # noqa: ARG002
62+
del train_ctx
63+
64+
def on_train_end(self, train_ctx): # noqa: ARG002
65+
del train_ctx
66+
67+
def on_train_step_start(self, train_ctx): # noqa: ARG002
68+
del train_ctx
69+
70+
def on_eval_step_start(self, train_ctx): # noqa: ARG002
71+
del train_ctx
72+
73+
def on_eval_step_end(self, train_ctx, *args, **kwargs): # noqa: ARG002
74+
del train_ctx, args, kwargs
75+
76+
def on_train_step_end(self, trainer, step, loss): # noqa: ARG002
77+
"""Fire `evaluate(...)` once per `eval_interval` outer steps."""
78+
del trainer, loss
79+
try:
80+
outer_step = int(rl_cluster.global_steps)
81+
except Exception: # pylint: disable=broad-exception-caught
82+
outer_step = int(step) if step is not None else -1
83+
if outer_step <= 0 or outer_step == state["last_step_evaluated"]:
84+
return
85+
if outer_step % eval_interval != 0:
86+
return
87+
state["last_step_evaluated"] = outer_step
88+
try:
89+
(corr, total, accuracy, partial_accuracy, format_accuracy), _ = evaluate(
90+
trainer_config,
91+
test_dataset,
92+
rl_cluster=rl_cluster,
93+
num_passes=trainer_config.num_eval_passes,
94+
corr_lst=trainer_config.eval_corr_lst,
95+
make_lst=trainer_config.eval_make_lst,
96+
)
97+
max_logging.warning(
98+
f"Intermediate Eval (step={outer_step}): {corr=}, {total=},"
99+
f" {accuracy=}%, {partial_accuracy=}%, {format_accuracy=}%"
100+
)
101+
except Exception as e: # pylint: disable=broad-exception-caught
102+
max_logging.warning(f"[intermediate-eval] step={outer_step} failed: {e!r}")
103+
104+
# PeftTrainer composes a single training_hooks; install if free, else warn.
105+
try:
106+
actor = rl_cluster.actor_trainer
107+
if getattr(actor, "training_hooks", None) is None:
108+
actor.training_hooks = _IntermediateEvalHook()
109+
max_logging.warning(
110+
f"[intermediate-eval] hook installed: evaluate(...) will fire every {eval_interval} outer steps."
111+
)
112+
else:
113+
max_logging.warning(
114+
"[intermediate-eval] actor.training_hooks already set; skipping install (chain manually if you need both)."
115+
)
116+
except Exception as e: # pylint: disable=broad-exception-caught
117+
max_logging.warning(f"[intermediate-eval] install failed: {e!r}")

src/maxtext/trainers/post_train/rl/train_rl.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,7 @@ def _no_bf16_to_f32_cast(val, tgt_dtype, src_key):
118118
from maxtext.utils.globals import MAXTEXT_CONFIGS_DIR
119119
from maxtext.integration.vllm.maxtext_vllm_rollout import MaxTextVllmRollout
120120
from maxtext.trainers.post_train.rl.evaluate_rl import evaluate
121+
from maxtext.trainers.post_train.rl import hooks as rl_hooks
121122
from maxtext.trainers.post_train.rl import utils_rl
122123
from maxtext.input_pipeline.instruction_data_processing import load_data_template_from_file
123124
from maxtext.utils import max_logging, max_utils, model_creation_utils
@@ -693,6 +694,10 @@ def _rl_train_impl(argv: Sequence[str], kwargs: dict):
693694
max_logging.log("Capturing reference model state before training.")
694695
ref_state_before = nnx.to_pure_dict(nnx.state(reference_model.base, nnx.Param))
695696

697+
# Wire intermediate eval: fire greedy `evaluate(...)` every `eval_interval`
698+
# outer steps. No-op when eval_interval <= 0 or num_test_batches <= 0.
699+
rl_hooks.install_intermediate_eval_hook(rl_cluster, trainer_config, test_dataset)
700+
696701
max_logging.warning("Starting RL training...")
697702
rl_trainer.train(train_dataset)
698703

0 commit comments

Comments
 (0)