2020import os
2121from typing import Sequence
2222
23- from flax import linen as nn
23+ from flax import nnx , linen as nn
24+ from flax .core .spmd import composite_rules , from_sharding_rules , get_logical_axis_rules
2425from flax .linen import partitioning as nn_partitioning
25- from flax .training import train_state
26+ from flax .training . train_state import TrainState
2627
2728import numpy as np
2829
29- from jax .experimental import mesh_utils
30- from jax .experimental .serialize_executable import deserialize_and_load
31- from jax .sharding import AxisType , Mesh
32-
3330import jax
3431import jax .numpy as jnp
32+ from jax .sharding import AxisType , Mesh , NamedSharding , PartitionSpec
33+ from jax .experimental import mesh_utils
34+ from jax .experimental .serialize_executable import deserialize_and_load
3535
3636import optax
37-
3837import orbax .checkpoint .experimental .emergency .checkpoint_manager as emergency_checkpoint_manager
3938import orbax .checkpoint .experimental .emergency .replicator_checkpoint_manager as emergency_replicator_checkpoint_manager
4039
4847from maxtext .utils import max_logging
4948from maxtext .utils import max_utils
5049from maxtext .utils import sharding
50+ from maxtext .utils import maxtext_utils_nnx
5151
5252OVERWRITE_WITH_GRADIENT = "_overwrite_with_gradient"
5353
@@ -994,15 +994,15 @@ def _apply_update(path, param):
994994 return state .replace (params = new_params )
995995
996996
997- def init_decode_state (apply_fn , params ) -> train_state . TrainState :
997+ def init_decode_state (apply_fn , params ) -> TrainState :
998998 """Init train state with null opt state for decode."""
999- state = train_state . TrainState (step = 0 , apply_fn = apply_fn , params = params , tx = None , opt_state = {}) # type: ignore
999+ state = TrainState (step = 0 , apply_fn = apply_fn , params = params , tx = None , opt_state = {}) # type: ignore
10001000 return state
10011001
10021002
10031003def init_training_state (apply_fn , params , tx ):
10041004 """Init train state with null opt state for decode."""
1005- state = train_state . TrainState .create (apply_fn = apply_fn , params = params , tx = tx )
1005+ state = TrainState .create (apply_fn = apply_fn , params = params , tx = tx )
10061006 return state
10071007
10081008
@@ -1124,7 +1124,7 @@ def setup_initial_state(
11241124 is_training: True to initialize training state, False for decode state
11251125
11261126 Returns:
1127- state : the initialized train state
1127+ train_state : the initialized train state. For NNX, this is a TrainStateNNX instance
11281128 state_mesh_annotations: the mesh annotations for the train state
11291129 """
11301130
@@ -1163,19 +1163,32 @@ def setup_initial_state(
11631163 else :
11641164 # The update of data_iterator state happens in place, no need to assign explicitly
11651165 state = restored ["items" ]
1166+
1167+ # TODO: For NNX, convert the pure dict to nnx.State.
11661168 else :
11671169 init_state_partial = init_state_fn
11681170 init_state_partial .__name__ = "initialize_state"
1169- # pylint: disable=not-callable
1170- state = jax .jit (
1171- init_state_partial ,
1172- in_shardings = None ,
1173- out_shardings = state_mesh_shardings ,
1174- )()
1171+ if config .pure_nnx :
1172+ state = jax .jit (
1173+ lambda : nnx .state (init_state_partial ()), # Get state only, mapping to out_sharding structure
1174+ in_shardings = None ,
1175+ out_shardings = state_mesh_shardings ,
1176+ )()
1177+ else :
1178+ # pylint: disable=not-callable
1179+ state = jax .jit (
1180+ init_state_partial ,
1181+ in_shardings = None ,
1182+ out_shardings = state_mesh_shardings ,
1183+ )()
11751184 if raw_params : # If we loaded a partial state, we need to merge it.
1176- state = state .replace (params = raw_params )
1177-
1178- state = max_utils .unbox_logicallypartioned (state )
1185+ if config .pure_nnx :
1186+ # raw_params should have the same sharding info as in the model
1187+ nnx .update (state .model , raw_params )
1188+ else :
1189+ state = state .replace (params = raw_params )
1190+ if not config .pure_nnx :
1191+ state = max_utils .unbox_logicallypartioned (state )
11791192
11801193 return state , state_mesh_annotations , state_mesh_shardings , data_iterator
11811194
@@ -1191,6 +1204,9 @@ def get_logical_annotations(config, mesh, init_state_fn):
11911204
11921205def get_abstract_state (config , mesh , init_state_fn , is_training = True ):
11931206 """Get a shaped abstraction of the state (including optimizer)"""
1207+ if config .pure_nnx :
1208+ return get_abstract_state_nnx (config , mesh , init_state_fn , is_training )
1209+
11941210 init_state_partial = init_state_fn
11951211
11961212 with nn_partitioning .axis_rules (config .logical_axis_rules ):
@@ -1234,6 +1250,148 @@ def move(path, x):
12341250 )
12351251
12361252
1253+ def get_nnx_named_sharding_with_scan_axis (abs_var_state : nnx .State , mesh ) -> nnx .State :
1254+ """Compute NamedSharding for each NNX variable, correctly handling the scan (stacked layers) axis.
1255+
1256+ Unlike flax.nnx.spmd.get_var_pspec (used inside nnx.get_abstract_model), this function also
1257+ inserts the partition_name axis at the correct scan_axis position for parameters created by
1258+ _create_scanned_layers. Without this, scanned parameters get a 2D partition spec applied to a
1259+ 3D tensor, placing sharding on the stacked-layers dimension instead of the embedding dimension.
1260+
1261+ Args:
1262+ abs_var_state: NNX abstract variable state from nnx.split(nnx.eval_shape(...)).
1263+ mesh: JAX physical mesh.
1264+
1265+ Returns:
1266+ Same tree structure as abs_var_state but each Variable's value replaced with NamedSharding.
1267+ """
1268+
1269+ def _make_named_sharding (v ):
1270+ val = v .get_value ()
1271+ if not hasattr (val , "shape" ):
1272+ # Non-tensor value (e.g., optax MaskedNode for non-trainable params). Preserve
1273+ # as-is so the treedef matches abs_var_state in the downstream jax.tree.map.
1274+ return v
1275+ metadata = v .get_metadata ()
1276+ out_sharding = metadata .get ("out_sharding" ) or metadata .get ("sharding_names" ) or metadata .get ("sharding" )
1277+ if not out_sharding :
1278+ pspec = PartitionSpec ()
1279+ else :
1280+ # Insert the scan axis for parameters created by _create_scanned_layers.
1281+ # _add_scan_metadata stores the axis name in nnx.PARTITION_NAME and the
1282+ # axis index in "param_scan_axis". flax.nnx.spmd.get_var_pspec ignores these.
1283+ if nnx .PARTITION_NAME in metadata :
1284+ partition_name = metadata [nnx .PARTITION_NAME ]
1285+ # Always use param_scan_axis from metadata. OptVariable (optimizer state) inherits
1286+ # param_scan_axis=1 from the model Param via to_opt_state(), so we must not hardcode
1287+ # scan_axis=0 for non-Param types. stacked_rest non-Param variables have
1288+ # param_scan_axis=0 set explicitly by _add_scan_metadata, so this is always correct.
1289+ scan_axis = metadata .get ("param_scan_axis" , 0 )
1290+ out_sharding = [out_sharding ] if isinstance (out_sharding , str ) else list (out_sharding )
1291+ # Guard against double-insertion: Flax 0.12.6 _remap_sharding_metadata renames
1292+ # 'sharding' -> 'out_sharding', so _add_scan_metadata may have already inserted
1293+ # the scan axis. Only insert if not already present.
1294+ if partition_name not in out_sharding :
1295+ out_sharding .insert (scan_axis , partition_name )
1296+ out_sharding = tuple (out_sharding )
1297+ # Convert logical axis names to physical mesh axes using current context rules.
1298+ context_rules = get_logical_axis_rules ()
1299+ local_rules = metadata .get ("sharding_rules" , ())
1300+ if context_rules or local_rules :
1301+ rules = composite_rules (context_rules , local_rules )
1302+ pspec = PartitionSpec (* from_sharding_rules (out_sharding , rules ))
1303+ else :
1304+ pspec = PartitionSpec (* out_sharding )
1305+ return v .replace (NamedSharding (mesh , pspec ))
1306+
1307+ return jax .tree .map (_make_named_sharding , abs_var_state , is_leaf = lambda x : isinstance (x , nnx .Variable ))
1308+
1309+
1310+ def get_abstract_state_nnx (config , mesh , nnx_init_trainstate_fn , is_training = True ):
1311+ """Calculates the abstract sharded state and memory placement for an NNX TrainState.
1312+
1313+ This function performs an abstract trace of the NNX model and optimizer using
1314+ `nnx.get_abstract_model`. It resolves logical sharding annotations into physical
1315+ JAX shardings and applies memory placement optimizations such as optimizer
1316+ sharding and host memory offloading (pinning to CPU RAM).
1317+
1318+ Args:
1319+ config: Configuration object containing sharding and offloading hyperparameters
1320+ (e.g., shard_optimizer_over_data, optimizer_memory_host_offload).
1321+ mesh: JAX physical mesh used to resolve logical axis names to physical devices.
1322+ nnx_init_trainstate_fn: A zero-argument factory function that produces a
1323+ TrainStateNNX instance during the abstract trace.
1324+ is_training: Boolean indicating if the state is for training. If True,
1325+ optimizer state is processed and memory offloading strategies are applied.
1326+
1327+ Returns:
1328+ A tuple containing (abstract_sharded_state, None, state_mesh_shardings):
1329+ abstract_sharded_state: An nnx.State containing ShapeDtypeStructs with
1330+ fully resolved physical sharding and memory_kind metadata.
1331+ state_mesh_annotations: An nnx.State tree consisting of the raw PartitionSpec
1332+ objects corresponding to each parameter/variable.
1333+ state_mesh_shardings: An nnx.State tree consisting of the raw JAX
1334+ Sharding objects corresponding to each parameter/variable.
1335+ """
1336+ assert nnx_init_trainstate_fn is not None , "get_abstract_state_nnx: init function must be given."
1337+
1338+ with nn_partitioning .axis_rules (config .logical_axis_rules ):
1339+ # Use nnx.eval_shape + nnx.split instead of nnx.get_abstract_model, so we can apply
1340+ # get_nnx_named_sharding_with_scan_axis which correctly inserts the stacked-layers
1341+ # axis into the partition spec. nnx.get_abstract_model uses get_var_pspec internally
1342+ # which ignores nnx.PARTITION_NAME / param_scan_axis metadata set by _create_scanned_layers,
1343+ # causing the 2D partition spec to be misapplied to the 3D stacked parameter tensor.
1344+ # Do NOT wrap nnx.eval_shape in jax.set_mesh: Flax 0.12.6's _to_variable calls
1345+ # var.shape for every variable when a global mesh is active, but masked optimizer
1346+ # state variables (e.g. from trainable_parameters_mask) have value=MaskedNode()
1347+ # which has no .shape and would raise AttributeError. We handle sharding
1348+ # ourselves via get_nnx_named_sharding_with_scan_axis, so auto-assignment is not
1349+ # needed here.
1350+ abs_model = nnx .eval_shape (nnx_init_trainstate_fn )
1351+ _ , abs_var_state = nnx .split (abs_model )
1352+ named_sharding_state = get_nnx_named_sharding_with_scan_axis (abs_var_state , mesh )
1353+ abstract_state = jax .tree .map (
1354+ lambda a , s : jax .ShapeDtypeStruct (a .shape , a .dtype , sharding = s ),
1355+ abs_var_state ,
1356+ named_sharding_state ,
1357+ )
1358+
1359+ state_mesh_shardings = maxtext_utils_nnx .get_named_sharding_nnx (abstract_state )
1360+
1361+ if is_training and config .shard_optimizer_over_data :
1362+ # Add data to sharding for optimizer state
1363+ optimizer_sharding = jax .tree_util .tree_map_with_path (
1364+ functools .partial (sharding .add_data_to_sharding , mesh ),
1365+ abstract_state .optimizer ,
1366+ state_mesh_shardings .optimizer ,
1367+ )
1368+ state_mesh_shardings .optimizer = optimizer_sharding
1369+ if is_training and config .optimizer_memory_host_offload :
1370+ optimizer_sharding = jax .tree_util .tree_map_with_path (
1371+ maxtext_utils_nnx .move_memory_to_host ,
1372+ state_mesh_shardings .optimizer ,
1373+ is_leaf = lambda x : isinstance (x , NamedSharding ),
1374+ )
1375+ state_mesh_shardings .optimizer = optimizer_sharding
1376+ if is_training and config .parameter_memory_host_offload :
1377+ assert config .param_scan_axis == 0 , "You must set the scan axis 0 to enable parameter offloading."
1378+ _ , state_params , _ = nnx .split (state_mesh_shardings , nnx .Param , ...)
1379+ state_params = jax .tree_util .tree_map_with_path (
1380+ maxtext_utils_nnx .move_memory_to_host ,
1381+ state_params ,
1382+ is_leaf = lambda x : isinstance (x , NamedSharding ),
1383+ )
1384+ nnx .update (state_mesh_shardings , state_params )
1385+
1386+ abstract_sharded_state = maxtext_utils_nnx .set_named_sharding_nnx (abstract_state , state_mesh_shardings )
1387+ state_mesh_annotations = maxtext_utils_nnx .get_partition_spec_nnx (state_mesh_shardings )
1388+ return (
1389+ abstract_sharded_state ,
1390+ state_mesh_annotations ,
1391+ state_mesh_shardings ,
1392+ )
1393+
1394+
12371395def get_prefill_kv_cache_annotations (model , config , rng , mesh , page_state : None | PageState = None ):
12381396 """Get a shaped abstraction of the state (including optimizer)"""
12391397
0 commit comments