Skip to content

Commit 090b545

Browse files
committed
NNX: add TrainState, model creation utilities, and training loop support
- Add TrainStateNNX (layers/train_state_nnx.py) with checkpoint and unit tests - Refactor model_creation_utils with create_nnx_abstract_model(); add NNX support to muon_utils - Add get_abstract_state_nnx() and get_nnx_named_sharding_with_scan_axis() to maxtext_utils.py - Wire NNX train state into train.py and train_utils.py with pure_nnx dispatch
1 parent 412902a commit 090b545

21 files changed

Lines changed: 2978 additions & 378 deletions

src/maxtext/common/checkpointing.py

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from absl import flags
2121
import datetime
2222
from etils import epath
23+
from flax import nnx
2324
from flax.training import train_state
2425
import jax
2526
from maxtext.utils.globals import DEFAULT_OCDBT_TARGET_DATA_FILE_SIZE
@@ -536,7 +537,7 @@ def load_state_if_possible(
536537
load_parameters_from_path: str,
537538
load_full_state_from_path: str,
538539
checkpoint_storage_concurrent_gb: int,
539-
abstract_unboxed_pre_state: train_state.TrainState,
540+
abstract_unboxed_pre_state: train_state.TrainState | nnx.State,
540541
enable_single_replica_ckpt_restoring: bool | None = False,
541542
dataset_type: str | None = "tfds",
542543
step: int = -1, # -1 means latest
@@ -604,9 +605,14 @@ def map_to_pspec(data):
604605
)
605606
ocp.type_handlers.register_type_handler(jax.Array, array_handler, override=True)
606607

607-
restore_args = jax.tree_util.tree_map(map_to_pspec, abstract_unboxed_pre_state)
608+
# Convert nnx.State to pure dict to match how checkpoints are saved for NNX
609+
restore_target = abstract_unboxed_pre_state
610+
if isinstance(abstract_unboxed_pre_state, nnx.State):
611+
restore_target = abstract_unboxed_pre_state.to_pure_dict()
612+
613+
restore_args = jax.tree_util.tree_map(map_to_pspec, restore_target)
608614
checkpoint_args = ocp.args.PyTreeRestore(
609-
item=abstract_unboxed_pre_state,
615+
item=restore_target,
610616
restore_args=restore_args,
611617
partial_restore=True,
612618
)
@@ -647,9 +653,14 @@ def map_to_pspec(data):
647653
return (checkpoint_manager.restore(step, args=Composite(items=checkpoint_args)), None)
648654

649655
if load_parameters_from_path != "":
656+
if isinstance(abstract_unboxed_pre_state, nnx.State):
657+
_, params, _ = nnx.split(abstract_unboxed_pre_state.model, nnx.Param, ...)
658+
else:
659+
params = abstract_unboxed_pre_state.params
660+
650661
restored_params = load_params_from_path(
651662
load_parameters_from_path,
652-
abstract_unboxed_pre_state.params,
663+
params,
653664
checkpoint_storage_concurrent_gb,
654665
use_ocdbt=use_ocdbt,
655666
use_zarr3=use_zarr3,
@@ -741,7 +752,18 @@ def maybe_save_checkpoint(checkpoint_manager, state, config, data_iterator, step
741752
# Determine the effective step for saving a checkpoint.
742753
# If 'step' is not provided, this call is for a potential final checkpoint
743754
# and use the last completed step from the state.
744-
actual_step = (int(state.step) - 1) if step is None else int(step)
755+
if step is not None:
756+
actual_step = int(step)
757+
else:
758+
if config.pure_nnx:
759+
actual_step = int(state.optimizer.step) - 1
760+
else:
761+
# Linen TrainState has .step attribute
762+
actual_step = int(state.step) - 1
763+
764+
if config.pure_nnx:
765+
# Convert nnx.State to dict.
766+
state = state.to_pure_dict()
745767

746768
# Determine if a checkpoint save should be forced, overriding the usual `config.checkpoint_period` logic.
747769
# This occurs if this function was called:

src/maxtext/layers/nnx_decoders.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
MODEL_MODE_TRAIN,
3636
Config,
3737
DecoderBlockType,
38+
MultimodalInput,
3839
ShardMode,
3940
)
4041
from maxtext.inference import page_manager
@@ -904,7 +905,14 @@ def __call__(
904905
audio_embeddings: None | jnp.ndarray = None,
905906
audio_masks: None | jnp.ndarray = None,
906907
deepstack_visual_embeds: None | list[jnp.ndarray] = None,
908+
multimodal_input: None | MultimodalInput = None,
907909
):
910+
if multimodal_input is not None:
911+
image_embeddings = multimodal_input.image_embeddings
912+
image_masks = multimodal_input.image_masks
913+
audio_embeddings = multimodal_input.audio_embeddings
914+
audio_masks = multimodal_input.audio_masks
915+
bidirectional_mask = multimodal_input.bidirectional_mask
908916
cfg = self.config
909917
assert decoder_input_tokens.ndim == 2 # [batch, len]
910918

0 commit comments

Comments
 (0)