4747from functools import wraps
4848from typing import Sequence
4949
50- import collections
5150import grain
5251import jax
5352import json
6059from absl import logging as absl_logging
6160from etils import epath
6261from flax import nnx
63- from jax .sharding import Mesh
6462from orbax import checkpoint as ocp
6563from pprint import pprint
6664from transformers import AutoTokenizer
7472
7573from maxtext .configs import pyconfig
7674from maxtext .utils .globals import MAXTEXT_CONFIGS_DIR
77- from maxtext .integration .tunix .tunix_adapter import TunixMaxTextAdapter
7875from maxtext .trainers .post_train .rl .evaluate_rl import evaluate
7976from maxtext .trainers .post_train .rl import utils_rl
8077from maxtext .input_pipeline .instruction_data_processing import load_template_from_file
81- from maxtext .utils import max_logging , max_utils , maxtext_utils , model_creation_utils
82-
83-
84- def get_maxtext_model (config , devices = None ):
85- """
86- Load MaxText model with Tunix adapter.
87- # Note: pass the path to your scanned checkpoint for 'load_parameters_path'.
88- # To create a scanned checkpoint, you can use /maxtext/src/maxtext/checkpoint_conversion/to_maxtext.py and if
89- # using Pathways, please set `USE_PATHWAYS=1` and use `$((1 - USE_PATHWAYS))` for storage flags:
90- # export USE_PATHWAYS=1
91- # python src/maxtext/checkpoint_conversion/to_maxtext.py \
92- # --model_name="gemma2-2b" \
93- # --base_output_directory="/path/to/your/output/directory" \
94- # --scan_layers=True \
95- # --checkpoint_storage_use_ocdbt=$((1 - USE_PATHWAYS)) \
96- # --checkpoint_storage_use_zarr3=$((1 - USE_PATHWAYS))
97- # Please ensure that you pass the full path ending in `/0/items` for load_parameters_path to train_rl.py i.e.,
98- # load_parameters_path=/path/to/your/output/directory/0/items
99- """
100- model , mesh = model_creation_utils .create_nnx_model (config , devices = devices )
101- with mesh :
102- use_no_op_mappings = "maxtext_config" in config .vllm_additional_config
103- tunix_model = TunixMaxTextAdapter (base_model = model , use_no_op_mappings = use_no_op_mappings )
104- tunix_model .config = None
105- return tunix_model , mesh
78+ from maxtext .utils import max_logging , max_utils , model_creation_utils
10679
10780
10881def get_dataset (
@@ -150,86 +123,6 @@ def get_dataset(
150123 return loaded_dataset
151124
152125
153- def setup_configs_and_devices (argv : list [str ]):
154- """Setup device allocation and configs for training and inference."""
155- config = pyconfig .initialize_pydantic (argv )
156- devices = jax .devices ()
157- if config .num_trainer_slices == - 1 and config .num_samplers_slices == - 1 :
158- max_logging .log ("Running RL on a single slice" )
159- num_vms = len (devices ) // config .chips_per_vm
160- trainer_devices = devices
161- sampler_devices = devices
162- if num_vms >= 2 and config .use_pathways :
163- # Multiple hosts with Pathways - potentially split devices for trainer and sampler
164- # based on trainer_devices_fraction and sampler_devices_fraction
165- max_logging .log (f"{ num_vms } VMs detected, allocating trainer and sampler devices, and using Pathways." )
166- num_devices = len (devices )
167- num_trainer_devices = int (num_devices * config .trainer_devices_fraction )
168- num_sampler_devices = int (num_devices * config .sampler_devices_fraction )
169- trainer_devices = devices [:num_trainer_devices ]
170- sampler_devices = devices [num_devices - num_sampler_devices :]
171- if config .trainer_devices_fraction != 1.0 :
172- max_logging .log (f"Using first { len (trainer_devices )} devices as Trainer devices" )
173- if config .sampler_devices_fraction != 1.0 :
174- max_logging .log (f"Using last { len (sampler_devices )} devices as Sampler devices" )
175- trainer_config = config
176- sampler_config = config
177- elif config .num_trainer_slices > 0 and config .num_samplers_slices > 0 :
178- max_logging .log ("Running RL with Multislice" )
179- devices_by_slice = collections .defaultdict (list )
180- for d in devices :
181- devices_by_slice [d .slice_index ].append (d )
182- slice_indices = sorted (devices_by_slice .keys ())
183-
184- if len (slice_indices ) < config .num_trainer_slices + config .num_samplers_slices :
185- raise ValueError ("Not enough slices for trainer and samplers" )
186-
187- trainer_devices = []
188- for i in range (config .num_trainer_slices ):
189- trainer_devices .extend (devices_by_slice [slice_indices [i ]])
190-
191- sampler_devices = []
192- for i in range (config .num_trainer_slices , config .num_trainer_slices + config .num_samplers_slices ):
193- sampler_devices .extend (devices_by_slice [slice_indices [i ]])
194-
195- trainer_devices_per_slice = len (trainer_devices ) // config .num_trainer_slices
196- trainer_fsdp = trainer_devices_per_slice
197- tp = config .ici_tensor_parallelism
198- if tp > 1 :
199- if trainer_devices_per_slice % tp != 0 :
200- raise ValueError (
201- f"trainer_devices_per_slice ({ trainer_devices_per_slice } ) must be divisible by tensor parallelism ({ tp } )"
202- )
203- if config .ici_fsdp_parallelism != - 1 and config .ici_fsdp_parallelism * tp != trainer_devices_per_slice :
204- raise ValueError (
205- f"ici_fsdp_parallelism ({ config .ici_fsdp_parallelism } ) * ici_tensor_parallelism ({ tp } ) must equal "
206- f"devices_per_slice ({ trainer_devices_per_slice } )"
207- )
208- trainer_fsdp = trainer_devices_per_slice // tp
209-
210- trainer_update = {
211- "num_slices" : config .num_trainer_slices ,
212- "ici_fsdp_parallelism" : trainer_fsdp ,
213- "ici_tensor_parallelism" : tp ,
214- "dcn_data_parallelism" : config .num_trainer_slices ,
215- }
216-
217- sampler_update = {
218- "num_slices" : config .num_samplers_slices ,
219- "ici_fsdp_parallelism" : len (sampler_devices ) // config .num_samplers_slices ,
220- "ici_tensor_parallelism" : - 1 ,
221- "dcn_data_parallelism" : config .num_samplers_slices ,
222- }
223-
224- trainer_config = pyconfig .initialize_pydantic (argv , ** trainer_update )
225- sampler_config = pyconfig .initialize_pydantic (argv , ** sampler_update )
226-
227- else :
228- raise ValueError ("num_trainer_slices and num_samplers_slices should be both -1 or positive" )
229-
230- return trainer_config , sampler_config , trainer_devices , sampler_devices
231-
232-
233126def get_rollout_kwargs_for_parallelism (sampler_config , num_sampler_devices ):
234127 """Get rollout kwargs for vLLM rollout when using data parallelism."""
235128 dp = sampler_config .rollout_data_parallelism
@@ -411,28 +304,6 @@ def _use_raw_prompt(x):
411304 return train_dataset , test_dataset
412305
413306
414- def create_models_and_meshes (trainer_config , sampler_config , trainer_devices , sampler_devices ):
415- """Create reference and actor models and their respective meshes."""
416- max_logging .log ("Creating reference model and also meshes for reference and rollout" )
417- reference_model , reference_mesh = get_maxtext_model (trainer_config , trainer_devices )
418- devices_array = maxtext_utils .create_device_mesh (sampler_config , sampler_devices )
419- rollout_mesh = Mesh (devices_array , sampler_config .mesh_axes )
420-
421- if trainer_config .load_checkpoint_only_once :
422- max_logging .log ("Creating policy model by copying reference model instead of restoring from checkpoint again." )
423- with reference_mesh :
424- actor_base_model = nnx .clone (reference_model .base )
425- use_no_op_mappings = "maxtext_config" in trainer_config .vllm_additional_config
426- actor_model = TunixMaxTextAdapter (base_model = actor_base_model , use_no_op_mappings = use_no_op_mappings )
427- actor_model .config = None
428- actor_mesh = reference_mesh
429- else :
430- max_logging .log ("Creating policy model with same config as reference model on trainer mesh" )
431- actor_model , actor_mesh = get_maxtext_model (trainer_config , trainer_devices )
432-
433- return reference_model , reference_mesh , actor_model , actor_mesh , rollout_mesh
434-
435-
436307def create_rl_components (
437308 trainer_config ,
438309 sampler_config ,
@@ -652,7 +523,7 @@ def _reward_fn(**kwargs):
652523 return rl_cluster , rl_trainer , optimizer
653524
654525
655- def rl_train (trainer_config , sampler_config , trainer_devices , sampler_devices ):
526+ def rl_train (argv : Sequence [ str ], kwargs : dict ):
656527 """
657528 Run RL training with the provided configuration.
658529
@@ -662,13 +533,21 @@ def rl_train(trainer_config, sampler_config, trainer_devices, sampler_devices):
662533 trainer_devices: JAX devices for the trainer.
663534 sampler_devices: JAX devices for the sampler.
664535 """
536+ trainer_config , sampler_config , trainer_devices , sampler_devices = model_creation_utils .setup_configs_and_devices (
537+ argv , kwargs
538+ )
539+
540+ reference_model , reference_mesh , actor_model , actor_mesh , rollout_mesh = model_creation_utils .create_models_and_meshes (
541+ trainer_config , sampler_config , trainer_devices , sampler_devices
542+ )
543+
665544 if not trainer_config .debug .rl :
666545 # Apply filter to suppress noisy logs
667546 noise_filter = max_logging .NoisyLogFilter ()
668547 logging .getLogger ().addFilter (noise_filter )
669548 absl_logging .get_absl_logger ().addFilter (noise_filter )
549+ os .environ ["VLLM_LOGGING_LEVEL" ] = "ERROR"
670550
671- max_logging .log ("Starting RL Training" )
672551 if not epath .Path (trainer_config .tensorboard_dir ).exists ():
673552 epath .Path (trainer_config .tensorboard_dir ).mkdir (parents = True , exist_ok = True )
674553
@@ -692,17 +571,14 @@ def rl_train(trainer_config, sampler_config, trainer_devices, sampler_devices):
692571 break
693572 pprint (ele )
694573
695- reference_model , reference_mesh , actor_model , actor_mesh , rollout_mesh = create_models_and_meshes (
696- trainer_config , sampler_config , trainer_devices , sampler_devices
697- )
698-
699574 if trainer_config .debug .rl :
700575 max_logging .log ("Reference Model initialized successfully" )
701576 nnx .display (reference_model )
702577 max_logging .log (f"Reference mesh shape: { reference_mesh .shape } " )
703578 max_logging .log ("Policy Model initialized successfully" )
704579 nnx .display (actor_model )
705580 max_logging .log (f"Policy mesh shape: { actor_mesh .shape } " )
581+ max_logging .log (f"Rollout_mesh shape: { rollout_mesh .shape } " )
706582
707583 rl_cluster , rl_trainer , _ = create_rl_components (
708584 trainer_config ,
@@ -759,18 +635,18 @@ def rl_train(trainer_config, sampler_config, trainer_devices, sampler_devices):
759635 max_logging .warning (f"Post RL Training: { corr = } , { total = } , { accuracy = } %, { partial_accuracy = } %," f" { format_accuracy = } %" )
760636
761637
762- def main (argv : Sequence [str ]) -> None :
638+ def main (argv : Sequence [str ], kwargs : dict = None ) -> None :
763639 """Main function to run RL training.
764640
765641 Args:
766642 argv: Command-line arguments.
767643 """
644+ kwargs = kwargs or {}
768645 pathwaysutils .initialize ()
769646 os .environ ["TF_CPP_MIN_LOG_LEVEL" ] = "0"
770647
771648 max_utils .print_system_information ()
772- trainer_config , sampler_config , trainer_devices , sampler_devices = setup_configs_and_devices (argv )
773- rl_train (trainer_config , sampler_config , trainer_devices , sampler_devices )
649+ rl_train (argv , kwargs )
774650
775651
776652if __name__ == "__main__" :
0 commit comments