Skip to content

Commit bd6f280

Browse files
committed
NNX: fix checkpointing in the training loop
- Convert nnx.State to pure dict for checkpoint saving - Restore pure dict back to nnx.State after loading
1 parent a9d1875 commit bd6f280

2 files changed

Lines changed: 34 additions & 6 deletions

File tree

src/maxtext/common/checkpointing.py

Lines changed: 30 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -590,8 +590,13 @@ def map_to_pspec(data):
590590
)
591591
ocp.type_handlers.register_type_handler(jax.Array, array_handler, override=True)
592592

593-
restore_args = jax.tree_util.tree_map(map_to_pspec, abstract_unboxed_pre_state)
594-
checkpoint_args = ocp.args.PyTreeRestore(item=abstract_unboxed_pre_state, restore_args=restore_args)
593+
# Convert nnx.State to pure dict to match how checkpoints are saved for NNX
594+
restore_target = abstract_unboxed_pre_state
595+
if isinstance(abstract_unboxed_pre_state, nnx.State):
596+
restore_target = abstract_unboxed_pre_state.to_pure_dict()
597+
598+
restore_args = jax.tree_util.tree_map(map_to_pspec, restore_target)
599+
checkpoint_args = ocp.args.PyTreeRestore(item=restore_target, restore_args=restore_args)
595600

596601
match (checkpoint_manager, dataset_type, data_iterator):
597602
# Case 1: Matches if 'checkpoint_manager' is an instance of either EmergencyCheckpointManager
@@ -717,15 +722,35 @@ def save_params_to_path(checkpoint_dir, params, use_ocdbt=True, use_zarr3=True):
717722
print(f"Quantized params checkpoint saved at: {checkpoint_dir}")
718723

719724

720-
def maybe_save_checkpoint(checkpoint_manager, state, config, data_iterator, step=None):
721-
"""Save checkpoint if checkpointing is enabled."""
725+
def maybe_save_checkpoint(checkpoint_manager, state, config, data_iterator, step=None, force=False):
726+
"""Save checkpoint if checkpointing is enabled.
727+
728+
Args:
729+
checkpoint_manager: The checkpoint manager.
730+
state: The training state to save.
731+
config: The config object.
732+
data_iterator: The data iterator.
733+
step: The step number. If None, extracts from state (for Linen TrainState).
734+
force: If True, force save the checkpoint regardless of checkpoint_period.
735+
"""
722736
if checkpoint_manager is None:
723737
return
724738

725739
# Determine the effective step for saving a checkpoint.
726740
# If 'step' is not provided, this call is for a potential final checkpoint
727741
# and use the last completed step from the state.
728-
actual_step = (int(state.step) - 1) if step is None else int(step)
742+
if step is not None:
743+
actual_step = int(step)
744+
else:
745+
if config.pure_nnx:
746+
actual_step = int(state.optimizer.step) - 1
747+
else:
748+
# Linen TrainState has .step attribute
749+
actual_step = int(state.step) - 1
750+
751+
if config.pure_nnx:
752+
# Convert nnx.State to dict.
753+
state = state.to_pure_dict()
729754

730755
# Determine if a checkpoint save should be forced, overriding the usual `config.checkpoint_period` logic.
731756
# This occurs if this function was called:

src/maxtext/utils/maxtext_utils.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1168,7 +1168,10 @@ def setup_initial_state(
11681168
# The update of data_iterator state happens in place, no need to assign explicitly
11691169
state = restored["items"]
11701170

1171-
# TODO: For NNX, convert the pure dict to nnx.State.
1171+
# For NNX, convert the pure dict to nnx.State using the abstract state as template
1172+
if config.pure_nnx:
1173+
nnx.replace_by_pure_dict(unboxed_abstract_state, state)
1174+
state = unboxed_abstract_state
11721175
else:
11731176
init_state_partial = init_state_fn
11741177
init_state_partial.__name__ = "initialize_state"

0 commit comments

Comments
 (0)