Skip to content

Add eval_batch_size knob for faster post-train RL evaluation#4030

Merged
copybara-service[bot] merged 1 commit into
mainfrom
pr/eval-batch-size
Jun 9, 2026
Merged

Add eval_batch_size knob for faster post-train RL evaluation#4030
copybara-service[bot] merged 1 commit into
mainfrom
pr/eval-batch-size

Conversation

@py4

@py4 py4 commented Jun 1, 2026

Copy link
Copy Markdown
Collaborator

Post-train RL evaluation is currently batched at trainer_config.batch_size, which for GRPO is intentionally small (e.g. 4 prompts per step × 8 generations = 32 trajectories — set by the GRPO recipe to keep trainer HBM workable for the backward pass). At eval time this is wasteful: eval is greedy decode only (no backward), so the trainer's per-step memory budget doesn't apply, and vLLM rollout has many DP replicas sitting idle when only 4 prompts are dispatched per call.

This PR adds an rl.eval_batch_size knob (default -1 = use batch_size, preserving old behavior) that overrides the batch dimension during dataset preparation for the test split. Setting it to e.g. 128 on a sampler with 8 DP replicas gives a ~32× eval throughput improvement on TPU without affecting training behavior.

Changes (3 files, +21/-2 lines):

  • src/maxtext/configs/types.py: new Pydantic field RLDataset.eval_batch_size: int = -1
  • src/maxtext/configs/post_train/rl.yml: default eval_batch_size: -1 + comment
  • src/maxtext/trainers/post_train/rl/train_rl.py:prepare_datasets: when set and positive, use eval_batch_size for the test split's slice + .batch(...) call

NOTE: total eval examples = num_test_batches * eval_batch_size, so users adjusting eval_batch_size should adjust num_test_batches to keep total eval set size constant.

Backward compatible: default -1 falls back to batch_size (identical to old behavior). No effect on training path.

Checklist

  • Tested locally on TPU v6e 8×8: with eval_batch_size=128 (8 DP replicas), eval over 1408 examples completes in ~3 min vs ~30+ min at batch_size=4
  • Backward compatible: default -1 preserves existing behavior bit-for-bit
  • No effect on training path (only eval-side dataset preparation touched)
  • No effect on non-RL paths (only RL Pydantic config + RL trainer touched)

@codecov

codecov Bot commented Jun 1, 2026

Copy link
Copy Markdown

Codecov Report

✅ All modified and coverable lines are covered by tests.

📢 Thoughts on this report? Let us know!

Comment thread src/maxtext/configs/types.py Outdated
@py4 py4 force-pushed the pr/eval-batch-size branch from 3871de1 to a88dfc1 Compare June 2, 2026 19:52
@py4

py4 commented Jun 2, 2026

Copy link
Copy Markdown
Collaborator Author

Updated description to match the style of sibling RLDataset fields: "Batch size for RL evaluation." PTAL.

@A9isha A9isha left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot Pooya!

trainer-side KV cache is a bit confusing, please consider updating the description given that KV cache is for rollout.

@py4 py4 force-pushed the pr/eval-batch-size branch from a88dfc1 to a5c61ef Compare June 2, 2026 23:13
@py4

py4 commented Jun 2, 2026

Copy link
Copy Markdown
Collaborator Author

@A9isha thanks for the catch — you're right, KV cache lives on the rollout (vLLM) side, not the trainer. Updated both the PR body and commit message to say the trainer per-step batch is sized for backward-pass HBM (activations + grads), and noted that eval is greedy decode only so that constraint doesn't apply at eval time.

@py4 py4 force-pushed the pr/eval-batch-size branch 2 times, most recently from 367b1bd to 87c9f3f Compare June 8, 2026 18:12
Post-train RL evaluation batched at trainer_config.batch_size, which
for GRPO is intentionally small (e.g. 4 prompts per training step ×
8 generations = 32 trajectories per step — set by the GRPO recipe to
keep trainer HBM workable for the backward pass). At eval time this
is wasteful: eval is greedy decode only (no backward), so the trainer
budget doesn't apply, and vLLM rollout has many DP replicas sitting
idle when only 4 prompts are dispatched per batch.

Add an `eval_batch_size` knob (default -1 = use batch_size, preserving
old behavior) that overrides the batch dimension during dataset
preparation for the test split. Setting it to e.g. 128 on a sampler
with 8 DP replicas gives a ~32x eval throughput improvement on TPU
without affecting training behavior.

Total eval examples = num_test_batches * eval_batch_size, so users
should adjust num_test_batches when increasing eval_batch_size to keep
total eval set size constant.
@py4 py4 force-pushed the pr/eval-batch-size branch from 87c9f3f to be6e0f4 Compare June 9, 2026 21:57
@copybara-service copybara-service Bot merged commit eded896 into main Jun 9, 2026
29 checks passed
@copybara-service copybara-service Bot deleted the pr/eval-batch-size branch June 9, 2026 22:51
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants