Skip to content

Commit 00de2a3

Browse files
committed
refactor from_pretrained
1 parent 2af51a0 commit 00de2a3

1 file changed

Lines changed: 15 additions & 27 deletions

File tree

src/maxtext/utils/model_creation_utils.py

Lines changed: 15 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,9 @@ def setup_configs_and_devices(argv: list[str] | None = None, kwargs: dict | None
206206
def create_models_and_meshes(trainer_config, sampler_config, trainer_devices, sampler_devices):
207207
"""Create reference and actor models and their respective meshes."""
208208
max_logging.log("Creating reference model and also meshes for reference and rollout")
209-
reference_model, reference_mesh = get_maxtext_model(trainer_config, trainer_devices)
209+
reference_model, reference_mesh = from_pretrained(
210+
trainer_config, devices=trainer_devices, wrap_with_tunix_adapter=True
211+
)
210212
devices_array = maxtext_utils.create_device_mesh(sampler_config, sampler_devices)
211213
rollout_mesh = Mesh(devices_array, sampler_config.mesh_axes)
212214

@@ -220,35 +222,15 @@ def create_models_and_meshes(trainer_config, sampler_config, trainer_devices, sa
220222
actor_mesh = reference_mesh
221223
else:
222224
max_logging.log("Creating policy model with same config as reference model on trainer mesh")
223-
actor_model, actor_mesh = get_maxtext_model(trainer_config, trainer_devices)
225+
actor_model, actor_mesh = from_pretrained(
226+
trainer_config, devices=trainer_devices, wrap_with_tunix_adapter=True
227+
)
224228

225229
return reference_model, reference_mesh, actor_model, actor_mesh, rollout_mesh
226230

227-
def get_maxtext_model(config, devices=None):
228-
"""
229-
Load MaxText model with Tunix adapter.
230-
# Note: pass the path to your scanned checkpoint for 'load_parameters_path'.
231-
# To create a scanned checkpoint, you can use /maxtext/src/MaxText/checkpoint_conversion/to_maxtext.py and if
232-
# using Pathways, please set `USE_PATHWAYS=1` and use `$((1 - USE_PATHWAYS))` for storage flags:
233-
# export USE_PATHWAYS=1
234-
# python src/MaxText/checkpoint_conversion/to_maxtext.py \
235-
# --model_name="gemma2-2b" \
236-
# --base_output_directory="/path/to/your/output/directory" \
237-
# --scan_layers=True \
238-
# --checkpoint_storage_use_ocdbt=$((1 - USE_PATHWAYS)) \
239-
# --checkpoint_storage_use_zarr3=$((1 - USE_PATHWAYS))
240-
# Please ensure that you pass the full path ending in `/0/items` for load_parameters_path to train_rl.py i.e.,
241-
# load_parameters_path=/path/to/your/output/directory/0/items
242-
"""
243-
model, mesh = from_pretrained(config, devices=devices)
244-
with mesh:
245-
use_no_op_mappings = "maxtext_config" in config.vllm_additional_config
246-
tunix_model = TunixMaxTextAdapter(base_model=model, use_no_op_mappings=use_no_op_mappings)
247-
tunix_model.config = None
248-
return tunix_model, mesh
249-
250-
251-
def from_pretrained(config, original_mesh=None, devices=None, model_mode=MODEL_MODE_TRAIN, rng_key=None):
231+
def from_pretrained(
232+
config, original_mesh=None, devices=None, model_mode=MODEL_MODE_TRAIN, rng_key=None, wrap_with_tunix_adapter=False
233+
):
252234
"""Creates a NNX model with sharded parameters, possibly loading from a checkpoint."""
253235
mesh = original_mesh
254236
if config.convert_checkpoint_if_possible:
@@ -411,6 +393,12 @@ def create_sharded_state():
411393
except Exception as e:
412394
raise ValueError(f"Checkpoint loading failed: {e}") from e
413395

396+
if wrap_with_tunix_adapter:
397+
with mesh:
398+
use_no_op_mappings = "maxtext_config" in config.vllm_additional_config
399+
model = TunixMaxTextAdapter(base_model=model, use_no_op_mappings=use_no_op_mappings)
400+
model.config = None
401+
414402
if original_mesh:
415403
return model
416404
else:

0 commit comments

Comments
 (0)