Skip to content

Commit d6e1b5b

Browse files
abrichrclaude
andauthored
fix: batch_size must be multiple of num_generations, pad dataset if needed (#244)
TRL requires generation_batch_size % num_generations == 0. With batch_size=1 and num_generations=4, TRL rejects it. Fix: 1. Set per_device_train_batch_size = num_generations (minimum valid) 2. Pad dataset by repeating tasks if len(dataset) < batch_size With 1 task and num_generations=4: dataset padded to 4 rows, batch_size=4, generation_batch_size=4, 4 % 4 == 0 ✓ Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 345f1a9 commit d6e1b5b

1 file changed

Lines changed: 37 additions & 20 deletions

File tree

openadapt_evals/training/trl_wrapper.py

Lines changed: 37 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -211,33 +211,33 @@ def on_step_end(self, args, state, control, **kwargs):
211211
pass
212212

213213
# --- TRL config: use provided or build sensible defaults ---
214-
# CRITICAL: per_device_train_batch_size must be <= len(dataset).
215-
# TRL default is 8, but RL task datasets are typically 1-10 tasks.
216-
# If batch_size > dataset_size, TRL computes 0 steps and exits
217-
# with "There seems not to be a single sample in your epoch_iterator".
214+
# TRL constraints on batch sizing:
215+
# 1. per_device_train_batch_size must be <= len(dataset)
216+
# 2. generation_batch_size must be divisible by num_generations
217+
# 3. generation_batch_size defaults to per_device_train_batch_size
218218
#
219-
# We set batch_size=1 (not n_tasks) because:
220-
# - Each step already does num_generations rollouts per sample
221-
# - batch_size=n_tasks with many tasks could OOM on GPU
222-
# - batch_size=1 matches the standalone trainer (one task per step,
223-
# rotating through tasks via epochs)
219+
# Therefore: per_device_train_batch_size must be a MULTIPLE of
220+
# num_generations AND <= len(dataset). The minimum valid value is
221+
# num_generations itself. If the dataset is smaller, we pad it
222+
# by repeating tasks to reach at least that size.
223+
num_gen = self._config.num_rollouts_per_step
224224
n_tasks = len(task_configs)
225225

226226
if self._trl_config is not None:
227227
trl_config = self._trl_config
228-
# Warn if user-provided config has batch_size > dataset
229228
bs = getattr(trl_config, "per_device_train_batch_size", 8)
230-
if bs > n_tasks:
229+
ng = getattr(trl_config, "num_generations", num_gen)
230+
if bs % ng != 0:
231231
logger.warning(
232-
"per_device_train_batch_size=%d > dataset size=%d. "
233-
"TRL will compute 0 steps and exit immediately. "
234-
"Set per_device_train_batch_size=1 or add more tasks.",
235-
bs, n_tasks,
232+
"per_device_train_batch_size=%d is not divisible by "
233+
"num_generations=%d. TRL will reject this. "
234+
"Set per_device_train_batch_size=%d.",
235+
bs, ng, ng,
236236
)
237237
else:
238238
trl_config = GRPOConfig(
239239
output_dir=self._config.output_dir,
240-
num_generations=self._config.num_rollouts_per_step,
240+
num_generations=num_gen,
241241
max_completion_length=self._config.max_new_tokens,
242242
max_steps=self._config.num_training_steps,
243243
learning_rate=self._config.learning_rate,
@@ -246,12 +246,29 @@ def on_step_end(self, args, state, control, **kwargs):
246246
bf16=True,
247247
loss_type="grpo",
248248
num_train_epochs=1,
249-
# batch_size=1: each step processes one task with
250-
# num_generations rollouts. Tasks rotate via epochs.
251-
# Default of 8 causes "0 steps" with small task sets.
252-
per_device_train_batch_size=1,
249+
# batch_size = num_generations: TRL requires
250+
# batch_size % num_generations == 0. This is the
251+
# minimum valid value. Each step processes
252+
# batch_size prompts × num_generations rollouts each.
253+
per_device_train_batch_size=num_gen,
253254
)
254255

256+
# Pad dataset if needed: TRL needs len(dataset) >= batch_size.
257+
# With 1 task and batch_size=4, we repeat the task 4 times.
258+
# Each row triggers the same rollout_func, so repeats are fine
259+
# for RL (same task, many rollouts = more learning signal).
260+
bs = getattr(trl_config, "per_device_train_batch_size", num_gen)
261+
if len(dataset) < bs:
262+
import math
263+
repeats = math.ceil(bs / len(dataset))
264+
logger.info(
265+
"Padding dataset from %d to %d rows (repeating tasks %dx) "
266+
"to meet per_device_train_batch_size=%d",
267+
len(dataset), len(dataset) * repeats, repeats, bs,
268+
)
269+
padded = {k: v * repeats for k, v in dataset.to_dict().items()}
270+
dataset = Dataset.from_dict(padded)
271+
255272
# --- Train ---
256273
trainer = _TRLTrainer(
257274
model=model,

0 commit comments

Comments
 (0)