Skip to content

Commit a726aac

Browse files
committed
NNX: add TrainState, model creation utilities, and training loop support
- Add TrainStateNNX (layers/train_state_nnx.py) with checkpoint and unit tests - Refactor model_creation_utils with create_nnx_abstract_model(); add NNX support to muon_utils - Add get_abstract_state_nnx() and get_nnx_named_sharding_with_scan_axis() to maxtext_utils.py - Wire NNX train state into train.py and train_utils.py with pure_nnx dispatch
1 parent 57f4a9a commit a726aac

12 files changed

Lines changed: 1401 additions & 260 deletions

src/maxtext/common/checkpointing.py

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from absl import flags
2121
import datetime
2222
from etils import epath
23+
from flax import nnx
2324
from flax.training import train_state
2425
import jax
2526
from maxtext.utils.globals import DEFAULT_OCDBT_TARGET_DATA_FILE_SIZE
@@ -521,7 +522,7 @@ def load_state_if_possible(
521522
load_parameters_from_path: str,
522523
load_full_state_from_path: str,
523524
checkpoint_storage_concurrent_gb: int,
524-
abstract_unboxed_pre_state: train_state.TrainState,
525+
abstract_unboxed_pre_state: train_state.TrainState | nnx.State,
525526
enable_single_replica_ckpt_restoring: bool | None = False,
526527
dataset_type: str | None = "tfds",
527528
step: int = -1, # -1 means latest
@@ -589,8 +590,13 @@ def map_to_pspec(data):
589590
)
590591
ocp.type_handlers.register_type_handler(jax.Array, array_handler, override=True)
591592

592-
restore_args = jax.tree_util.tree_map(map_to_pspec, abstract_unboxed_pre_state)
593-
checkpoint_args = ocp.args.PyTreeRestore(item=abstract_unboxed_pre_state, restore_args=restore_args)
593+
# Convert nnx.State to pure dict to match how checkpoints are saved for NNX
594+
restore_target = abstract_unboxed_pre_state
595+
if isinstance(abstract_unboxed_pre_state, nnx.State):
596+
restore_target = abstract_unboxed_pre_state.to_pure_dict()
597+
598+
restore_args = jax.tree_util.tree_map(map_to_pspec, restore_target)
599+
checkpoint_args = ocp.args.PyTreeRestore(item=restore_target, restore_args=restore_args)
594600

595601
match (checkpoint_manager, dataset_type, data_iterator):
596602
# Case 1: Matches if 'checkpoint_manager' is an instance of either EmergencyCheckpointManager
@@ -625,9 +631,14 @@ def map_to_pspec(data):
625631
return (checkpoint_manager.restore(step, args=Composite(items=checkpoint_args)), None)
626632

627633
if load_parameters_from_path != "":
634+
if isinstance(abstract_unboxed_pre_state, nnx.State):
635+
_, params, _ = nnx.split(abstract_unboxed_pre_state.model, nnx.Param, ...)
636+
else:
637+
params = abstract_unboxed_pre_state.params
638+
628639
restored_params = load_params_from_path(
629640
load_parameters_from_path,
630-
abstract_unboxed_pre_state.params,
641+
params,
631642
checkpoint_storage_concurrent_gb,
632643
use_ocdbt=use_ocdbt,
633644
use_zarr3=use_zarr3,
@@ -719,7 +730,18 @@ def maybe_save_checkpoint(checkpoint_manager, state, config, data_iterator, step
719730
# Determine the effective step for saving a checkpoint.
720731
# If 'step' is not provided, this call is for a potential final checkpoint
721732
# and use the last completed step from the state.
722-
actual_step = (int(state.step) - 1) if step is None else int(step)
733+
if step is not None:
734+
actual_step = int(step)
735+
else:
736+
if config.pure_nnx:
737+
actual_step = int(state.optimizer.step) - 1
738+
else:
739+
# Linen TrainState has .step attribute
740+
actual_step = int(state.step) - 1
741+
742+
if config.pure_nnx:
743+
# Convert nnx.State to dict.
744+
state = state.to_pure_dict()
723745

724746
# Determine if a checkpoint save should be forced, overriding the usual `config.checkpoint_period` logic.
725747
# This occurs if this function was called:

0 commit comments

Comments
 (0)