@@ -590,8 +590,13 @@ def map_to_pspec(data):
590590 )
591591 ocp .type_handlers .register_type_handler (jax .Array , array_handler , override = True )
592592
593- restore_args = jax .tree_util .tree_map (map_to_pspec , abstract_unboxed_pre_state )
594- checkpoint_args = ocp .args .PyTreeRestore (item = abstract_unboxed_pre_state , restore_args = restore_args )
593+ # Convert nnx.State to pure dict to match how checkpoints are saved for NNX
594+ restore_target = abstract_unboxed_pre_state
595+ if isinstance (abstract_unboxed_pre_state , nnx .State ):
596+ restore_target = abstract_unboxed_pre_state .to_pure_dict ()
597+
598+ restore_args = jax .tree_util .tree_map (map_to_pspec , restore_target )
599+ checkpoint_args = ocp .args .PyTreeRestore (item = restore_target , restore_args = restore_args )
595600
596601 match (checkpoint_manager , dataset_type , data_iterator ):
597602 # Case 1: Matches if 'checkpoint_manager' is an instance of either EmergencyCheckpointManager
@@ -717,15 +722,35 @@ def save_params_to_path(checkpoint_dir, params, use_ocdbt=True, use_zarr3=True):
717722 print (f"Quantized params checkpoint saved at: { checkpoint_dir } " )
718723
719724
720- def maybe_save_checkpoint (checkpoint_manager , state , config , data_iterator , step = None ):
721- """Save checkpoint if checkpointing is enabled."""
725+ def maybe_save_checkpoint (checkpoint_manager , state , config , data_iterator , step = None , force = False ):
726+ """Save checkpoint if checkpointing is enabled.
727+
728+ Args:
729+ checkpoint_manager: The checkpoint manager.
730+ state: The training state to save.
731+ config: The config object.
732+ data_iterator: The data iterator.
733+ step: The step number. If None, extracts from state (for Linen TrainState).
734+ force: If True, force save the checkpoint regardless of checkpoint_period.
735+ """
722736 if checkpoint_manager is None :
723737 return
724738
725739 # Determine the effective step for saving a checkpoint.
726740 # If 'step' is not provided, this call is for a potential final checkpoint
727741 # and use the last completed step from the state.
728- actual_step = (int (state .step ) - 1 ) if step is None else int (step )
742+ if step is not None :
743+ actual_step = int (step )
744+ else :
745+ if config .pure_nnx :
746+ actual_step = int (state .optimizer .step ) - 1
747+ else :
748+ # Linen TrainState has .step attribute
749+ actual_step = int (state .step ) - 1
750+
751+ if config .pure_nnx :
752+ # Convert nnx.State to dict.
753+ state = state .to_pure_dict ()
729754
730755 # Determine if a checkpoint save should be forced, overriding the usual `config.checkpoint_period` logic.
731756 # This occurs if this function was called:
0 commit comments