You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Add mean_reward to evaluate() via reusing user-provided reward functions
evaluate() previously returned only binary correctness metrics (corr,
accuracy, partial_accuracy, format_accuracy). For RL training the most
important metric is the actual reward signal — but that was only
available at training time, not during PRE / POST / intermediate eval.
Extend evaluate() to accept an optional `reward_fns` list and compute
the per-example sum of all reward functions, returning the mean as a
6th element `mean_reward`. The reward functions used are the same ones
that drive training (per `reward_functions_path` + `reward_functions`
CLI knobs from commit 50eb2ca) — so eval-time mean_reward is exactly
what training optimizes for. No task-specific code is added to maxtext;
whatever scoring scheme the user plugs in becomes both the training
signal AND the eval-time reward metric.
Plumbed through to all three eval call sites in train_rl.py:
* Pre RL Training log line now includes mean_reward=...
* Post RL Training log line now includes mean_reward=...
* Intermediate Eval (step=N) log line now includes mean_reward=...
When no reward_fns is configured, mean_reward is reported as 0.0 and
the rest of evaluate() works exactly as before (backward compatible).
create_rl_components() return signature extended to also return
reward_fns so the eval call sites can pass them along.
0 commit comments