Skip to content

Commit cbce870

Browse files
committed
NNX: add TrainState, model creation utilities, and training loop support
- Add TrainStateNNX (layers/train_state_nnx.py) with checkpoint and unit tests - Refactor model_creation_utils with create_nnx_abstract_model(); add NNX support to muon_utils - Add get_abstract_state_nnx() and get_nnx_named_sharding_with_scan_axis() to maxtext_utils.py - Wire NNX train state into train.py and train_utils.py with pure_nnx dispatch
1 parent d7fc0f1 commit cbce870

12 files changed

Lines changed: 1442 additions & 297 deletions

src/maxtext/common/checkpointing.py

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from absl import flags
2121
import datetime
2222
from etils import epath
23+
from flax import nnx
2324
from flax.training import train_state
2425
import jax
2526
from maxtext.utils.globals import DEFAULT_OCDBT_TARGET_DATA_FILE_SIZE
@@ -532,7 +533,7 @@ def load_state_if_possible(
532533
load_parameters_from_path: str,
533534
load_full_state_from_path: str,
534535
checkpoint_storage_concurrent_gb: int,
535-
abstract_unboxed_pre_state: train_state.TrainState,
536+
abstract_unboxed_pre_state: train_state.TrainState | nnx.State,
536537
enable_single_replica_ckpt_restoring: bool | None = False,
537538
dataset_type: str | None = "tfds",
538539
step: int = -1, # -1 means latest
@@ -600,8 +601,13 @@ def map_to_pspec(data):
600601
)
601602
ocp.type_handlers.register_type_handler(jax.Array, array_handler, override=True)
602603

603-
restore_args = jax.tree_util.tree_map(map_to_pspec, abstract_unboxed_pre_state)
604-
checkpoint_args = ocp.args.PyTreeRestore(item=abstract_unboxed_pre_state, restore_args=restore_args)
604+
# Convert nnx.State to pure dict to match how checkpoints are saved for NNX
605+
restore_target = abstract_unboxed_pre_state
606+
if isinstance(abstract_unboxed_pre_state, nnx.State):
607+
restore_target = abstract_unboxed_pre_state.to_pure_dict()
608+
609+
restore_args = jax.tree_util.tree_map(map_to_pspec, restore_target)
610+
checkpoint_args = ocp.args.PyTreeRestore(item=restore_target, restore_args=restore_args)
605611

606612
match (checkpoint_manager, dataset_type, data_iterator):
607613
# Case 1: Matches if 'checkpoint_manager' is an instance of either EmergencyCheckpointManager
@@ -636,9 +642,14 @@ def map_to_pspec(data):
636642
return (checkpoint_manager.restore(step, args=Composite(items=checkpoint_args)), None)
637643

638644
if load_parameters_from_path != "":
645+
if isinstance(abstract_unboxed_pre_state, nnx.State):
646+
_, params, _ = nnx.split(abstract_unboxed_pre_state.model, nnx.Param, ...)
647+
else:
648+
params = abstract_unboxed_pre_state.params
649+
639650
restored_params = load_params_from_path(
640651
load_parameters_from_path,
641-
abstract_unboxed_pre_state.params,
652+
params,
642653
checkpoint_storage_concurrent_gb,
643654
use_ocdbt=use_ocdbt,
644655
use_zarr3=use_zarr3,
@@ -730,7 +741,18 @@ def maybe_save_checkpoint(checkpoint_manager, state, config, data_iterator, step
730741
# Determine the effective step for saving a checkpoint.
731742
# If 'step' is not provided, this call is for a potential final checkpoint
732743
# and use the last completed step from the state.
733-
actual_step = (int(state.step) - 1) if step is None else int(step)
744+
if step is not None:
745+
actual_step = int(step)
746+
else:
747+
if config.pure_nnx:
748+
actual_step = int(state.optimizer.step) - 1
749+
else:
750+
# Linen TrainState has .step attribute
751+
actual_step = int(state.step) - 1
752+
753+
if config.pure_nnx:
754+
# Convert nnx.State to dict.
755+
state = state.to_pure_dict()
734756

735757
# Determine if a checkpoint save should be forced, overriding the usual `config.checkpoint_period` logic.
736758
# This occurs if this function was called:

0 commit comments

Comments
 (0)