@@ -282,39 +282,18 @@ def get_rollout_kwargs_for_parallelism(sampler_config, num_sampler_devices):
282282 return rollout_kwargs
283283
284284
285- def rl_train (trainer_config , sampler_config , trainer_devices , sampler_devices ):
286- """
287- Run RL training with the provided configuration.
288-
289- Args:
290- trainer_config: MaxText configuration for the trainer.
291- sampler_config: MaxText configuration for the sampler.
292- trainer_devices: JAX devices for the trainer.
293- sampler_devices: JAX devices for the sampler.
294- """
295- if not trainer_config .debug .rl :
296- # Apply filter to suppress noisy logs
297- noise_filter = max_logging .NoisyLogFilter ()
298- logging .getLogger ().addFilter (noise_filter )
299- absl_logging .get_absl_logger ().addFilter (noise_filter )
300-
301- max_logging .log ("Starting RL Training" )
302- max_logging .log (f"Ensuring TensorBoard log directory exists: { trainer_config .tensorboard_dir } " )
303- if not epath .Path (trainer_config .tensorboard_dir ).exists ():
304- epath .Path (trainer_config .tensorboard_dir ).mkdir (parents = True , exist_ok = True )
305-
306- if not epath .Path (trainer_config .checkpoint_dir ).exists ():
307- epath .Path (trainer_config .checkpoint_dir ).mkdir (parents = True )
308-
309- # Number of training steps.
310- max_train_steps = int (
285+ def get_max_train_steps (trainer_config ):
286+ """Calculate the total number of training steps."""
287+ return int (
311288 trainer_config .num_batches
312289 * trainer_config .rl .num_iterations
313290 * trainer_config .train_fraction
314291 * trainer_config .num_epoch
315292 )
316- # ====== Data ======
317- # Setup data directories
293+
294+
295+ def prepare_datasets (trainer_config , model_tokenizer ):
296+ """Setup and return train and test datasets."""
318297 home = os .path .expanduser ("~" ) + "/"
319298 train_data_dir = f"{ home } /data/train"
320299 test_data_dir = f"{ home } /data/test"
@@ -323,9 +302,6 @@ def rl_train(trainer_config, sampler_config, trainer_devices, sampler_devices):
323302 if not os .path .exists (test_data_dir ):
324303 os .makedirs (test_data_dir )
325304
326- # Create model tokenizer
327- model_tokenizer = AutoTokenizer .from_pretrained (trainer_config .tokenizer_path )
328-
329305 # Load datasets
330306 if trainer_config .dataset_name == "huggingface:nvidia/OpenMathInstruct-2" :
331307 import datasets # pylint: disable=import-outside-toplevel
@@ -334,7 +310,6 @@ def prepare_openinstructmath2_dataset(
334310 split : str = "train_1M" ,
335311 seed : int = 42 ,
336312 test_size : float = 0.05 ,
337- output_key : str = "expected_answer" ,
338313 ):
339314 """Load and split the OpenMathInstruct-2 dataset into train and validation sets using HF's train_test_split."""
340315 max_logging .log (
@@ -419,41 +394,16 @@ def _filter_long_prompts(x):
419394 test_dataset = test_dataset [: trainer_config .num_test_batches * trainer_config .batch_size ]
420395
421396 test_dataset = test_dataset .to_iter_dataset ().batch (trainer_config .batch_size )
397+ return train_dataset , test_dataset
422398
423- if trainer_config .debug .rl :
424- # Let's see how one batch of the dataset looks like!
425- if trainer_config .debug .rl :
426- for i , ele in enumerate (train_dataset ):
427- if i >= 5 :
428- break
429- pprint (ele )
430- if trainer_config .debug .rl :
431- for i , ele in enumerate (test_dataset ):
432- if i >= 5 :
433- break
434- pprint (ele )
435-
436- # Load reference model
399+
400+ def create_models_and_meshes (trainer_config , sampler_config , trainer_devices , sampler_devices ):
401+ """Create reference and actor models and their respective meshes."""
437402 max_logging .log ("Creating reference model and also meshes for reference and rollout" )
438403 reference_model , reference_mesh = get_maxtext_model (trainer_config , trainer_devices )
439404 devices_array = maxtext_utils .create_device_mesh (sampler_config , sampler_devices )
440- # if trainer_devices=sampler_devices, then rollout_mesh=reference_mesh
441- # else rollout_mesh uses sampler_devices
442405 rollout_mesh = Mesh (devices_array , sampler_config .mesh_axes )
443- if trainer_config .debug .rl :
444- max_logging .log ("Reference Model initialized successfully" )
445- nnx .display (reference_model )
446- max_logging .log (f"Reference mesh shape: { reference_mesh .shape } " )
447406
448- # Sanity check that weights are loaded correctly.
449- _maxtext_state_flatten = nnx .state (reference_model ).flat_state ()
450- maxtext_state_flatten = {"." .join (str (key ) for key in keys ): v for keys , v in _maxtext_state_flatten }
451- max_logging .log (
452- f"maxtext_state_flatten[base.token_embedder.embedding].value=\
453- { maxtext_state_flatten ['base.token_embedder.embedding' ][...]} "
454- )
455-
456- # TODO: @mazumdera: change this to use lora
457407 if trainer_config .load_checkpoint_only_once :
458408 max_logging .log ("Creating policy model by copying reference model instead of restoring from checkpoint again." )
459409 with reference_mesh :
@@ -466,11 +416,22 @@ def _filter_long_prompts(x):
466416 max_logging .log ("Creating policy model with same config as reference model on trainer mesh" )
467417 actor_model , actor_mesh = get_maxtext_model (trainer_config , trainer_devices )
468418
469- if trainer_config .debug .rl :
470- max_logging .log ("Policy Model initialized successfully" )
471- nnx .display (actor_model )
472- max_logging .log (f"Policy mesh shape: { actor_mesh .shape } " )
473-
419+ return reference_model , reference_mesh , actor_model , actor_mesh , rollout_mesh
420+
421+
422+ def create_rl_components (
423+ trainer_config ,
424+ sampler_config ,
425+ sampler_devices ,
426+ actor_model ,
427+ actor_mesh ,
428+ reference_model ,
429+ reference_mesh ,
430+ rollout_mesh ,
431+ model_tokenizer ,
432+ max_train_steps ,
433+ ):
434+ """Setup RL cluster, trainer, and optimizer."""
474435 # Setup optimizer
475436 optimizer = utils_rl .get_optimizer (trainer_config , max_train_steps )
476437
@@ -483,7 +444,6 @@ def _filter_long_prompts(x):
483444 micro_batch_size = None if trainer_config .micro_batch_size == - 1 else trainer_config .micro_batch_size
484445
485446 # Setup metrics logging
486- max_logging .log (f"Tensorboard logs directory: { trainer_config .tensorboard_dir } " )
487447 metrics_logging_options = metrics_logger .MetricsLoggerOptions (
488448 log_dir = trainer_config .tensorboard_dir , flush_every_n_steps = trainer_config .log_period
489449 )
@@ -501,25 +461,18 @@ def _filter_long_prompts(x):
501461 rollout_additional_config = None
502462 if trainer_config .vllm_additional_config :
503463 if isinstance (trainer_config .vllm_additional_config , dict ):
504- # It's already parsed into a dict
505464 rollout_additional_config = trainer_config .vllm_additional_config
506465 elif isinstance (trainer_config .vllm_additional_config , str ):
507- # It's a string, so we need to parse it
508466 try :
509467 rollout_additional_config = json .loads (trainer_config .vllm_additional_config )
510468 except json .JSONDecodeError as e :
511469 raise ValueError (f"Failed to parse additional_config JSON: { e } " ) from e
512470
513- max_logging .log (f"Parsed additional config: { rollout_additional_config } " )
514-
515471 # We need to parse vLLM config to get the logical axis rules for the sampler config.
516472 vllm_config_path = os .path .join (MAXTEXT_CONFIGS_DIR , "inference" , "vllm.yml" )
517473 argv_list = ["" , str (vllm_config_path ), "log_config=False" ]
518474 vllm_config = pyconfig .initialize (argv_list )
519475
520- # RL Cluster config
521- # Note that we use vLLM as the rollout engine.
522- # and we are using Tensor Parallelism for rollout
523476 cluster_config = rl_cluster_lib .ClusterConfig (
524477 role_to_mesh = {
525478 rl_cluster_lib .Role .ACTOR : actor_mesh ,
@@ -537,15 +490,11 @@ def _filter_long_prompts(x):
537490 actor_optimizer = optimizer ,
538491 eval_every_n_steps = trainer_config .eval_interval ,
539492 max_steps = max_train_steps ,
540- # Micro batching
541493 mini_batch_size = trainer_config .batch_size ,
542494 train_micro_batch_size = micro_batch_size ,
543495 rollout_micro_batch_size = micro_batch_size ,
544- # Metrics logging
545496 metrics_logging_options = metrics_logging_options ,
546- # Profiling
547497 profiler_options = profiler_options ,
548- # Checkpoint saving
549498 checkpoint_root_directory = trainer_config .checkpoint_dir ,
550499 checkpointing_options = checkpointing_options ,
551500 ),
@@ -579,6 +528,7 @@ def _filter_long_prompts(x):
579528 ** get_rollout_kwargs_for_parallelism (sampler_config , len (sampler_devices )),
580529 ),
581530 )
531+
582532 grpo_config = GrpoConfig (
583533 num_generations = trainer_config .rl .num_generations ,
584534 num_iterations = trainer_config .rl .num_iterations ,
@@ -595,9 +545,6 @@ def _filter_long_prompts(x):
595545 from tunix .perf import export as perf_export # pylint: disable=import-outside-toplevel
596546 from tunix .perf import metrics as perf_metrics # pylint: disable=import-outside-toplevel
597547
598- max_logging .log (
599- "enable_tunix_perf_metrics is True and tunix.perf modules are available, enabling Tunix-managed metrics."
600- )
601548 perf_config = perf_metrics .PerfMetricsConfig ()
602549 perf_config .custom_export_fn = perf_export .PerfMetricsExport .create_metrics_export_fn (cluster_config )
603550 rl_cluster_kwargs ["perf_config" ] = perf_config
@@ -627,9 +574,76 @@ def _filter_long_prompts(x):
627574 algo_config = grpo_config ,
628575 )
629576
577+ return rl_cluster , rl_trainer , optimizer
578+
579+
580+ def rl_train (trainer_config , sampler_config , trainer_devices , sampler_devices ):
581+ """
582+ Run RL training with the provided configuration.
583+
584+ Args:
585+ trainer_config: MaxText configuration for the trainer.
586+ sampler_config: MaxText configuration for the sampler.
587+ trainer_devices: JAX devices for the trainer.
588+ sampler_devices: JAX devices for the sampler.
589+ """
590+ if not trainer_config .debug .rl :
591+ # Apply filter to suppress noisy logs
592+ noise_filter = max_logging .NoisyLogFilter ()
593+ logging .getLogger ().addFilter (noise_filter )
594+ absl_logging .get_absl_logger ().addFilter (noise_filter )
595+
596+ max_logging .log ("Starting RL Training" )
597+ if not epath .Path (trainer_config .tensorboard_dir ).exists ():
598+ epath .Path (trainer_config .tensorboard_dir ).mkdir (parents = True , exist_ok = True )
599+
600+ if not epath .Path (trainer_config .checkpoint_dir ).exists ():
601+ epath .Path (trainer_config .checkpoint_dir ).mkdir (parents = True )
602+
603+ max_train_steps = get_max_train_steps (trainer_config )
604+
605+ # Create model tokenizer
606+ model_tokenizer = AutoTokenizer .from_pretrained (trainer_config .tokenizer_path )
607+
608+ train_dataset , test_dataset = prepare_datasets (trainer_config , model_tokenizer )
609+
610+ if trainer_config .debug .rl :
611+ for i , ele in enumerate (train_dataset ):
612+ if i >= 5 :
613+ break
614+ pprint (ele )
615+ for i , ele in enumerate (test_dataset ):
616+ if i >= 5 :
617+ break
618+ pprint (ele )
619+
620+ reference_model , reference_mesh , actor_model , actor_mesh , rollout_mesh = create_models_and_meshes (
621+ trainer_config , sampler_config , trainer_devices , sampler_devices
622+ )
623+
624+ if trainer_config .debug .rl :
625+ max_logging .log ("Reference Model initialized successfully" )
626+ nnx .display (reference_model )
627+ max_logging .log (f"Reference mesh shape: { reference_mesh .shape } " )
628+ max_logging .log ("Policy Model initialized successfully" )
629+ nnx .display (actor_model )
630+ max_logging .log (f"Policy mesh shape: { actor_mesh .shape } " )
631+
632+ rl_cluster , rl_trainer , _ = create_rl_components (
633+ trainer_config ,
634+ sampler_config ,
635+ sampler_devices ,
636+ actor_model ,
637+ actor_mesh ,
638+ reference_model ,
639+ reference_mesh ,
640+ rollout_mesh ,
641+ model_tokenizer ,
642+ max_train_steps ,
643+ )
644+
630645 # Before we train the model, let's evaluate the model on the test set so we can
631646 # see the improvement post training.
632- #
633647 (corr , total , accuracy , partial_accuracy , format_accuracy ), _ = evaluate (
634648 trainer_config ,
635649 test_dataset ,
@@ -638,11 +652,9 @@ def _filter_long_prompts(x):
638652 corr_lst = trainer_config .eval_corr_lst ,
639653 make_lst = trainer_config .eval_make_lst ,
640654 )
641- # TODO: @mazumdera: Change this to max_logging.log once b/473703277 is resolved
642655 max_logging .warning (f"Pre RL Training: { corr = } , { total = } , { accuracy = } %, { partial_accuracy = } %," f" { format_accuracy = } %" )
643656
644657 # Start training
645-
646658 if trainer_config .load_checkpoint_only_once :
647659 max_logging .log ("Capturing reference model state before training." )
648660 ref_state_before = nnx .to_pure_dict (nnx .state (reference_model .base , nnx .Param ))
0 commit comments