Skip to content

Commit b06b8b6

Browse files
author
Pooya Moradi
committed
Add mean_reward to evaluate() via reusing user-provided reward functions
evaluate() previously returned only binary correctness metrics (corr, accuracy, partial_accuracy, format_accuracy). For RL training the most important metric is the actual reward signal — but that was only available at training time, not during PRE / POST / intermediate eval. Extend evaluate() to accept an optional `reward_fns` list and compute the per-example sum of all reward functions, returning the mean as a 6th element `mean_reward`. The reward functions used are the same ones that drive training (per `reward_functions_path` + `reward_functions` CLI knobs from commit 50eb2ca) — so eval-time mean_reward is exactly what training optimizes for. No task-specific code is added to maxtext; whatever scoring scheme the user plugs in becomes both the training signal AND the eval-time reward metric. Plumbed through to all three eval call sites in train_rl.py: * Pre RL Training log line now includes mean_reward=... * Post RL Training log line now includes mean_reward=... * Intermediate Eval (step=N) log line now includes mean_reward=... When no reward_fns is configured, mean_reward is reported as 0.0 and the rest of evaluate() works exactly as before (backward compatible). create_rl_components() return signature extended to also return reward_fns so the eval call sites can pass them along.
1 parent 5d45f2e commit b06b8b6

6 files changed

Lines changed: 196 additions & 21 deletions

File tree

src/maxtext/trainers/post_train/rl/evaluate_rl.py

Lines changed: 60 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import collections
2020
import json
2121
import re
22-
from typing import Any
22+
from typing import Any, Callable, Optional
2323

2424
from tqdm.auto import tqdm
2525
from tunix.rl.rollout.base_rollout import RolloutConfig
@@ -183,13 +183,43 @@ def score_responses(tmvp_config, question, responses, answers):
183183
raise ValueError(f"Unknown eval_mode: {eval_mode!r}")
184184

185185

186+
def _compute_row_reward(reward_fns, prompt, responses, answer, row_idx):
187+
"""Sum the per-function reward scores across all sampled responses for one prompt.
188+
189+
Honors the sampling strategy `evaluate()` ran with: when `num_passes > 1`
190+
(or when a non-greedy `eval_sampling_strategy` is configured),
191+
`responses` contains one entry per pass for the same prompt, and this
192+
helper sums the reward across all of them. The caller divides the
193+
total by the number of (prompt, response) pairs to get the per-sample
194+
mean reward, mirroring tunix's GRPO per-rollout reward aggregation.
195+
196+
Returns a tuple `(score_sum, n_responses)`. On any exception the
197+
failure is logged and `(0.0, 0)` is returned so the caller's running
198+
mean is not corrupted.
199+
"""
200+
if not responses:
201+
return 0.0, 0
202+
try:
203+
score_sum = 0.0
204+
for resp in responses:
205+
for fn in reward_fns:
206+
scores = fn(prompts=[prompt], completions=[resp], answer=[answer])
207+
if scores:
208+
score_sum += float(scores[0])
209+
return score_sum, len(responses)
210+
except Exception as e: # pylint: disable=broad-exception-caught
211+
max_logging.log(f"[eval-reward] reward_fn failed on row {row_idx}: {e!r}")
212+
return 0.0, 0
213+
214+
186215
def evaluate(
187216
tmvp_config,
188217
dataset,
189218
rl_cluster,
190219
num_passes=1,
191220
corr_lst=False,
192221
make_lst=False,
222+
reward_fns: Optional[list[Callable[..., Any]]] = None,
193223
):
194224
"""
195225
Computes accuracy and percentage of outputs matching the format.
@@ -201,15 +231,27 @@ def evaluate(
201231
num_passes: Number of generation passes
202232
corr_lst: If True, only include correct responses in the list
203233
make_lst: If True, return a list of (question, answer, responses)
234+
reward_fns: Optional list of reward functions to also evaluate against
235+
the sampled responses (using whichever `eval_sampling_strategy`
236+
is configured). Each function must accept `prompts`,
237+
`completions`, `answer`, and return a list of floats (same signature
238+
as the training-time reward stack). When provided, the per-example
239+
score is the SUM across all reward functions (matching tunix's GRPO
240+
aggregation), and the per-example mean is returned as `mean_reward`.
241+
When None or empty, `mean_reward` is 0.0.
204242
205243
Returns:
206-
Tuple of statistics and optionally the response list
244+
Tuple (corr, total, accuracy, partial_accuracy, format_accuracy,
245+
mean_reward), response_lst
207246
"""
208247
response_lst = []
209248
corr = 0
210249
partially_corr = 0
211250
corr_format = 0
212251
total = 0
252+
reward_sum = 0.0
253+
reward_count = 0 # number of (prompt, sampled response) pairs scored
254+
use_reward = bool(reward_fns)
213255

214256
for batch in tqdm(dataset):
215257
answers = batch["answer"]
@@ -225,16 +267,26 @@ def evaluate(
225267
)
226268

227269
# Score each question-answer pair
228-
for question, responses, answer in zip(questions, multiple_call_responses, answers):
270+
for question, responses, answer, prompt in zip(questions, multiple_call_responses, answers, prompts):
229271
# decode the json-encoded list of acceptable answers
230-
answer = list(dict.fromkeys(json.loads(answer)))
272+
answer_list = list(dict.fromkeys(json.loads(answer)))
231273
is_correct, is_partially_correct, has_correct_format = score_responses(
232274
tmvp_config=tmvp_config,
233275
question=question,
234276
responses=responses,
235-
answers=answer,
277+
answers=answer_list,
236278
)
237279

280+
# Per-example reward (eval-time mirror of the training reward). The
281+
# total is accumulated across all sampled responses (across num_passes
282+
# and across the eval_sampling_strategy distribution) and divided by
283+
# the actual per-(prompt, response) count at the end. See
284+
# `_compute_row_reward` for details.
285+
if use_reward:
286+
row_sum, row_count = _compute_row_reward(reward_fns, prompt, responses, answer, total)
287+
reward_sum += row_sum
288+
reward_count += row_count
289+
238290
# Update counters. For "pass" and "maj" modes, scores are booleans
239291
# (True=1, False=0). For "pass_at_1" mode, scores are floats in [0, 1]
240292
# representing the fraction of samples correct. Using += works for both:
@@ -245,9 +297,9 @@ def evaluate(
245297

246298
if make_lst:
247299
if corr_lst and is_correct:
248-
response_lst.append((question, answer, responses))
300+
response_lst.append((question, answer_list, responses))
249301
elif not corr_lst and not is_correct:
250-
response_lst.append((question, answer, responses))
302+
response_lst.append((question, answer_list, responses))
251303

252304
total += 1
253305

@@ -265,6 +317,7 @@ def evaluate(
265317
corr / total * 100 if total > 0 else 0,
266318
partially_corr / total * 100 if total > 0 else 0,
267319
corr_format / total * 100 if total > 0 else 0,
320+
reward_sum / reward_count if (use_reward and reward_count > 0) else 0.0,
268321
)
269322

270323
return to_return, response_lst

src/maxtext/trainers/post_train/rl/hooks.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
"""Training hooks for post-train RL."""
1616

17-
from typing import Any
17+
from typing import Any, Callable, Optional
1818

1919
from tunix.sft import hooks as _tunix_hooks
2020

@@ -33,7 +33,8 @@ class RLTrainingHooks(_tunix_hooks.TrainingHooks):
3333
3434
This hook hooks `on_train_step_end`, checks
3535
`rl_cluster.global_steps % eval_interval`, and calls maxtext's
36-
`evaluate(...)` — greedy decode + the configured scoring pipeline —
36+
`evaluate(...)` (using whichever `eval_sampling_strategy` is configured
37+
in `generation_configs`) plus the configured scoring pipeline,
3738
logging the result. Gives matched-step PRE/INTERMEDIATE/POST curves
3839
without any change to tunix.
3940
"""
@@ -44,11 +45,13 @@ def __init__(
4445
trainer_config: Any,
4546
test_dataset: Any,
4647
eval_interval: int,
48+
reward_fns: Optional[list[Callable[..., Any]]] = None,
4749
):
4850
self._rl_cluster = rl_cluster
4951
self._trainer_config = trainer_config
5052
self._test_dataset = test_dataset
5153
self._eval_interval = eval_interval
54+
self._reward_fns = reward_fns
5255
self._last_step_evaluated = -1
5356

5457
# The five lifecycle methods below are abstract in `tunix.sft.hooks.TrainingHooks`,
@@ -83,17 +86,19 @@ def on_train_step_end(self, trainer, step, loss): # noqa: ARG002
8386
self._last_step_evaluated = outer_step
8487
try:
8588
tc = self._trainer_config
86-
(corr, total, accuracy, partial_accuracy, format_accuracy), _ = evaluate(
89+
(corr, total, accuracy, partial_accuracy, format_accuracy, mean_reward), _ = evaluate(
8790
tc,
8891
self._test_dataset,
8992
rl_cluster=self._rl_cluster,
9093
num_passes=tc.num_eval_passes,
9194
corr_lst=tc.eval_corr_lst,
9295
make_lst=tc.eval_make_lst,
96+
reward_fns=self._reward_fns,
9397
)
9498
max_logging.warning(
9599
f"Intermediate Eval (step={outer_step}): {corr=}, {total=},"
96-
f" {accuracy=}%, {partial_accuracy=}%, {format_accuracy=}%"
100+
f" {accuracy=}%, {partial_accuracy=}%, {format_accuracy=}%,"
101+
f" {mean_reward=:.4f}"
97102
)
98103
except Exception as e: # pylint: disable=broad-exception-caught
99104
max_logging.warning(f"[intermediate-eval] step={outer_step} failed: {e!r}")

src/maxtext/trainers/post_train/rl/train_rl.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -579,7 +579,7 @@ def _reward_fn(**kwargs):
579579
algo_config=grpo_config,
580580
)
581581

582-
return rl_cluster, rl_trainer, optimizer
582+
return rl_cluster, rl_trainer, optimizer, reward_fns
583583

584584

585585
def rl_train(argv: Sequence[str], kwargs: dict):
@@ -664,7 +664,7 @@ def _rl_train_impl(argv: Sequence[str], kwargs: dict):
664664
max_logging.log(f"Policy mesh shape: {actor_mesh.shape}")
665665
max_logging.log(f"Rollout_mesh shape: {rollout_mesh.shape}")
666666

667-
rl_cluster, rl_trainer, _ = create_rl_components(
667+
rl_cluster, rl_trainer, _, reward_fns = create_rl_components(
668668
trainer_config,
669669
sampler_config,
670670
sampler_devices,
@@ -682,16 +682,18 @@ def _rl_train_impl(argv: Sequence[str], kwargs: dict):
682682
# Update vllm with model parameters from checkpoint
683683
rl_cluster.rollout.update_params(nnx.state(actor_model))
684684

685-
(corr, total, accuracy, partial_accuracy, format_accuracy), _ = evaluate(
685+
(corr, total, accuracy, partial_accuracy, format_accuracy, mean_reward), _ = evaluate(
686686
trainer_config,
687687
test_dataset,
688688
rl_cluster=rl_cluster,
689689
num_passes=trainer_config.num_eval_passes,
690690
corr_lst=trainer_config.eval_corr_lst,
691691
make_lst=trainer_config.eval_make_lst,
692+
reward_fns=reward_fns,
692693
)
693694
max_logging.warning(
694-
f"Pre RL Training: {corr=}, {total=}, {accuracy=}%, {partial_accuracy=}%," f" {format_accuracy=}%"
695+
f"Pre RL Training: {corr=}, {total=}, {accuracy=}%, {partial_accuracy=}%,"
696+
f" {format_accuracy=}%, {mean_reward=:.4f}"
695697
)
696698

697699
# Start training
@@ -701,7 +703,7 @@ def _rl_train_impl(argv: Sequence[str], kwargs: dict):
701703

702704
# Wire intermediate eval: fire greedy `evaluate(...)` every `eval_interval`
703705
# outer steps. No-op when eval_interval <= 0 or num_test_batches <= 0.
704-
utils_rl.install_training_hooks(rl_cluster, trainer_config, test_dataset)
706+
utils_rl.install_training_hooks(rl_cluster, trainer_config, test_dataset, reward_fns)
705707

706708
max_logging.warning("Starting RL training...")
707709
rl_trainer.train(train_dataset)
@@ -718,16 +720,18 @@ def _rl_train_impl(argv: Sequence[str], kwargs: dict):
718720

719721
# Run evaluation after training
720722
if trainer_config.num_test_batches > 0:
721-
(corr, total, accuracy, partial_accuracy, format_accuracy), _ = evaluate(
723+
(corr, total, accuracy, partial_accuracy, format_accuracy, mean_reward), _ = evaluate(
722724
trainer_config,
723725
test_dataset,
724726
rl_cluster=rl_cluster,
725727
num_passes=trainer_config.num_eval_passes,
726728
corr_lst=trainer_config.eval_corr_lst,
727729
make_lst=trainer_config.eval_make_lst,
730+
reward_fns=reward_fns,
728731
)
729732
max_logging.warning(
730-
f"Post RL Training: {corr=}, {total=}, {accuracy=}%, {partial_accuracy=}%," f" {format_accuracy=}%"
733+
f"Post RL Training: {corr=}, {total=}, {accuracy=}%, {partial_accuracy=}%,"
734+
f" {format_accuracy=}%, {mean_reward=:.4f}"
731735
)
732736

733737

src/maxtext/trainers/post_train/rl/utils_rl.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -779,9 +779,14 @@ def install_training_hooks(
779779
rl_cluster: Any,
780780
trainer_config: Any,
781781
test_dataset: Any,
782+
reward_fns: Optional[list[Callable[..., Any]]] = None,
782783
) -> None:
783784
"""Install maxtext's `RLTrainingHooks` on the actor trainer.
784785
786+
When `reward_fns` is provided, intermediate eval logs the per-example
787+
`mean_reward` alongside the correctness metrics, mirroring the training-time
788+
reward stack.
789+
785790
No-op if `eval_interval <= 0` or `num_test_batches <= 0` or tunix's hooks
786791
module is unavailable.
787792
"""
@@ -803,7 +808,7 @@ def install_training_hooks(
803808
try:
804809
actor = rl_cluster.actor_trainer
805810
if getattr(actor, "training_hooks", None) is None:
806-
actor.training_hooks = RLTrainingHooks(rl_cluster, trainer_config, test_dataset, eval_interval)
811+
actor.training_hooks = RLTrainingHooks(rl_cluster, trainer_config, test_dataset, eval_interval, reward_fns)
807812
max_logging.warning(
808813
f"[intermediate-eval] hook installed: evaluate(...) will fire every {eval_interval} outer steps."
809814
)

tests/post_training/unit/evaluate_rl_test.py

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414

1515
"""Unit tests for evaluate_rl.py (CPU-only)."""
1616

17+
# pylint: disable=protected-access
18+
1719
import unittest
1820
import pytest
1921
from types import SimpleNamespace
@@ -208,5 +210,96 @@ def test_pass_at_1_all_correct(self):
208210
self.assertAlmostEqual(has_correct_format, 1.0)
209211

210212

213+
class TestComputeRowReward(unittest.TestCase):
214+
"""Tests for _compute_row_reward (per-prompt eval-time reward aggregation)."""
215+
216+
def _two_fns(self):
217+
"""Return two reward functions whose per-response scores can be summed."""
218+
219+
# Each fn must accept prompts, completions, answer as keyword args and
220+
# return a list of per-completion scores. The helper calls fn once per
221+
# response with single-element lists, so the returned list has length 1.
222+
def fn1(prompts, completions, answer): # pylint: disable=unused-argument
223+
return [1.0 for _ in completions]
224+
225+
def fn2(prompts, completions, answer): # pylint: disable=unused-argument
226+
return [float(len(c)) for c in completions]
227+
228+
return [fn1, fn2]
229+
230+
@pytest.mark.cpu_only
231+
def test_single_response_single_fn(self):
232+
def fn(prompts, completions, answer): # pylint: disable=unused-argument
233+
return [2.5 for _ in completions]
234+
235+
score_sum, count = evaluate_rl._compute_row_reward(
236+
reward_fns=[fn],
237+
prompt="p",
238+
responses=["abc"],
239+
answer="gold",
240+
row_idx=0,
241+
)
242+
self.assertAlmostEqual(score_sum, 2.5)
243+
self.assertEqual(count, 1)
244+
245+
@pytest.mark.cpu_only
246+
def test_sums_across_reward_fns_for_single_response(self):
247+
score_sum, count = evaluate_rl._compute_row_reward(
248+
reward_fns=self._two_fns(),
249+
prompt="p",
250+
responses=["abcd"],
251+
answer="gold",
252+
row_idx=0,
253+
)
254+
# fn1 = 1.0, fn2 = len("abcd") = 4.0 -> per-pass score = 5.0
255+
self.assertAlmostEqual(score_sum, 5.0)
256+
self.assertEqual(count, 1)
257+
258+
@pytest.mark.cpu_only
259+
def test_sums_across_passes_for_multiple_responses(self):
260+
"""Multi-pass: helper must aggregate across ALL sampled responses, not just [0]."""
261+
score_sum, count = evaluate_rl._compute_row_reward(
262+
reward_fns=self._two_fns(),
263+
prompt="p",
264+
responses=["a", "bcd", "ef"],
265+
answer="gold",
266+
row_idx=0,
267+
)
268+
# Per pass: pass0 = 1 + 1 = 2, pass1 = 1 + 3 = 4, pass2 = 1 + 2 = 3
269+
# Sum across 3 passes = 9, count = 3
270+
self.assertAlmostEqual(score_sum, 9.0)
271+
self.assertEqual(count, 3)
272+
273+
@pytest.mark.cpu_only
274+
def test_empty_responses_returns_zero_and_zero_count(self):
275+
"""An empty responses list must contribute nothing to the running mean."""
276+
score_sum, count = evaluate_rl._compute_row_reward(
277+
reward_fns=self._two_fns(),
278+
prompt="p",
279+
responses=[],
280+
answer="gold",
281+
row_idx=0,
282+
)
283+
self.assertEqual(score_sum, 0.0)
284+
self.assertEqual(count, 0)
285+
286+
@pytest.mark.cpu_only
287+
def test_exception_in_reward_fn_swallowed_and_returns_zero_count(self):
288+
"""A raising reward_fn must not propagate and must not corrupt the mean denominator."""
289+
290+
def _boom(**kwargs): # pylint: disable=unused-argument
291+
raise RuntimeError("reward failure")
292+
293+
score_sum, count = evaluate_rl._compute_row_reward(
294+
reward_fns=[_boom],
295+
prompt="p",
296+
responses=["abc"],
297+
answer="gold",
298+
row_idx=0,
299+
)
300+
self.assertEqual(score_sum, 0.0)
301+
self.assertEqual(count, 0) # zero count so the caller's mean isn't biased
302+
303+
211304
if __name__ == "__main__":
212305
unittest.main()

0 commit comments

Comments
 (0)