Skip to content

Commit 6b778c3

Browse files
committed
NNX: add TrainState, model creation utilities, and training loop support
- Add TrainStateNNX (layers/train_state_nnx.py) with checkpoint and unit tests - Refactor model_creation_utils with create_nnx_abstract_model(); add NNX support to muon_utils - Add get_abstract_state_nnx() and get_nnx_named_sharding_with_scan_axis() to maxtext_utils.py - Wire NNX train state into train.py and train_utils.py with pure_nnx dispatch
1 parent 963a188 commit 6b778c3

13 files changed

Lines changed: 1494 additions & 309 deletions

src/maxtext/common/checkpointing.py

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from absl import flags
2121
import datetime
2222
from etils import epath
23+
from flax import nnx
2324
from flax.training import train_state
2425
import jax
2526
from maxtext.utils.globals import DEFAULT_OCDBT_TARGET_DATA_FILE_SIZE
@@ -536,7 +537,7 @@ def load_state_if_possible(
536537
load_parameters_from_path: str,
537538
load_full_state_from_path: str,
538539
checkpoint_storage_concurrent_gb: int,
539-
abstract_unboxed_pre_state: train_state.TrainState,
540+
abstract_unboxed_pre_state: train_state.TrainState | nnx.State,
540541
enable_single_replica_ckpt_restoring: bool | None = False,
541542
dataset_type: str | None = "tfds",
542543
step: int = -1, # -1 means latest
@@ -604,8 +605,13 @@ def map_to_pspec(data):
604605
)
605606
ocp.type_handlers.register_type_handler(jax.Array, array_handler, override=True)
606607

607-
restore_args = jax.tree_util.tree_map(map_to_pspec, abstract_unboxed_pre_state)
608-
checkpoint_args = ocp.args.PyTreeRestore(item=abstract_unboxed_pre_state, restore_args=restore_args)
608+
# Convert nnx.State to pure dict to match how checkpoints are saved for NNX
609+
restore_target = abstract_unboxed_pre_state
610+
if isinstance(abstract_unboxed_pre_state, nnx.State):
611+
restore_target = abstract_unboxed_pre_state.to_pure_dict()
612+
613+
restore_args = jax.tree_util.tree_map(map_to_pspec, restore_target)
614+
checkpoint_args = ocp.args.PyTreeRestore(item=restore_target, restore_args=restore_args)
609615

610616
match (checkpoint_manager, dataset_type, data_iterator):
611617
# Case 1: Matches if 'checkpoint_manager' is an instance of either EmergencyCheckpointManager
@@ -640,9 +646,14 @@ def map_to_pspec(data):
640646
return (checkpoint_manager.restore(step, args=Composite(items=checkpoint_args)), None)
641647

642648
if load_parameters_from_path != "":
649+
if isinstance(abstract_unboxed_pre_state, nnx.State):
650+
_, params, _ = nnx.split(abstract_unboxed_pre_state.model, nnx.Param, ...)
651+
else:
652+
params = abstract_unboxed_pre_state.params
653+
643654
restored_params = load_params_from_path(
644655
load_parameters_from_path,
645-
abstract_unboxed_pre_state.params,
656+
params,
646657
checkpoint_storage_concurrent_gb,
647658
use_ocdbt=use_ocdbt,
648659
use_zarr3=use_zarr3,
@@ -734,7 +745,18 @@ def maybe_save_checkpoint(checkpoint_manager, state, config, data_iterator, step
734745
# Determine the effective step for saving a checkpoint.
735746
# If 'step' is not provided, this call is for a potential final checkpoint
736747
# and use the last completed step from the state.
737-
actual_step = (int(state.step) - 1) if step is None else int(step)
748+
if step is not None:
749+
actual_step = int(step)
750+
else:
751+
if config.pure_nnx:
752+
actual_step = int(state.optimizer.step) - 1
753+
else:
754+
# Linen TrainState has .step attribute
755+
actual_step = int(state.step) - 1
756+
757+
if config.pure_nnx:
758+
# Convert nnx.State to dict.
759+
state = state.to_pure_dict()
738760

739761
# Determine if a checkpoint save should be forced, overriding the usual `config.checkpoint_period` logic.
740762
# This occurs if this function was called:

src/maxtext/layers/nnx_decoders.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
MODEL_MODE_TRAIN,
3636
Config,
3737
DecoderBlockType,
38+
MultimodalInput,
3839
ShardMode,
3940
)
4041
from maxtext.inference import page_manager
@@ -904,7 +905,14 @@ def __call__(
904905
audio_embeddings: None | jnp.ndarray = None,
905906
audio_masks: None | jnp.ndarray = None,
906907
deepstack_visual_embeds: None | list[jnp.ndarray] = None,
908+
multimodal_input: None | MultimodalInput = None,
907909
):
910+
if multimodal_input is not None:
911+
image_embeddings = multimodal_input.image_embeddings
912+
image_masks = multimodal_input.image_masks
913+
audio_embeddings = multimodal_input.audio_embeddings
914+
audio_masks = multimodal_input.audio_masks
915+
bidirectional_mask = multimodal_input.bidirectional_mask
908916
cfg = self.config
909917
assert decoder_input_tokens.ndim == 2 # [batch, len]
910918

0 commit comments

Comments
 (0)