Skip to content

Commit 754df44

Browse files
committed
NNX: add TrainState, model creation utilities, and training loop support
- Add TrainStateNNX (layers/train_state_nnx.py) with checkpoint and unit tests - Refactor model_creation_utils with create_nnx_abstract_model(); add NNX support to muon_utils - Add get_abstract_state_nnx() and get_nnx_named_sharding_with_scan_axis() to maxtext_utils.py - Wire NNX train state into train.py and train_utils.py with pure_nnx dispatch
1 parent f5f6760 commit 754df44

10 files changed

Lines changed: 1208 additions & 241 deletions

File tree

src/maxtext/common/checkpointing.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from absl import flags
2121
import datetime
2222
from etils import epath
23+
from flax import nnx
2324
from flax.training import train_state
2425
import jax
2526
from maxtext.utils.globals import DEFAULT_OCDBT_TARGET_DATA_FILE_SIZE
@@ -521,7 +522,7 @@ def load_state_if_possible(
521522
load_parameters_from_path: str,
522523
load_full_state_from_path: str,
523524
checkpoint_storage_concurrent_gb: int,
524-
abstract_unboxed_pre_state: train_state.TrainState,
525+
abstract_unboxed_pre_state: train_state.TrainState | nnx.State,
525526
enable_single_replica_ckpt_restoring: bool | None = False,
526527
dataset_type: str | None = "tfds",
527528
step: int = -1, # -1 means latest
@@ -625,9 +626,14 @@ def map_to_pspec(data):
625626
return (checkpoint_manager.restore(step, args=Composite(items=checkpoint_args)), None)
626627

627628
if load_parameters_from_path != "":
629+
if isinstance(abstract_unboxed_pre_state, nnx.State):
630+
_, params, _ = nnx.split(abstract_unboxed_pre_state.model, nnx.Param, ...)
631+
else:
632+
params = abstract_unboxed_pre_state.params
633+
628634
restored_params = load_params_from_path(
629635
load_parameters_from_path,
630-
abstract_unboxed_pre_state.params,
636+
params,
631637
checkpoint_storage_concurrent_gb,
632638
use_ocdbt=use_ocdbt,
633639
use_zarr3=use_zarr3,

src/maxtext/trainers/pre_train/train.py

Lines changed: 248 additions & 154 deletions
Large diffs are not rendered by default.

src/maxtext/utils/maxtext_utils.py

Lines changed: 178 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -20,21 +20,20 @@
2020
import os
2121
from typing import Sequence
2222

23-
from flax import linen as nn
23+
from flax import nnx, linen as nn
24+
from flax.core.spmd import composite_rules, from_sharding_rules, get_logical_axis_rules
2425
from flax.linen import partitioning as nn_partitioning
25-
from flax.training import train_state
26+
from flax.training.train_state import TrainState
2627

2728
import numpy as np
2829

29-
from jax.experimental import mesh_utils
30-
from jax.experimental.serialize_executable import deserialize_and_load
31-
from jax.sharding import AxisType, Mesh
32-
3330
import jax
3431
import jax.numpy as jnp
32+
from jax.sharding import AxisType, Mesh, NamedSharding, PartitionSpec
33+
from jax.experimental import mesh_utils
34+
from jax.experimental.serialize_executable import deserialize_and_load
3535

3636
import optax
37-
3837
import orbax.checkpoint.experimental.emergency.checkpoint_manager as emergency_checkpoint_manager
3938
import orbax.checkpoint.experimental.emergency.replicator_checkpoint_manager as emergency_replicator_checkpoint_manager
4039

@@ -48,6 +47,7 @@
4847
from maxtext.utils import max_logging
4948
from maxtext.utils import max_utils
5049
from maxtext.utils import sharding
50+
from maxtext.utils import maxtext_utils_nnx
5151

5252
OVERWRITE_WITH_GRADIENT = "_overwrite_with_gradient"
5353

@@ -994,15 +994,15 @@ def _apply_update(path, param):
994994
return state.replace(params=new_params)
995995

996996

997-
def init_decode_state(apply_fn, params) -> train_state.TrainState:
997+
def init_decode_state(apply_fn, params) -> TrainState:
998998
"""Init train state with null opt state for decode."""
999-
state = train_state.TrainState(step=0, apply_fn=apply_fn, params=params, tx=None, opt_state={}) # type: ignore
999+
state = TrainState(step=0, apply_fn=apply_fn, params=params, tx=None, opt_state={}) # type: ignore
10001000
return state
10011001

10021002

10031003
def init_training_state(apply_fn, params, tx):
10041004
"""Init train state with null opt state for decode."""
1005-
state = train_state.TrainState.create(apply_fn=apply_fn, params=params, tx=tx)
1005+
state = TrainState.create(apply_fn=apply_fn, params=params, tx=tx)
10061006
return state
10071007

10081008

@@ -1124,7 +1124,7 @@ def setup_initial_state(
11241124
is_training: True to initialize training state, False for decode state
11251125
11261126
Returns:
1127-
state: the initialized train state
1127+
train_state: the initialized train state. For NNX, this is a TrainStateNNX instance
11281128
state_mesh_annotations: the mesh annotations for the train state
11291129
"""
11301130

@@ -1163,19 +1163,32 @@ def setup_initial_state(
11631163
else:
11641164
# The update of data_iterator state happens in place, no need to assign explicitly
11651165
state = restored["items"]
1166+
1167+
# TODO: For NNX, convert the pure dict to nnx.State.
11661168
else:
11671169
init_state_partial = init_state_fn
11681170
init_state_partial.__name__ = "initialize_state"
1169-
# pylint: disable=not-callable
1170-
state = jax.jit(
1171-
init_state_partial,
1172-
in_shardings=None,
1173-
out_shardings=state_mesh_shardings,
1174-
)()
1171+
if config.pure_nnx:
1172+
state = jax.jit(
1173+
lambda: nnx.state(init_state_partial()), # Get state only, mapping to out_sharding structure
1174+
in_shardings=None,
1175+
out_shardings=state_mesh_shardings,
1176+
)()
1177+
else:
1178+
# pylint: disable=not-callable
1179+
state = jax.jit(
1180+
init_state_partial,
1181+
in_shardings=None,
1182+
out_shardings=state_mesh_shardings,
1183+
)()
11751184
if raw_params: # If we loaded a partial state, we need to merge it.
1176-
state = state.replace(params=raw_params)
1177-
1178-
state = max_utils.unbox_logicallypartioned(state)
1185+
if config.pure_nnx:
1186+
# raw_params should have the same sharding info as in the model
1187+
nnx.update(state.model, raw_params)
1188+
else:
1189+
state = state.replace(params=raw_params)
1190+
if not config.pure_nnx:
1191+
state = max_utils.unbox_logicallypartioned(state)
11791192

11801193
return state, state_mesh_annotations, state_mesh_shardings, data_iterator
11811194

@@ -1191,6 +1204,9 @@ def get_logical_annotations(config, mesh, init_state_fn):
11911204

11921205
def get_abstract_state(config, mesh, init_state_fn, is_training=True):
11931206
"""Get a shaped abstraction of the state (including optimizer)"""
1207+
if config.pure_nnx:
1208+
return get_abstract_state_nnx(config, mesh, init_state_fn, is_training)
1209+
11941210
init_state_partial = init_state_fn
11951211

11961212
with nn_partitioning.axis_rules(config.logical_axis_rules):
@@ -1234,6 +1250,148 @@ def move(path, x):
12341250
)
12351251

12361252

1253+
def get_nnx_named_sharding_with_scan_axis(abs_var_state: nnx.State, mesh) -> nnx.State:
1254+
"""Compute NamedSharding for each NNX variable, correctly handling the scan (stacked layers) axis.
1255+
1256+
Unlike flax.nnx.spmd.get_var_pspec (used inside nnx.get_abstract_model), this function also
1257+
inserts the partition_name axis at the correct scan_axis position for parameters created by
1258+
_create_scanned_layers. Without this, scanned parameters get a 2D partition spec applied to a
1259+
3D tensor, placing sharding on the stacked-layers dimension instead of the embedding dimension.
1260+
1261+
Args:
1262+
abs_var_state: NNX abstract variable state from nnx.split(nnx.eval_shape(...)).
1263+
mesh: JAX physical mesh.
1264+
1265+
Returns:
1266+
Same tree structure as abs_var_state but each Variable's value replaced with NamedSharding.
1267+
"""
1268+
1269+
def _make_named_sharding(v):
1270+
val = v.get_value()
1271+
if not hasattr(val, "shape"):
1272+
# Non-tensor value (e.g., optax MaskedNode for non-trainable params). Preserve
1273+
# as-is so the treedef matches abs_var_state in the downstream jax.tree.map.
1274+
return v
1275+
metadata = v.get_metadata()
1276+
out_sharding = metadata.get("out_sharding") or metadata.get("sharding_names") or metadata.get("sharding")
1277+
if not out_sharding:
1278+
pspec = PartitionSpec()
1279+
else:
1280+
# Insert the scan axis for parameters created by _create_scanned_layers.
1281+
# _add_scan_metadata stores the axis name in nnx.PARTITION_NAME and the
1282+
# axis index in "param_scan_axis". flax.nnx.spmd.get_var_pspec ignores these.
1283+
if nnx.PARTITION_NAME in metadata:
1284+
partition_name = metadata[nnx.PARTITION_NAME]
1285+
# Always use param_scan_axis from metadata. OptVariable (optimizer state) inherits
1286+
# param_scan_axis=1 from the model Param via to_opt_state(), so we must not hardcode
1287+
# scan_axis=0 for non-Param types. stacked_rest non-Param variables have
1288+
# param_scan_axis=0 set explicitly by _add_scan_metadata, so this is always correct.
1289+
scan_axis = metadata.get("param_scan_axis", 0)
1290+
out_sharding = [out_sharding] if isinstance(out_sharding, str) else list(out_sharding)
1291+
# Guard against double-insertion: Flax 0.12.6 _remap_sharding_metadata renames
1292+
# 'sharding' -> 'out_sharding', so _add_scan_metadata may have already inserted
1293+
# the scan axis. Only insert if not already present.
1294+
if partition_name not in out_sharding:
1295+
out_sharding.insert(scan_axis, partition_name)
1296+
out_sharding = tuple(out_sharding)
1297+
# Convert logical axis names to physical mesh axes using current context rules.
1298+
context_rules = get_logical_axis_rules()
1299+
local_rules = metadata.get("sharding_rules", ())
1300+
if context_rules or local_rules:
1301+
rules = composite_rules(context_rules, local_rules)
1302+
pspec = PartitionSpec(*from_sharding_rules(out_sharding, rules))
1303+
else:
1304+
pspec = PartitionSpec(*out_sharding)
1305+
return v.replace(NamedSharding(mesh, pspec))
1306+
1307+
return jax.tree.map(_make_named_sharding, abs_var_state, is_leaf=lambda x: isinstance(x, nnx.Variable))
1308+
1309+
1310+
def get_abstract_state_nnx(config, mesh, nnx_init_trainstate_fn, is_training=True):
1311+
"""Calculates the abstract sharded state and memory placement for an NNX TrainState.
1312+
1313+
This function performs an abstract trace of the NNX model and optimizer using
1314+
`nnx.get_abstract_model`. It resolves logical sharding annotations into physical
1315+
JAX shardings and applies memory placement optimizations such as optimizer
1316+
sharding and host memory offloading (pinning to CPU RAM).
1317+
1318+
Args:
1319+
config: Configuration object containing sharding and offloading hyperparameters
1320+
(e.g., shard_optimizer_over_data, optimizer_memory_host_offload).
1321+
mesh: JAX physical mesh used to resolve logical axis names to physical devices.
1322+
nnx_init_trainstate_fn: A zero-argument factory function that produces a
1323+
TrainStateNNX instance during the abstract trace.
1324+
is_training: Boolean indicating if the state is for training. If True,
1325+
optimizer state is processed and memory offloading strategies are applied.
1326+
1327+
Returns:
1328+
A tuple containing (abstract_sharded_state, None, state_mesh_shardings):
1329+
abstract_sharded_state: An nnx.State containing ShapeDtypeStructs with
1330+
fully resolved physical sharding and memory_kind metadata.
1331+
state_mesh_annotations: An nnx.State tree consisting of the raw PartitionSpec
1332+
objects corresponding to each parameter/variable.
1333+
state_mesh_shardings: An nnx.State tree consisting of the raw JAX
1334+
Sharding objects corresponding to each parameter/variable.
1335+
"""
1336+
assert nnx_init_trainstate_fn is not None, "get_abstract_state_nnx: init function must be given."
1337+
1338+
with nn_partitioning.axis_rules(config.logical_axis_rules):
1339+
# Use nnx.eval_shape + nnx.split instead of nnx.get_abstract_model, so we can apply
1340+
# get_nnx_named_sharding_with_scan_axis which correctly inserts the stacked-layers
1341+
# axis into the partition spec. nnx.get_abstract_model uses get_var_pspec internally
1342+
# which ignores nnx.PARTITION_NAME / param_scan_axis metadata set by _create_scanned_layers,
1343+
# causing the 2D partition spec to be misapplied to the 3D stacked parameter tensor.
1344+
# Do NOT wrap nnx.eval_shape in jax.set_mesh: Flax 0.12.6's _to_variable calls
1345+
# var.shape for every variable when a global mesh is active, but masked optimizer
1346+
# state variables (e.g. from trainable_parameters_mask) have value=MaskedNode()
1347+
# which has no .shape and would raise AttributeError. We handle sharding
1348+
# ourselves via get_nnx_named_sharding_with_scan_axis, so auto-assignment is not
1349+
# needed here.
1350+
abs_model = nnx.eval_shape(nnx_init_trainstate_fn)
1351+
_, abs_var_state = nnx.split(abs_model)
1352+
named_sharding_state = get_nnx_named_sharding_with_scan_axis(abs_var_state, mesh)
1353+
abstract_state = jax.tree.map(
1354+
lambda a, s: jax.ShapeDtypeStruct(a.shape, a.dtype, sharding=s),
1355+
abs_var_state,
1356+
named_sharding_state,
1357+
)
1358+
1359+
state_mesh_shardings = maxtext_utils_nnx.get_named_sharding_nnx(abstract_state)
1360+
1361+
if is_training and config.shard_optimizer_over_data:
1362+
# Add data to sharding for optimizer state
1363+
optimizer_sharding = jax.tree_util.tree_map_with_path(
1364+
functools.partial(sharding.add_data_to_sharding, mesh),
1365+
abstract_state.optimizer,
1366+
state_mesh_shardings.optimizer,
1367+
)
1368+
state_mesh_shardings.optimizer = optimizer_sharding
1369+
if is_training and config.optimizer_memory_host_offload:
1370+
optimizer_sharding = jax.tree_util.tree_map_with_path(
1371+
maxtext_utils_nnx.move_memory_to_host,
1372+
state_mesh_shardings.optimizer,
1373+
is_leaf=lambda x: isinstance(x, NamedSharding),
1374+
)
1375+
state_mesh_shardings.optimizer = optimizer_sharding
1376+
if is_training and config.parameter_memory_host_offload:
1377+
assert config.param_scan_axis == 0, "You must set the scan axis 0 to enable parameter offloading."
1378+
_, state_params, _ = nnx.split(state_mesh_shardings, nnx.Param, ...)
1379+
state_params = jax.tree_util.tree_map_with_path(
1380+
maxtext_utils_nnx.move_memory_to_host,
1381+
state_params,
1382+
is_leaf=lambda x: isinstance(x, NamedSharding),
1383+
)
1384+
nnx.update(state_mesh_shardings, state_params)
1385+
1386+
abstract_sharded_state = maxtext_utils_nnx.set_named_sharding_nnx(abstract_state, state_mesh_shardings)
1387+
state_mesh_annotations = maxtext_utils_nnx.get_partition_spec_nnx(state_mesh_shardings)
1388+
return (
1389+
abstract_sharded_state,
1390+
state_mesh_annotations,
1391+
state_mesh_shardings,
1392+
)
1393+
1394+
12371395
def get_prefill_kv_cache_annotations(model, config, rng, mesh, page_state: None | PageState = None):
12381396
"""Get a shaped abstraction of the state (including optimizer)"""
12391397

0 commit comments

Comments
 (0)