@@ -206,7 +206,9 @@ def setup_configs_and_devices(argv: list[str] | None = None, kwargs: dict | None
206206def create_models_and_meshes (trainer_config , sampler_config , trainer_devices , sampler_devices ):
207207 """Create reference and actor models and their respective meshes."""
208208 max_logging .log ("Creating reference model and also meshes for reference and rollout" )
209- reference_model , reference_mesh = get_maxtext_model (trainer_config , trainer_devices )
209+ reference_model , reference_mesh = from_pretrained (
210+ trainer_config , devices = trainer_devices , wrap_with_tunix_adapter = True
211+ )
210212 devices_array = maxtext_utils .create_device_mesh (sampler_config , sampler_devices )
211213 rollout_mesh = Mesh (devices_array , sampler_config .mesh_axes )
212214
@@ -220,35 +222,15 @@ def create_models_and_meshes(trainer_config, sampler_config, trainer_devices, sa
220222 actor_mesh = reference_mesh
221223 else :
222224 max_logging .log ("Creating policy model with same config as reference model on trainer mesh" )
223- actor_model , actor_mesh = get_maxtext_model (trainer_config , trainer_devices )
225+ actor_model , actor_mesh = from_pretrained (
226+ trainer_config , devices = trainer_devices , wrap_with_tunix_adapter = True
227+ )
224228
225229 return reference_model , reference_mesh , actor_model , actor_mesh , rollout_mesh
226230
227- def get_maxtext_model (config , devices = None ):
228- """
229- Load MaxText model with Tunix adapter.
230- # Note: pass the path to your scanned checkpoint for 'load_parameters_path'.
231- # To create a scanned checkpoint, you can use /maxtext/src/MaxText/checkpoint_conversion/to_maxtext.py and if
232- # using Pathways, please set `USE_PATHWAYS=1` and use `$((1 - USE_PATHWAYS))` for storage flags:
233- # export USE_PATHWAYS=1
234- # python src/MaxText/checkpoint_conversion/to_maxtext.py \
235- # --model_name="gemma2-2b" \
236- # --base_output_directory="/path/to/your/output/directory" \
237- # --scan_layers=True \
238- # --checkpoint_storage_use_ocdbt=$((1 - USE_PATHWAYS)) \
239- # --checkpoint_storage_use_zarr3=$((1 - USE_PATHWAYS))
240- # Please ensure that you pass the full path ending in `/0/items` for load_parameters_path to train_rl.py i.e.,
241- # load_parameters_path=/path/to/your/output/directory/0/items
242- """
243- model , mesh = from_pretrained (config , devices = devices )
244- with mesh :
245- use_no_op_mappings = "maxtext_config" in config .vllm_additional_config
246- tunix_model = TunixMaxTextAdapter (base_model = model , use_no_op_mappings = use_no_op_mappings )
247- tunix_model .config = None
248- return tunix_model , mesh
249-
250-
251- def from_pretrained (config , original_mesh = None , devices = None , model_mode = MODEL_MODE_TRAIN , rng_key = None ):
231+ def from_pretrained (
232+ config , original_mesh = None , devices = None , model_mode = MODEL_MODE_TRAIN , rng_key = None , wrap_with_tunix_adapter = False
233+ ):
252234 """Creates a NNX model with sharded parameters, possibly loading from a checkpoint."""
253235 mesh = original_mesh
254236 if config .convert_checkpoint_if_possible :
@@ -411,6 +393,12 @@ def create_sharded_state():
411393 except Exception as e :
412394 raise ValueError (f"Checkpoint loading failed: { e } " ) from e
413395
396+ if wrap_with_tunix_adapter :
397+ with mesh :
398+ use_no_op_mappings = "maxtext_config" in config .vllm_additional_config
399+ model = TunixMaxTextAdapter (base_model = model , use_no_op_mappings = use_no_op_mappings )
400+ model .config = None
401+
414402 if original_mesh :
415403 return model
416404 else :
0 commit comments