Skip to content

Commit 7a172ce

Browse files
committed
NNX: fix checkpointing in the training loop
- Convert nnx.State to pure dict for checkpoint saving - Restore pure dict back to nnx.State after loading
1 parent 0a1f4e9 commit 7a172ce

7 files changed

Lines changed: 1038 additions & 12 deletions

File tree

src/maxtext/common/checkpointing.py

Lines changed: 30 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -590,8 +590,13 @@ def map_to_pspec(data):
590590
)
591591
ocp.type_handlers.register_type_handler(jax.Array, array_handler, override=True)
592592

593-
restore_args = jax.tree_util.tree_map(map_to_pspec, abstract_unboxed_pre_state)
594-
checkpoint_args = ocp.args.PyTreeRestore(item=abstract_unboxed_pre_state, restore_args=restore_args)
593+
# Convert nnx.State to pure dict to match how checkpoints are saved for NNX
594+
restore_target = abstract_unboxed_pre_state
595+
if isinstance(abstract_unboxed_pre_state, nnx.State):
596+
restore_target = abstract_unboxed_pre_state.to_pure_dict()
597+
598+
restore_args = jax.tree_util.tree_map(map_to_pspec, restore_target)
599+
checkpoint_args = ocp.args.PyTreeRestore(item=restore_target, restore_args=restore_args)
595600

596601
match (checkpoint_manager, dataset_type, data_iterator):
597602
# Case 1: Matches if 'checkpoint_manager' is an instance of either EmergencyCheckpointManager
@@ -717,15 +722,35 @@ def save_params_to_path(checkpoint_dir, params, use_ocdbt=True, use_zarr3=True):
717722
print(f"Quantized params checkpoint saved at: {checkpoint_dir}")
718723

719724

720-
def maybe_save_checkpoint(checkpoint_manager, state, config, data_iterator, step=None):
721-
"""Save checkpoint if checkpointing is enabled."""
725+
def maybe_save_checkpoint(checkpoint_manager, state, config, data_iterator, step=None, force=False):
726+
"""Save checkpoint if checkpointing is enabled.
727+
728+
Args:
729+
checkpoint_manager: The checkpoint manager.
730+
state: The training state to save.
731+
config: The config object.
732+
data_iterator: The data iterator.
733+
step: The step number. If None, extracts from state (for Linen TrainState).
734+
force: If True, force save the checkpoint regardless of checkpoint_period.
735+
"""
722736
if checkpoint_manager is None:
723737
return
724738

725739
# Determine the effective step for saving a checkpoint.
726740
# If 'step' is not provided, this call is for a potential final checkpoint
727741
# and use the last completed step from the state.
728-
actual_step = (int(state.step) - 1) if step is None else int(step)
742+
if step is not None:
743+
actual_step = int(step)
744+
else:
745+
if config.pure_nnx:
746+
actual_step = int(state.optimizer.step) - 1
747+
else:
748+
# Linen TrainState has .step attribute
749+
actual_step = int(state.step) - 1
750+
751+
if config.pure_nnx:
752+
# Convert nnx.State to dict.
753+
state = state.to_pure_dict()
729754

730755
# Determine if a checkpoint save should be forced, overriding the usual `config.checkpoint_period` logic.
731756
# This occurs if this function was called:

src/maxtext/utils/maxtext_utils.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1168,7 +1168,10 @@ def setup_initial_state(
11681168
# The update of data_iterator state happens in place, no need to assign explicitly
11691169
state = restored["items"]
11701170

1171-
# TODO: For NNX, convert the pure dict to nnx.State.
1171+
# For NNX, convert the pure dict to nnx.State using the abstract state as template
1172+
if config.pure_nnx:
1173+
nnx.replace_by_pure_dict(unboxed_abstract_state, state)
1174+
state = unboxed_abstract_state
11721175
else:
11731176
init_state_partial = init_state_fn
11741177
init_state_partial.__name__ = "initialize_state"

tests/integration/checkpointing_test.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ def get_checkpointing_command(run_date, hardware, steps, metrics_file, attention
9393
f"dataset_type={dataset_type}",
9494
"async_checkpointing=False",
9595
f"attention={attention_type}",
96+
"profiler=''",
9697
]
9798
+ model_params
9899
+ pathways_command
@@ -135,19 +136,19 @@ def run_checkpointing(hardware, attention_type):
135136
# Determine dataset path/pattern depending on decoupled mode.
136137
gcsfuse_pattern = "/tmp/gcsfuse/array-record/c4/en/3.0.1/c4-train.array_record*"
137138
local_decoupled_root = os.path.join(
138-
MAXTEXT_PKG_DIR, "..", "tests", "assets", "local_datasets", "c4_en_dataset_minimal", "c4", "en", "3.0.1"
139+
MAXTEXT_PKG_DIR, "..", "..", "tests", "assets", "local_datasets", "c4_en_dataset_minimal", "c4", "en", "3.0.1"
139140
)
140141
local_pattern = os.path.join(local_decoupled_root, "c4-train.array_record*")
141142
selected_pattern = gcsfuse_pattern
142143
dataset_path = "/tmp/gcsfuse"
143144

144-
if is_decoupled():
145+
if not glob.glob(gcsfuse_pattern):
145146
# Prefer local minimal dataset if gcsfuse data absent
146-
if not glob.glob(gcsfuse_pattern) and glob.glob(local_pattern):
147+
if glob.glob(local_pattern):
147148
selected_pattern = local_pattern
148-
dataset_path = os.path.join(MAXTEXT_PKG_DIR, "..", "tests", "assets", "local_datasets")
149-
elif not glob.glob(gcsfuse_pattern) and not glob.glob(local_pattern):
150-
pytest.skip("No grain ArrayRecord shards found for checkpointing test in decoupled mode.")
149+
dataset_path = os.path.join(MAXTEXT_PKG_DIR, "..", "..", "tests", "assets", "local_datasets")
150+
else:
151+
pytest.skip("No grain ArrayRecord shards found for checkpointing test.")
151152

152153
grain_command = [
153154
"grain_worker_count=0",

tests/integration/generate_param_only_checkpoint_test.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,8 @@ def run_e2e_test_flow(hardware, model_config, attention_type="autoselected", sta
5454
f"attention={attention_type}",
5555
"max_target_length=128",
5656
"per_device_batch_size=1",
57+
"profiler=''",
58+
"pure_nnx=False",
5759
] + model_config
5860

5961
pathways_command = []
@@ -72,6 +74,7 @@ def run_e2e_test_flow(hardware, model_config, attention_type="autoselected", sta
7274
dataset_type="tfds",
7375
dataset_path=dataset_path,
7476
)
77+
+ ["pure_nnx=False"]
7578
)
7679
state_path = f"{base_output_directory}/runner_{run_date}/checkpoints/0/items"
7780

tests/integration/standalone_dl_ckpt_test.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,8 @@ def test_standalone_checkpointer(self):
8989
"async_checkpointing=False",
9090
"enable_goodput_recording=False",
9191
"skip_jax_distributed_system=True",
92+
"pure_nnx=False",
93+
"enable_nnx=False",
9294
)
9395
)
9496
# restore at 50 and checkpoint at 100
@@ -110,6 +112,8 @@ def test_standalone_checkpointer(self):
110112
"async_checkpointing=False",
111113
"enable_goodput_recording=False",
112114
"skip_jax_distributed_system=True",
115+
"pure_nnx=False",
116+
"enable_nnx=False",
113117
)
114118
)
115119

0 commit comments

Comments
 (0)