Add eval_batch_size knob for faster post-train RL evaluation#4030
Merged
Conversation
Codecov Report✅ All modified and coverable lines are covered by tests. 📢 Thoughts on this report? Let us know! |
Collaborator
Author
|
Updated description to match the style of sibling |
4 tasks
A9isha
approved these changes
Jun 2, 2026
A9isha
left a comment
Collaborator
There was a problem hiding this comment.
Thanks a lot Pooya!
trainer-side KV cache is a bit confusing, please consider updating the description given that KV cache is for rollout.
Collaborator
Author
|
@A9isha thanks for the catch — you're right, KV cache lives on the rollout (vLLM) side, not the trainer. Updated both the PR body and commit message to say the trainer per-step batch is sized for backward-pass HBM (activations + grads), and noted that eval is greedy decode only so that constraint doesn't apply at eval time. |
SurbhiJainUSC
approved these changes
Jun 3, 2026
367b1bd to
87c9f3f
Compare
Post-train RL evaluation batched at trainer_config.batch_size, which for GRPO is intentionally small (e.g. 4 prompts per training step × 8 generations = 32 trajectories per step — set by the GRPO recipe to keep trainer HBM workable for the backward pass). At eval time this is wasteful: eval is greedy decode only (no backward), so the trainer budget doesn't apply, and vLLM rollout has many DP replicas sitting idle when only 4 prompts are dispatched per batch. Add an `eval_batch_size` knob (default -1 = use batch_size, preserving old behavior) that overrides the batch dimension during dataset preparation for the test split. Setting it to e.g. 128 on a sampler with 8 DP replicas gives a ~32x eval throughput improvement on TPU without affecting training behavior. Total eval examples = num_test_batches * eval_batch_size, so users should adjust num_test_batches when increasing eval_batch_size to keep total eval set size constant.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Post-train RL evaluation is currently batched at
trainer_config.batch_size, which for GRPO is intentionally small (e.g. 4 prompts per step × 8 generations = 32 trajectories — set by the GRPO recipe to keep trainer HBM workable for the backward pass). At eval time this is wasteful: eval is greedy decode only (no backward), so the trainer's per-step memory budget doesn't apply, and vLLM rollout has many DP replicas sitting idle when only 4 prompts are dispatched per call.This PR adds an
rl.eval_batch_sizeknob (default-1= usebatch_size, preserving old behavior) that overrides the batch dimension during dataset preparation for the test split. Setting it to e.g. 128 on a sampler with 8 DP replicas gives a ~32× eval throughput improvement on TPU without affecting training behavior.Changes (3 files, +21/-2 lines):
src/maxtext/configs/types.py: new Pydantic fieldRLDataset.eval_batch_size: int = -1src/maxtext/configs/post_train/rl.yml: defaulteval_batch_size: -1+ commentsrc/maxtext/trainers/post_train/rl/train_rl.py:prepare_datasets: when set and positive, useeval_batch_sizefor the test split's slice +.batch(...)callNOTE: total eval examples =
num_test_batches * eval_batch_size, so users adjustingeval_batch_sizeshould adjustnum_test_batchesto keep total eval set size constant.Backward compatible: default
-1falls back tobatch_size(identical to old behavior). No effect on training path.Checklist
eval_batch_size=128(8 DP replicas), eval over 1408 examples completes in ~3 min vs ~30+ min atbatch_size=4-1preserves existing behavior bit-for-bit