Skip to content

Commit eded896

Browse files
Merge pull request #4030 from AI-Hypercomputer:pr/eval-batch-size
PiperOrigin-RevId: 929448807
2 parents 6fe6bbd + be6e0f4 commit eded896

3 files changed

Lines changed: 21 additions & 2 deletions

File tree

src/maxtext/configs/post_train/rl.yml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,12 @@ rollout_micro_batch_size: -1
123123
# Keep `num_test_batches` low so that evaluation runs quickly. It can be
124124
# increased to a max. of 330 (if batch size is 4).
125125
num_test_batches: 5 # 200
126+
# Optional override: batch size used during post-train RL evaluation. -1 (default)
127+
# = use `batch_size`. Set higher (e.g. 32-128) to feed vLLM bigger batches during
128+
# greedy eval — otherwise eval is bottlenecked by training batch_size, which is
129+
# small for GRPO (e.g. 4 prompts × 8 generations per step). Total eval examples
130+
# = num_test_batches * eval_batch_size, so adjust num_test_batches accordingly.
131+
eval_batch_size: -1
126132
test_batch_start_index: 0
127133
train_fraction: 1.0
128134

src/maxtext/configs/types.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2039,6 +2039,7 @@ class RLDataset(BaseModel):
20392039
batch_size: int = Field(1, description="Global batch size for the dataset loader in RL.")
20402040
num_batches: int = Field(4, description="Number of batches for RL training.")
20412041
num_test_batches: int = Field(5, description="Number of batches for RL evaluation.")
2042+
eval_batch_size: int = Field(-1, description="Batch size for RL evaluation.")
20422043
test_batch_start_index: int = Field(0, description="Start index for the test dataset")
20432044
train_fraction: float = Field(1.0, description="Fraction of the dataset to be used for training.")
20442045
train_micro_batch_size: int = Field(-1, description="Micro batch size for training.")

src/maxtext/trainers/post_train/rl/train_rl.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -342,11 +342,23 @@ def _use_raw_prompt(x):
342342
train_dataset = train_dataset.to_iter_dataset().batch(trainer_config.batch_size)
343343

344344
if trainer_config.num_test_batches > 0:
345+
# eval_batch_size = -1 (default) → use trainer_config.batch_size (legacy
346+
# behavior). Otherwise use the override so vLLM rollout during greedy eval
347+
# can pack more prompts per call — important when training batch_size is
348+
# small (e.g. 4 for GRPO) but the sampler has enough DP replicas to absorb
349+
# a much larger eval batch. Total eval examples = num_test_batches *
350+
# eval_batch_size_for_eval; adjust num_test_batches when changing
351+
# eval_batch_size to keep total eval set size constant.
352+
eval_batch_size_for_eval = (
353+
trainer_config.batch_size
354+
if getattr(trainer_config, "eval_batch_size", -1) <= 0
355+
else trainer_config.eval_batch_size
356+
)
345357
test_dataset = test_dataset.filter(_filter_long_prompts)
346358
test_dataset = test_dataset[
347-
trainer_config.test_batch_start_index : trainer_config.num_test_batches * trainer_config.batch_size
359+
trainer_config.test_batch_start_index : trainer_config.num_test_batches * eval_batch_size_for_eval
348360
]
349-
test_dataset = test_dataset.to_iter_dataset().batch(trainer_config.batch_size)
361+
test_dataset = test_dataset.to_iter_dataset().batch(eval_batch_size_for_eval)
350362

351363
return train_dataset, test_dataset
352364

0 commit comments

Comments
 (0)