@@ -211,33 +211,33 @@ def on_step_end(self, args, state, control, **kwargs):
211211 pass
212212
213213 # --- TRL config: use provided or build sensible defaults ---
214- # CRITICAL: per_device_train_batch_size must be <= len(dataset).
215- # TRL default is 8, but RL task datasets are typically 1-10 tasks.
216- # If batch_size > dataset_size, TRL computes 0 steps and exits
217- # with "There seems not to be a single sample in your epoch_iterator".
214+ # TRL constraints on batch sizing:
215+ # 1. per_device_train_batch_size must be <= len(dataset)
216+ # 2. generation_batch_size must be divisible by num_generations
217+ # 3. generation_batch_size defaults to per_device_train_batch_size
218218 #
219- # We set batch_size=1 (not n_tasks) because:
220- # - Each step already does num_generations rollouts per sample
221- # - batch_size=n_tasks with many tasks could OOM on GPU
222- # - batch_size=1 matches the standalone trainer (one task per step,
223- # rotating through tasks via epochs)
219+ # Therefore: per_device_train_batch_size must be a MULTIPLE of
220+ # num_generations AND <= len(dataset). The minimum valid value is
221+ # num_generations itself. If the dataset is smaller, we pad it
222+ # by repeating tasks to reach at least that size.
223+ num_gen = self . _config . num_rollouts_per_step
224224 n_tasks = len (task_configs )
225225
226226 if self ._trl_config is not None :
227227 trl_config = self ._trl_config
228- # Warn if user-provided config has batch_size > dataset
229228 bs = getattr (trl_config , "per_device_train_batch_size" , 8 )
230- if bs > n_tasks :
229+ ng = getattr (trl_config , "num_generations" , num_gen )
230+ if bs % ng != 0 :
231231 logger .warning (
232- "per_device_train_batch_size=%d > dataset size=%d. "
233- "TRL will compute 0 steps and exit immediately . "
234- "Set per_device_train_batch_size=1 or add more tasks ." ,
235- bs , n_tasks ,
232+ "per_device_train_batch_size=%d is not divisible by "
233+ "num_generations=%d. TRL will reject this . "
234+ "Set per_device_train_batch_size=%d ." ,
235+ bs , ng , ng ,
236236 )
237237 else :
238238 trl_config = GRPOConfig (
239239 output_dir = self ._config .output_dir ,
240- num_generations = self . _config . num_rollouts_per_step ,
240+ num_generations = num_gen ,
241241 max_completion_length = self ._config .max_new_tokens ,
242242 max_steps = self ._config .num_training_steps ,
243243 learning_rate = self ._config .learning_rate ,
@@ -246,12 +246,29 @@ def on_step_end(self, args, state, control, **kwargs):
246246 bf16 = True ,
247247 loss_type = "grpo" ,
248248 num_train_epochs = 1 ,
249- # batch_size=1: each step processes one task with
250- # num_generations rollouts. Tasks rotate via epochs.
251- # Default of 8 causes "0 steps" with small task sets.
252- per_device_train_batch_size = 1 ,
249+ # batch_size = num_generations: TRL requires
250+ # batch_size % num_generations == 0. This is the
251+ # minimum valid value. Each step processes
252+ # batch_size prompts × num_generations rollouts each.
253+ per_device_train_batch_size = num_gen ,
253254 )
254255
256+ # Pad dataset if needed: TRL needs len(dataset) >= batch_size.
257+ # With 1 task and batch_size=4, we repeat the task 4 times.
258+ # Each row triggers the same rollout_func, so repeats are fine
259+ # for RL (same task, many rollouts = more learning signal).
260+ bs = getattr (trl_config , "per_device_train_batch_size" , num_gen )
261+ if len (dataset ) < bs :
262+ import math
263+ repeats = math .ceil (bs / len (dataset ))
264+ logger .info (
265+ "Padding dataset from %d to %d rows (repeating tasks %dx) "
266+ "to meet per_device_train_batch_size=%d" ,
267+ len (dataset ), len (dataset ) * repeats , repeats , bs ,
268+ )
269+ padded = {k : v * repeats for k , v in dataset .to_dict ().items ()}
270+ dataset = Dataset .from_dict (padded )
271+
255272 # --- Train ---
256273 trainer = _TRLTrainer (
257274 model = model ,
0 commit comments