@@ -546,23 +546,43 @@ def setup_train_loop(
546546 max_logging .log ("Training mesh used for the workload" )
547547 num_inference_devices = config .inference_devices_per_replica * config .inference_replicas
548548 training_devices = jax .devices ()[num_inference_devices :]
549- model = mt .from_config (config , devices = training_devices )
549+ if config .pure_nnx :
550+ raise NotImplementedError ("Pure NNX support has not been implemented yet." )
551+ else :
552+ model = mt .from_config (config , devices = training_devices )
550553 mesh = model .mesh
551554 max_logging .log ("Inference mesh used for the workload" )
552555 inference_devices = jax .devices ()[:num_inference_devices ]
553- inference_model = mt .from_config (config_inference , devices = inference_devices )
556+ if config_inference .pure_nnx :
557+ raise NotImplementedError ("Pure NNX support has not been implemented yet." )
558+ else :
559+ inference_model = mt .from_config (config_inference , devices = inference_devices )
554560 inference_mesh = inference_model .mesh
555- init_rng , checkpoint_manager , learning_rate_schedule , tx = train_utils .create_training_tools (config , model , mesh )
561+ init_rng = jax .random .PRNGKey (config .init_weights_seed )
562+ learning_rate_schedule , tx = train_utils .create_training_optimizer (config , model )
563+ if config .pure_nnx :
564+ # NNX has a different function to init the training state.
565+ raise NotImplementedError ("Pure NNX support has not been implemented yet." )
566+ else :
567+ init_state_fn = functools .partial (maxtext_utils .init_initial_state , model , tx , config , True , init_rng )
568+ checkpoint_manager = train_utils .create_checkpoint_manager (config , mesh , init_state_fn )
556569
557570 with maybe_record_goodput (recorder , GoodputEvent .TRAINING_PREPARATION ):
558571 data_iterator = grpo_input_pipeline .create_data_iterator (config_inference , inference_mesh )
559572 state , _ , state_mesh_shardings , data_iterator = maxtext_utils .setup_training_state (
560- model , data_iterator , tx , config , init_rng , mesh , checkpoint_manager
573+ data_iterator , config , mesh , checkpoint_manager , init_state_fn
561574 )
562575
563576 # create inference_state_mesh_shardings from inference_mesh
577+ if config_inference .pure_nnx :
578+ # NNX has a different function to init the training state.
579+ raise NotImplementedError ("Pure NNX support has not been implemented yet." )
580+ else :
581+ init_inference_state_fn = functools .partial (
582+ maxtext_utils .init_initial_state , inference_model , tx , config_inference , False , init_rng
583+ )
564584 inference_state_mesh_shardings = maxtext_utils .get_abstract_state (
565- inference_model , tx , config_inference , init_rng , inference_mesh , is_training = False
585+ config_inference , inference_mesh , init_inference_state_fn , is_training = False
566586 )[2 ]
567587 if not config .using_pipeline_parallelism :
568588 # The vocab tensor(s) of shape [vocab, embed] (and transpose) are not sharded by stage
0 commit comments