3535"""
3636
3737import argparse
38- import functools
3938import gc
4039import os
4140import sys
4746from maxtext .configs import pyconfig
4847from maxtext .utils .globals import MAXTEXT_PKG_DIR
4948from maxtext .common import checkpointing
50- from maxtext .common .common_types import MODEL_MODE_TRAIN
51- from maxtext .layers import quantizations
5249from maxtext .common import train_state_nnx
53- from maxtext .models .models import transformer_as_linen
54- from maxtext .optimizers import optimizers
5550from maxtext .utils import max_logging
5651from maxtext .utils import max_utils
5752from maxtext .utils import maxtext_utils
@@ -92,23 +87,15 @@ def convert(paxml_ckpt_path, maxtext_model_name, base_output_directory, run_name
9287 devices_array = maxtext_utils .create_device_mesh (cfg )
9388 mesh = Mesh (devices_array , cfg .mesh_axes )
9489
95- if cfg .pure_nnx :
96- rngs = maxtext_utils_nnx .create_nnx_rngs (cfg , rng_key = init_rng )
97- model = model_creation_utils .from_config (cfg , mesh = mesh , rngs = rngs )
98- _ , tx = train_utils .create_training_optimizer (cfg , model )
99- _create_model_partial , _ = model_creation_utils .create_nnx_abstract_model (cfg , mesh )
90+ rngs = maxtext_utils_nnx .create_nnx_rngs (cfg , rng_key = init_rng )
91+ model = model_creation_utils .from_config (cfg , mesh = mesh , rngs = rngs )
92+ _ , tx = train_utils .create_training_optimizer (cfg , model )
93+ _create_model_partial , _ = model_creation_utils .create_nnx_abstract_model (cfg , mesh )
10094
101- def init_state_fn ():
102- nnx_model = _create_model_partial ()
103- optimizer = nnx .Optimizer (nnx_model , tx , wrt = nnx .Param )
104- return train_state_nnx .TrainStateNNX (nnx_model , optimizer )
105-
106- else :
107- quant = quantizations .configure_quantization (cfg )
108- model = transformer_as_linen (cfg , mesh , quant = quant , model_mode = MODEL_MODE_TRAIN )
109- learning_rate_schedule = maxtext_utils .create_learning_rate_schedule (cfg )
110- tx = optimizers .get_optimizer (cfg , learning_rate_schedule )
111- init_state_fn = functools .partial (maxtext_utils .init_initial_state , model , tx , cfg , True , init_rng )
95+ def init_state_fn ():
96+ nnx_model = _create_model_partial ()
97+ optimizer = nnx .Optimizer (nnx_model , tx , wrt = nnx .Param )
98+ return train_state_nnx .TrainStateNNX (nnx_model , optimizer )
11299
113100 checkpoint_manager = checkpointing .create_orbax_checkpoint_manager (
114101 cfg .checkpoint_dir ,
@@ -201,24 +188,18 @@ def init_state_fn():
201188 "['decoder']['decoder_norm']['bias']" : (".params.lm.final_ln.bias" , None ),
202189 }
203190
204- if cfg .pure_nnx :
205- # NNX state-tree paths after `nnx.split(TrainStateNNX)`. The state is a
206- # nested `nnx.State` (dict-like Mapping) with `nnx.Variable` leaves, so
207- # `jax.tree_util.keystr` produces dict-style entries (`['key']`) plus
208- # `.value` for the Variable leaf, plus `[idx]` for the optax tuple:
209- # model params -> ['model']<rest>.value
210- # adam mu / nu -> ['optimizer']['opt_state'][0]['mu' | 'nu']<rest>.value
211- # step -> ['optimizer']['step'].value
212- # opt count -> ['optimizer']['opt_state'][0]['count'].value
213- state_map = {
214- "['optimizer']['step'].value" : ("step" , None ),
215- "['optimizer']['opt_state'][0]['count'].value" : ("opt_states_0.no_prefix_0.count" , None ),
216- }
217- else :
218- state_map = {
219- ".step" : ("step" , None ),
220- ".opt_state.count" : ("opt_states_0.no_prefix_0.count" , None ),
221- }
191+ # NNX state-tree paths after `nnx.split(TrainStateNNX)`. The state is a
192+ # nested `nnx.State` (dict-like Mapping) with `nnx.Variable` leaves, so
193+ # `jax.tree_util.keystr` produces dict-style entries (`['key']`) plus
194+ # `.value` for the Variable leaf, plus `[idx]` for the optax tuple:
195+ # model params -> ['model']<rest>.value
196+ # adam mu / nu -> ['optimizer']['opt_state'][0]['mu' | 'nu']<rest>.value
197+ # step -> ['optimizer']['step'].value
198+ # opt count -> ['optimizer']['opt_state'][0]['count'].value
199+ state_map = {
200+ "['optimizer']['step'].value" : ("step" , None ),
201+ "['optimizer']['opt_state'][0]['count'].value" : ("opt_states_0.no_prefix_0.count" , None ),
202+ }
222203
223204 def get_layer_prefix (keystr_pax ):
224205 # different path format between decoder_layer variable
@@ -231,26 +212,15 @@ def get_layer_prefix(keystr_pax):
231212
232213 for keystr_maxtext , (keystr_pax , transform_fn ) in keystr_map .items ():
233214 prefix_pax_opt_state = get_layer_prefix (keystr_pax )
234- if cfg .pure_nnx :
235- state_map [f"['model']{ keystr_maxtext } .value" ] = (f"mdl_vars{ keystr_pax } " , transform_fn )
236- state_map [f"['optimizer']['opt_state'][0]['mu']{ keystr_maxtext } .value" ] = (
237- f"opt_states_0.{ prefix_pax_opt_state } .m{ keystr_pax } " ,
238- transform_fn ,
239- )
240- state_map [f"['optimizer']['opt_state'][0]['nu']{ keystr_maxtext } .value" ] = (
241- f"opt_states_0.{ prefix_pax_opt_state } .v{ keystr_pax } " ,
242- transform_fn ,
243- )
244- else :
245- state_map [f".params['params']{ keystr_maxtext } " ] = (f"mdl_vars{ keystr_pax } " , transform_fn )
246- state_map [f".opt_state.mu['params']{ keystr_maxtext } " ] = (
247- f"opt_states_0.{ prefix_pax_opt_state } .m{ keystr_pax } " ,
248- transform_fn ,
249- )
250- state_map [f".opt_state.nu['params']{ keystr_maxtext } " ] = (
251- f"opt_states_0.{ prefix_pax_opt_state } .v{ keystr_pax } " ,
252- transform_fn ,
253- )
215+ state_map [f"['model']{ keystr_maxtext } .value" ] = (f"mdl_vars{ keystr_pax } " , transform_fn )
216+ state_map [f"['optimizer']['opt_state'][0]['mu']{ keystr_maxtext } .value" ] = (
217+ f"opt_states_0.{ prefix_pax_opt_state } .m{ keystr_pax } " ,
218+ transform_fn ,
219+ )
220+ state_map [f"['optimizer']['opt_state'][0]['nu']{ keystr_maxtext } .value" ] = (
221+ f"opt_states_0.{ prefix_pax_opt_state } .v{ keystr_pax } " ,
222+ transform_fn ,
223+ )
254224
255225 def verify_fn (key_path , _ ):
256226 keystr = jax .tree_util .keystr (key_path )
@@ -302,7 +272,7 @@ def map_fn(key_path, value):
302272 max_logging .log ("converted state finished" )
303273 max_utils .print_mem_stats ("converted state finished" )
304274
305- step_value = int (converted_state .optimizer .step .value ) if cfg . pure_nnx else converted_state . step
275+ step_value = int (converted_state .optimizer .step .value )
306276 if checkpointing .save_checkpoint (checkpoint_manager , step_value , converted_state ):
307277 max_logging .log (f"saved a checkpoint at step { step_value } " )
308278 # Upon preemption, exit when and only when all ongoing saves are complete.
0 commit comments