Add mean_reward to evaluate() via reusing user-provided reward functions#4083
Open
py4 wants to merge 1 commit into
Open
Add mean_reward to evaluate() via reusing user-provided reward functions#4083py4 wants to merge 1 commit into
py4 wants to merge 1 commit into
Conversation
Codecov Report❌ Patch coverage is
📢 Thoughts on this report? Let us know! |
ae6e674 to
0b93e88
Compare
evaluate() previously returned only binary correctness metrics (corr, accuracy, partial_accuracy, format_accuracy). For RL training the most important metric is the actual reward signal — but that was only available at training time, not during PRE / POST / intermediate eval. Extend evaluate() to accept an optional `reward_fns` list and compute the per-example sum of all reward functions, returning the mean as a 6th element `mean_reward`. The reward functions used are the same ones that drive training (per `reward_functions_path` + `reward_functions` CLI knobs from commit 50eb2ca) — so eval-time mean_reward is exactly what training optimizes for. No task-specific code is added to maxtext; whatever scoring scheme the user plugs in becomes both the training signal AND the eval-time reward metric. Plumbed through to all three eval call sites in train_rl.py: * Pre RL Training log line now includes mean_reward=... * Post RL Training log line now includes mean_reward=... * Intermediate Eval (step=N) log line now includes mean_reward=... When no reward_fns is configured, mean_reward is reported as 0.0 and the rest of evaluate() works exactly as before (backward compatible). create_rl_components() return signature extended to also return reward_fns so the eval call sites can pass them along.
0b93e88 to
b06b8b6
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
evaluate()previously returned only binary correctness metrics (corr,accuracy,partial_accuracy,format_accuracy). For RL training the most important metric is the actual reward signal, but that was only available at training time, not during PRE / POST / intermediate eval. This PR extendsevaluate()to also compute the per-example mean reward using the same reward functions that drive training.Changes
evaluate()accepts an optionalreward_fnslist. When provided, the per-example score is the SUM across all reward functions (matching tunix's GRPO aggregation), and the per-example mean is returned as a 6th tuple elementmean_reward. WhenNoneor empty,mean_rewardis0.0and the rest ofevaluate()is unchanged (bit-for-bit backward compatible).The reward functions used are the same ones the training stack builds (via
reward_functions_path+reward_functionsCLI knobs in Add reward_functions_path + reward_functions CLI knobs for custom rewards #4045-to-be), so eval-timemean_rewardis exactly what training optimizes for. No task-specific code is added to maxtext; whatever scoring scheme the user plugs in becomes both the training signal AND the eval-time reward metric.Plumbed through to all three eval call sites in
train_rl.py:mean_reward=...mean_reward=...mean_reward=...create_rl_components()return signature extended to also returnreward_fnsso the eval call sites can pass them along.install_training_hooks/RLTrainingHooks(added in Add intermediate eval hook: fire evaluate() every eval_interval outer steps #4044) extended to acceptreward_fnsand forward it to the intermediateevaluate()call.Backward compatibility
Default
reward_fns=Noneeverywhere preserves existing behavior. Existing recipes seemean_reward=0.0in the eval tuple and no other change.Files
src/maxtext/trainers/post_train/rl/evaluate_rl.pyreward_fnsparam + 6th tuple elementsrc/maxtext/trainers/post_train/rl/train_rl.pysrc/maxtext/trainers/post_train/rl/hooks.pyRLTrainingHooks.__init__takes reward_fns;on_train_step_endforwardssrc/maxtext/trainers/post_train/rl/utils_rl.pyinstall_training_hooksaccepts and forwards reward_fnstests/post_training/unit/rl_hooks_test.pyTests
The two new tests on
RLTrainingHooksTest:test_reward_fns_plumbed_to_evaluate: a sentinelreward_fnspassed to the hook is forwarded toevaluate(...).test_reward_fns_default_none: whenreward_fnsis omitted,evaluate(...)receivesreward_fns=None.Existing tests' mock return updated to the new 6-tuple shape (added the
mean_reward=0.426th element).Checklist
--pyink-indentation=2 --line-length=122)reward_fns=Noneeverywhere