@@ -282,37 +282,11 @@ def get_rollout_kwargs_for_parallelism(sampler_config, num_sampler_devices):
282282 return rollout_kwargs
283283
284284
285- def rl_train (trainer_config , sampler_config , trainer_devices , sampler_devices ):
286- """
287- Run RL training with the provided configuration.
288-
289- Args:
290- trainer_config: MaxText configuration for the trainer.
291- sampler_config: MaxText configuration for the sampler.
292- trainer_devices: JAX devices for the trainer.
293- sampler_devices: JAX devices for the sampler.
294- """
295- if not trainer_config .debug .rl :
296- # Apply filter to suppress noisy logs
297- noise_filter = max_logging .NoisyLogFilter ()
298- logging .getLogger ().addFilter (noise_filter )
299- absl_logging .get_absl_logger ().addFilter (noise_filter )
300-
301- max_logging .log ("Starting RL Training" )
302- max_logging .log (f"Ensuring TensorBoard log directory exists: { trainer_config .tensorboard_dir } " )
303- if not epath .Path (trainer_config .tensorboard_dir ).exists ():
304- epath .Path (trainer_config .tensorboard_dir ).mkdir (parents = True , exist_ok = True )
305-
306- if not epath .Path (trainer_config .checkpoint_dir ).exists ():
307- epath .Path (trainer_config .checkpoint_dir ).mkdir (parents = True )
308-
309- # Number of training steps.
310- max_train_steps = int (
311- trainer_config .num_batches
312- * trainer_config .rl .num_iterations
313- * trainer_config .train_fraction
314- * trainer_config .num_epoch
315- )
285+ def get_datasets (
286+ model_tokenizer ,
287+ trainer_config ,
288+ ) -> tuple [grain .IterDataset , grain .IterDataset ]:
289+ """Handles loading, templating, filtering, and batching of train/test datasets."""
316290 # ====== Data ======
317291 # Setup data directories
318292 home = os .path .expanduser ("~" ) + "/"
@@ -323,9 +297,6 @@ def rl_train(trainer_config, sampler_config, trainer_devices, sampler_devices):
323297 if not os .path .exists (test_data_dir ):
324298 os .makedirs (test_data_dir )
325299
326- # Create model tokenizer
327- model_tokenizer = AutoTokenizer .from_pretrained (trainer_config .tokenizer_path )
328-
329300 # Load datasets
330301 if trainer_config .dataset_name == "huggingface:nvidia/OpenMathInstruct-2" :
331302 import datasets # pylint: disable=import-outside-toplevel
@@ -334,7 +305,6 @@ def prepare_openinstructmath2_dataset(
334305 split : str = "train_1M" ,
335306 seed : int = 42 ,
336307 test_size : float = 0.05 ,
337- output_key : str = "expected_answer" ,
338308 ):
339309 """Load and split the OpenMathInstruct-2 dataset into train and validation sets using HF's train_test_split."""
340310 max_logging .log (
@@ -422,16 +392,54 @@ def _filter_long_prompts(x):
422392
423393 if trainer_config .debug .rl :
424394 # Let's see how one batch of the dataset looks like!
425- if trainer_config .debug .rl :
426- for i , ele in enumerate (train_dataset ):
427- if i >= 5 :
428- break
429- pprint (ele )
430- if trainer_config .debug .rl :
431- for i , ele in enumerate (test_dataset ):
432- if i >= 5 :
433- break
434- pprint (ele )
395+ for i , ele in enumerate (train_dataset ):
396+ if i >= 5 :
397+ break
398+ pprint (ele )
399+ for i , ele in enumerate (test_dataset ):
400+ if i >= 5 :
401+ break
402+ pprint (ele )
403+
404+ return train_dataset , test_dataset
405+
406+
407+ def rl_train (trainer_config , sampler_config , trainer_devices , sampler_devices ):
408+ """
409+ Run RL training with the provided configuration.
410+
411+ Args:
412+ trainer_config: MaxText configuration for the trainer.
413+ sampler_config: MaxText configuration for the sampler.
414+ trainer_devices: JAX devices for the trainer.
415+ sampler_devices: JAX devices for the sampler.
416+ """
417+ if not trainer_config .debug .rl :
418+ # Apply filter to suppress noisy logs
419+ noise_filter = max_logging .NoisyLogFilter ()
420+ logging .getLogger ().addFilter (noise_filter )
421+ absl_logging .get_absl_logger ().addFilter (noise_filter )
422+
423+ max_logging .log ("Starting RL Training" )
424+ max_logging .log (f"Ensuring TensorBoard log directory exists: { trainer_config .tensorboard_dir } " )
425+ if not epath .Path (trainer_config .tensorboard_dir ).exists ():
426+ epath .Path (trainer_config .tensorboard_dir ).mkdir (parents = True , exist_ok = True )
427+
428+ if not epath .Path (trainer_config .checkpoint_dir ).exists ():
429+ epath .Path (trainer_config .checkpoint_dir ).mkdir (parents = True )
430+
431+ # Number of training steps.
432+ max_train_steps = int (
433+ trainer_config .num_batches
434+ * trainer_config .rl .num_iterations
435+ * trainer_config .train_fraction
436+ * trainer_config .num_epoch
437+ )
438+ # ====== Data ======
439+ # Create model tokenizer
440+ model_tokenizer = AutoTokenizer .from_pretrained (trainer_config .tokenizer_path )
441+
442+ train_dataset , test_dataset = get_datasets (model_tokenizer , trainer_config )
435443
436444 # Load reference model
437445 max_logging .log ("Creating reference model and also meshes for reference and rollout" )
0 commit comments