Skip to content

Commit 87c9f3f

Browse files
author
Pooya Moradi
committed
Add eval_batch_size knob for faster post-train RL evaluation
Post-train RL evaluation batched at trainer_config.batch_size, which for GRPO is intentionally small (e.g. 4 prompts per training step × 8 generations = 32 trajectories per step — set by the GRPO recipe to keep trainer HBM workable for the backward pass). At eval time this is wasteful: eval is greedy decode only (no backward), so the trainer budget doesn't apply, and vLLM rollout has many DP replicas sitting idle when only 4 prompts are dispatched per batch. Add an `eval_batch_size` knob (default -1 = use batch_size, preserving old behavior) that overrides the batch dimension during dataset preparation for the test split. Setting it to e.g. 128 on a sampler with 8 DP replicas gives a ~32x eval throughput improvement on TPU without affecting training behavior. Total eval examples = num_test_batches * eval_batch_size, so users should adjust num_test_batches when increasing eval_batch_size to keep total eval set size constant.
1 parent c2d7758 commit 87c9f3f

3 files changed

Lines changed: 21 additions & 2 deletions

File tree

src/maxtext/configs/post_train/rl.yml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,12 @@ rollout_micro_batch_size: -1
123123
# Keep `num_test_batches` low so that evaluation runs quickly. It can be
124124
# increased to a max. of 330 (if batch size is 4).
125125
num_test_batches: 5 # 200
126+
# Optional override: batch size used during post-train RL evaluation. -1 (default)
127+
# = use `batch_size`. Set higher (e.g. 32-128) to feed vLLM bigger batches during
128+
# greedy eval — otherwise eval is bottlenecked by training batch_size, which is
129+
# small for GRPO (e.g. 4 prompts × 8 generations per step). Total eval examples
130+
# = num_test_batches * eval_batch_size, so adjust num_test_batches accordingly.
131+
eval_batch_size: -1
126132
test_batch_start_index: 0
127133
train_fraction: 1.0
128134

src/maxtext/configs/types.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2020,6 +2020,7 @@ class RLDataset(BaseModel):
20202020
batch_size: int = Field(1, description="Global batch size for the dataset loader in RL.")
20212021
num_batches: int = Field(4, description="Number of batches for RL training.")
20222022
num_test_batches: int = Field(5, description="Number of batches for RL evaluation.")
2023+
eval_batch_size: int = Field(-1, description="Batch size for RL evaluation.")
20232024
test_batch_start_index: int = Field(0, description="Start index for the test dataset")
20242025
train_fraction: float = Field(1.0, description="Fraction of the dataset to be used for training.")
20252026
train_micro_batch_size: int = Field(-1, description="Micro batch size for training.")

src/maxtext/trainers/post_train/rl/train_rl.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -343,11 +343,23 @@ def _use_raw_prompt(x):
343343
train_dataset = train_dataset.to_iter_dataset().batch(trainer_config.batch_size)
344344

345345
if trainer_config.num_test_batches > 0:
346+
# eval_batch_size = -1 (default) → use trainer_config.batch_size (legacy
347+
# behavior). Otherwise use the override so vLLM rollout during greedy eval
348+
# can pack more prompts per call — important when training batch_size is
349+
# small (e.g. 4 for GRPO) but the sampler has enough DP replicas to absorb
350+
# a much larger eval batch. Total eval examples = num_test_batches *
351+
# eval_batch_size_for_eval; adjust num_test_batches when changing
352+
# eval_batch_size to keep total eval set size constant.
353+
eval_batch_size_for_eval = (
354+
trainer_config.batch_size
355+
if getattr(trainer_config, "eval_batch_size", -1) <= 0
356+
else trainer_config.eval_batch_size
357+
)
346358
test_dataset = test_dataset.filter(_filter_long_prompts)
347359
test_dataset = test_dataset[
348-
trainer_config.test_batch_start_index : trainer_config.num_test_batches * trainer_config.batch_size
360+
trainer_config.test_batch_start_index : trainer_config.num_test_batches * eval_batch_size_for_eval
349361
]
350-
test_dataset = test_dataset.to_iter_dataset().batch(trainer_config.batch_size)
362+
test_dataset = test_dataset.to_iter_dataset().batch(eval_batch_size_for_eval)
351363

352364
return train_dataset, test_dataset
353365

0 commit comments

Comments
 (0)