Skip to content

Add mean_reward to evaluate() via reusing user-provided reward functions#4083

Open
py4 wants to merge 1 commit into
mainfrom
pr/mean-reward-evaluate
Open

Add mean_reward to evaluate() via reusing user-provided reward functions#4083
py4 wants to merge 1 commit into
mainfrom
pr/mean-reward-evaluate

Conversation

@py4
Copy link
Copy Markdown
Collaborator

@py4 py4 commented Jun 5, 2026

Summary

evaluate() previously returned only binary correctness metrics (corr, accuracy, partial_accuracy, format_accuracy). For RL training the most important metric is the actual reward signal, but that was only available at training time, not during PRE / POST / intermediate eval. This PR extends evaluate() to also compute the per-example mean reward using the same reward functions that drive training.

Changes

  1. evaluate() accepts an optional reward_fns list. When provided, the per-example score is the SUM across all reward functions (matching tunix's GRPO aggregation), and the per-example mean is returned as a 6th tuple element mean_reward. When None or empty, mean_reward is 0.0 and the rest of evaluate() is unchanged (bit-for-bit backward compatible).

  2. The reward functions used are the same ones the training stack builds (via reward_functions_path + reward_functions CLI knobs in Add reward_functions_path + reward_functions CLI knobs for custom rewards #4045-to-be), so eval-time mean_reward is exactly what training optimizes for. No task-specific code is added to maxtext; whatever scoring scheme the user plugs in becomes both the training signal AND the eval-time reward metric.

  3. Plumbed through to all three eval call sites in train_rl.py:

    • Pre RL Training log line now includes mean_reward=...
    • Post RL Training log line now includes mean_reward=...
    • Intermediate Eval (step=N) log line now includes mean_reward=...
  4. create_rl_components() return signature extended to also return reward_fns so the eval call sites can pass them along.

  5. install_training_hooks / RLTrainingHooks (added in Add intermediate eval hook: fire evaluate() every eval_interval outer steps #4044) extended to accept reward_fns and forward it to the intermediate evaluate() call.

Backward compatibility

Default reward_fns=None everywhere preserves existing behavior. Existing recipes see mean_reward=0.0 in the eval tuple and no other change.

Files

File Δ
src/maxtext/trainers/post_train/rl/evaluate_rl.py +43 / -6: add reward_fns param + 6th tuple element
src/maxtext/trainers/post_train/rl/train_rl.py +12 / -8: wire reward_fns through 3 eval call sites
src/maxtext/trainers/post_train/rl/hooks.py +8 / -2: RLTrainingHooks.__init__ takes reward_fns; on_train_step_end forwards
src/maxtext/trainers/post_train/rl/utils_rl.py +6 / -1: install_training_hooks accepts and forwards reward_fns
tests/post_training/unit/rl_hooks_test.py +18 / -1: extend mock + 2 new tests covering the plumbing

Tests

The two new tests on RLTrainingHooksTest:

  • test_reward_fns_plumbed_to_evaluate: a sentinel reward_fns passed to the hook is forwarded to evaluate(...).
  • test_reward_fns_default_none: when reward_fns is omitted, evaluate(...) receives reward_fns=None.

Existing tests' mock return updated to the new 6-tuple shape (added the mean_reward=0.42 6th element).

Checklist

  • Pyink-clean (--pyink-indentation=2 --line-length=122)
  • Backward compatible: default reward_fns=None everywhere
  • No effect on non-RL paths (only RL eval + RL trainer touched)
  • Tests cover the reward_fns plumbing through the hook

@codecov
Copy link
Copy Markdown

codecov Bot commented Jun 5, 2026

Codecov Report

❌ Patch coverage is 51.42857% with 17 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
src/maxtext/trainers/post_train/rl/evaluate_rl.py 53.84% 11 Missing and 1 partial ⚠️
src/maxtext/trainers/post_train/rl/train_rl.py 0.00% 5 Missing ⚠️

📢 Thoughts on this report? Let us know!

@py4 py4 force-pushed the pr/mean-reward-evaluate branch from ae6e674 to 0b93e88 Compare June 5, 2026 19:12
evaluate() previously returned only binary correctness metrics (corr,
accuracy, partial_accuracy, format_accuracy). For RL training the most
important metric is the actual reward signal — but that was only
available at training time, not during PRE / POST / intermediate eval.

Extend evaluate() to accept an optional `reward_fns` list and compute
the per-example sum of all reward functions, returning the mean as a
6th element `mean_reward`. The reward functions used are the same ones
that drive training (per `reward_functions_path` + `reward_functions`
CLI knobs from commit 50eb2ca) — so eval-time mean_reward is exactly
what training optimizes for. No task-specific code is added to maxtext;
whatever scoring scheme the user plugs in becomes both the training
signal AND the eval-time reward metric.

Plumbed through to all three eval call sites in train_rl.py:
  * Pre RL Training log line now includes mean_reward=...
  * Post RL Training log line now includes mean_reward=...
  * Intermediate Eval (step=N) log line now includes mean_reward=...

When no reward_fns is configured, mean_reward is reported as 0.0 and
the rest of evaluate() works exactly as before (backward compatible).

create_rl_components() return signature extended to also return
reward_fns so the eval call sites can pass them along.
@py4 py4 force-pushed the pr/mean-reward-evaluate branch from 0b93e88 to b06b8b6 Compare June 5, 2026 19:25
Copy link
Copy Markdown
Collaborator

@khatwanimohit khatwanimohit left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants