Skip to content

Commit bac289f

Browse files
committed
Set NNX flags to true by default
1 parent f27e4f9 commit bac289f

71 files changed

Lines changed: 48619 additions & 13525 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

src/maxtext/configs/base.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1125,9 +1125,9 @@ position_id_per_seconds: 25
11251125
subslice_shape: ""
11261126

11271127
# NNX
1128-
enable_nnx: False
1129-
pure_nnx_decoder: False
1130-
pure_nnx: False
1128+
enable_nnx: True
1129+
pure_nnx_decoder: True
1130+
pure_nnx: True
11311131

11321132
################################## Qwen3-Next Specific Configs ##################################
11331133
# Kernel size for the 1D convolution in the Gated Delta Net

src/maxtext/configs/custom_mesh_and_rule/pipeline-large-moe.yml

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,4 +72,32 @@ logical_axis_rules: [
7272
['exp_with_fsdp', 'fsdp'],
7373
['paged_kv_heads', ['tensor']],
7474
['engram_dim', ['tensor']],
75+
# Axes unsharded: sequence/context/tensor_transpose/autoregressive do not exist in this mesh
76+
['activation_attn_length_no_exp', []],
77+
['activation_length_no_exp', []],
78+
['activation_norm_length', []],
79+
['activation_q_length_no_exp', []],
80+
['prefill_activation_length', []],
81+
['prefill_activation_norm_length', []],
82+
['activation_kv_length', []],
83+
['decode_length', []],
84+
['embed_tensor_transpose', []],
85+
['q_lora_up_proj', []],
86+
['kv_lora_up_proj', []],
87+
['kv', []],
88+
['qkv', []],
89+
['kv_head_dim', []],
90+
['cache_batch_prefill', []],
91+
['cache_batch', []],
92+
['cache_heads_none', []],
93+
['cache_kv', []],
94+
['cache_sequence', []],
95+
['num_pages', []],
96+
['tokens_per_page', []],
97+
['paged_kv_head_dim_size', []],
98+
['dense_layers', []],
99+
['moe_layers', []],
100+
['num_activations', []],
101+
['mhc', []],
102+
['diloco', []],
75103
]

src/maxtext/configs/custom_mesh_and_rule/pure-fsdp.yml

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,4 +34,57 @@ logical_axis_rules: [
3434
['q_lora', ['fsdp']],
3535
['kv_lora', ['fsdp']],
3636
['exp_with_fsdp', 'fsdp'],
37+
# All other axes are unsharded (tensor/sequence/expert axes do not exist in pure-fsdp)
38+
['activation_heads', []],
39+
['activation_kv_heads', []],
40+
['activation_length', []],
41+
['activation_attn_length', []],
42+
['activation_attn_length_no_exp', []],
43+
['activation_length_no_exp', []],
44+
['activation_norm_length', []],
45+
['activation_q_length', []],
46+
['activation_q_length_no_exp', []],
47+
['prefill_activation_length', []],
48+
['prefill_activation_norm_length', []],
49+
['activation_kv_length', []],
50+
['activation_attn_embed', []],
51+
['activation_embed', []],
52+
['activation_mlp', []],
53+
['activation_kv', []],
54+
['activation_kv_head_dim', []],
55+
['activation_vocab', []],
56+
['activation_stage', []],
57+
['activation_exp', []],
58+
['decode_length', []],
59+
['mlp', []],
60+
['mlp_no_fsdp', []],
61+
['vocab', []],
62+
['heads', []],
63+
['q_heads', []],
64+
['kv_heads', []],
65+
['embed_tensor_transpose', []],
66+
['q_lora_up_proj', []],
67+
['kv_lora_up_proj', []],
68+
['norm', []],
69+
['layers', []],
70+
['qkv', []],
71+
['kv', []],
72+
['kv_head_dim', []],
73+
['cache_batch_prefill', []],
74+
['cache_batch', []],
75+
['cache_heads_none', []],
76+
['cache_heads', []],
77+
['cache_kv', []],
78+
['cache_sequence', []],
79+
['exp', []],
80+
['paged_kv_heads', []],
81+
['num_pages', []],
82+
['tokens_per_page', []],
83+
['paged_kv_head_dim_size', []],
84+
['dense_layers', []],
85+
['moe_layers', []],
86+
['num_activations', []],
87+
['engram_dim', []],
88+
['mhc', []],
89+
['diloco', []],
3790
]

src/maxtext/layers/nnx_decoders.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -311,7 +311,7 @@ def __init__(
311311

312312
num_moe = config.num_decoder_layers - config.first_num_dense_layers
313313

314-
self.moe_layer = self._create_scanned_layers(moe_cls, length=num_moe, rngs=rngs)
314+
self.moe_layers = self._create_scanned_layers(moe_cls, length=num_moe, rngs=rngs)
315315
elif self.is_gemma3:
316316
attention_pattern_length = len(gemma3.GEMMA3_ATTENTION_PATTERN)
317317
scan_length = config.num_decoder_layers // attention_pattern_length
@@ -346,7 +346,7 @@ def __init__(
346346
for i in range(config.first_num_dense_layers):
347347
self._create_and_register_layer(dense_cls, rngs, "dense_layer", i)
348348
for i in range(config.num_decoder_layers - config.first_num_dense_layers):
349-
self._create_and_register_layer(moe_cls, rngs, "moe_layer", i)
349+
self._create_and_register_layer(moe_cls, rngs, "moe_layers", i)
350350
else:
351351
layer_cls = decoder_block_classes[0]
352352

@@ -388,6 +388,8 @@ def _create_single_layer(self, decoder_layer_class, rngs, **kwargs):
388388

389389
def _create_scanned_layers(self, decoder_layer_class, length: int, rngs: nnx.Rngs, **layer_kwargs):
390390
"""Creates a VMapped stack of layers, forcing parameter init for Compact modules."""
391+
if length == 0:
392+
return nnx.List([])
391393

392394
def create_layer_fn(rng):
393395
layer = decoder_layer_class(
@@ -433,6 +435,8 @@ def pure_layer_fn(state_in, y_in):
433435

434436
def _apply_layers_sequentially(self, layers, x_in, *args, length: int, **kwargs):
435437
"""Runs the layer stack using nnx.scan."""
438+
if length == 0:
439+
return x_in, layers
436440
policy = self.get_remat_policy()
437441
prevent_cse = maxtext_utils.should_prevent_cse_in_remat(self.config)
438442
graphdef, params, state = nnx.split(
@@ -961,7 +965,7 @@ def __call__(
961965

962966
y = self._apply_interleaved_scanned_layers(
963967
y,
964-
self.moe_layer,
968+
self.moe_layers,
965969
0,
966970
(cfg.num_decoder_layers - cfg.first_num_dense_layers),
967971
[e - cfg.first_num_dense_layers for e in cfg.engram_layers],
@@ -978,7 +982,7 @@ def __call__(
978982
if cfg.use_batch_split_schedule:
979983
policy = self.get_remat_policy()
980984

981-
mock_params = self._build_linen_params(self.moe_layer)
985+
mock_params = self._build_linen_params(self.moe_layers)
982986

983987
y = deepseek_batchsplit.scan_batch_split_layers(
984988
y,
@@ -992,8 +996,8 @@ def __call__(
992996
policy=policy,
993997
)
994998
else:
995-
y, self.moe_layer = self._apply_layers_sequentially(
996-
self.moe_layer, y, *layer_args, length=num_moe, **layer_kwargs
999+
y, self.moe_layers = self._apply_layers_sequentially(
1000+
self.moe_layers, y, *layer_args, length=num_moe, **layer_kwargs
9971001
)
9981002
elif self.is_gemma3:
9991003
y = self._apply_gemma3_scanned_blocks(

src/maxtext/models/gpt_oss.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from maxtext.common.common_types import AttentionType, Config
2929
from maxtext.layers import attentions
3030
from maxtext.layers import initializers
31+
from maxtext.layers import linears
3132
from maxtext.layers import moe
3233
from maxtext.layers import nnx_wrappers
3334
from maxtext.layers import quantizations
@@ -130,6 +131,8 @@ def __init__(
130131
rngs=rngs,
131132
)
132133

134+
self.dropout = linears.Dropout(rate=config.dropout_rate, broadcast_dims=(-2,), rngs=rngs)
135+
133136
def __call__(
134137
self,
135138
inputs,
@@ -181,7 +184,7 @@ def __call__(
181184
mlp_lnx = nn.with_logical_constraint(mlp_lnx, ("activation_batch", "activation_norm_length", "activation_embed"))
182185

183186
layer_output = mlp_lnx + intermediate_inputs
184-
layer_output = nn.Dropout(rate=cfg.dropout_rate, broadcast_dims=(-2,))(layer_output, deterministic=deterministic)
187+
layer_output = self.dropout(layer_output, deterministic=deterministic)
185188

186189
layer_output = nn.with_logical_constraint(
187190
layer_output,

src/maxtext/trainers/diloco/diloco.py

Lines changed: 49 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from typing import Any, Callable
2727

2828
import drjax
29+
from flax import nnx
2930
from flax import struct
3031
from flax.training import train_state
3132
import jax
@@ -153,15 +154,23 @@ def add_diloco_dim(x):
153154
momentum=config.diloco_outer_momentum,
154155
nesterov=True,
155156
)
156-
outer_opt_state = jax.eval_shape(outer_optimizer.init, abstract_state.params)
157+
# For NNX, model params (Param variables only) live under abstract_state.model;
158+
# for Linen under abstract_state.params.
159+
if config.pure_nnx:
160+
model_params = abstract_state.model.filter(nnx.Param)
161+
model_params_sharding = state_mesh_shardings.model.filter(nnx.Param)
162+
else:
163+
model_params = abstract_state.params
164+
model_params_sharding = state_mesh_shardings.params
165+
outer_opt_state = jax.eval_shape(outer_optimizer.init, model_params)
157166

158167
# Create abstract step
159168
abstract_step = jax.ShapeDtypeStruct((), jnp.int32)
160169

161170
# Build abstract DiLoCo state
162171
diloco_state = DiLoCoTrainState(
163172
inner_state=inner_state,
164-
params=abstract_state.params,
173+
params=model_params,
165174
outer_opt_state=outer_opt_state,
166175
step=abstract_step,
167176
)
@@ -171,12 +180,12 @@ def add_diloco_dim(x):
171180
# Sharding for outer_opt_state. For SGD with momentum, it is (TraceState(trace=...), EmptyState())
172181
# We shard the momentum trace the same way as the parameters.
173182
outer_opt_state_sharding = (
174-
optax.TraceState(trace=state_mesh_shardings.params),
183+
optax.TraceState(trace=model_params_sharding),
175184
optax.EmptyState(),
176185
)
177186
diloco_state_shardings = DiLoCoTrainState(
178187
inner_state=inner_state_shardings,
179-
params=state_mesh_shardings.params,
188+
params=model_params_sharding,
180189
outer_opt_state=outer_opt_state_sharding,
181190
step=None,
182191
)
@@ -205,11 +214,15 @@ def init_diloco_state() -> tuple[DiLoCoTrainState, PyTree]:
205214
# mesh automatically when jax.set_mesh is used.
206215
inner_state = drjax.broadcast(state, mesh=mesh)
207216
# Outer state retains a single copy of the model parameters and optimizer state.
208-
outer_params = state.params
217+
# For NNX, model params (Param variables only) live under state.model;
218+
# for Linen under state.params.
219+
outer_params = state.model.filter(nnx.Param) if config.pure_nnx else state.params
209220
outer_opt_state = outer_optimizer.init(outer_params)
210221
outer_opt_state_sharding = jax.tree_util.tree_map(lambda x: x.sharding, outer_opt_state)
222+
# For NNX, the step counter lives at state.optimizer.step; for Linen at state.step.
223+
step = state.optimizer.step if config.pure_nnx else state.step
211224
return (
212-
DiLoCoTrainState(inner_state=inner_state, params=outer_params, outer_opt_state=outer_opt_state, step=state.step),
225+
DiLoCoTrainState(inner_state=inner_state, params=outer_params, outer_opt_state=outer_opt_state, step=step),
213226
outer_opt_state_sharding,
214227
)
215228

@@ -244,7 +257,11 @@ def synchronize(state):
244257
# Calculate the delta between the current replica's state and the global
245258
# state (since last synchronization).
246259
broadcast_outer_params = drjax.broadcast(state.params, mesh=mesh)
247-
model_delta = jax.tree.map(lambda x, y: y - x, state.inner_state.params, broadcast_outer_params)
260+
# For NNX, model Param vars live under inner_state.model; for Linen under inner_state.params.
261+
inner_model_params = (
262+
nnx.filter_state(state.inner_state.model, nnx.Param) if config.pure_nnx else state.inner_state.params
263+
)
264+
model_delta = jax.tree.map(lambda x, y: y - x, inner_model_params, broadcast_outer_params)
248265
# Treat the average delta as the outer optimizer's gradient and apply to
249266
# the global (outer) model params.
250267
averaged_pseudo_grad = drjax.reduce_mean(model_delta)
@@ -253,7 +270,27 @@ def synchronize(state):
253270
# Replace inner model params with the new global model params.
254271
# NOTE: inner optimizer state is retained despite the change in parameters,
255272
# see section 6.1 in https://arxiv.org/pdf/2311.08105.
256-
new_inner_state = drjax.map_fn(lambda state: state.replace(params=new_outer_params), state.inner_state, mesh=mesh)
273+
if config.pure_nnx:
274+
# For NNX: merge new Param vars back with the non-Param model vars (e.g. RNG state).
275+
def replace_nnx_model_params(s, new_params):
276+
non_param_model = nnx.filter_state(s.model, nnx.Not(nnx.Param))
277+
new_model = nnx.merge_state(non_param_model, new_params)
278+
# Build result via __setitem__ so nested States are stored as plain dicts
279+
# internally, matching the pytree structure produced by nnx.state().
280+
# (Passing State objects via the constructor dict literal stores them
281+
# as-is, causing jax.lax.cond to see mismatched pytree structures.)
282+
result = type(s)({})
283+
result["model"] = new_model
284+
result["optimizer"] = s["optimizer"]
285+
return result
286+
287+
new_inner_state = drjax.map_fn(
288+
lambda s: replace_nnx_model_params(s, new_outer_params),
289+
state.inner_state,
290+
mesh=mesh,
291+
)
292+
else:
293+
new_inner_state = drjax.map_fn(lambda s: s.replace(params=new_outer_params), state.inner_state, mesh=mesh)
257294
return state.replace(
258295
params=new_outer_params,
259296
outer_opt_state=new_opt_state,
@@ -271,14 +308,16 @@ def diloco_train_step(state, batch, prng):
271308
broadcast_rng = drjax.broadcast(prng, mesh=mesh)
272309
inner_state, metrics = drjax.map_fn(train_step, (state.inner_state, batch, broadcast_rng), mesh=mesh)
273310
avg_metrics = typed_reduce_mean(metrics)
311+
# For NNX, the step counter lives at inner_state.optimizer.step; for Linen at inner_state.step.
312+
new_step = inner_state.optimizer.step[0] if config.pure_nnx else inner_state.step[0]
274313
state = state.replace(
275314
inner_state=inner_state,
276-
step=inner_state.step[0],
315+
step=new_step,
277316
)
278317
# Either synchronize the model, or no-op, depending on whether the current
279318
# step falls on the synchronization period.
280319
state = jax.lax.cond(
281-
inner_state.step[0] % config.diloco_sync_period == 0,
320+
new_step % config.diloco_sync_period == 0,
282321
synchronize,
283322
lambda x: x, # no-op
284323
state,

src/maxtext/trainers/pre_train/train_compile.py

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from flax import nnx
3131
from flax.linen import partitioning as nn_partitioning
3232
import jax
33+
import jax.numpy as jnp
3334
from jax.experimental.serialize_executable import serialize
3435
from jax.experimental.topologies import get_topology_desc
3536
from jax.sharding import AxisType, Mesh
@@ -93,6 +94,27 @@ def get_topology_mesh(config):
9394
return topology_mesh
9495

9596

97+
def _collect_nnx_activation_shardings(create_model_fn, config, mesh):
98+
"""Run an NNX forward pass in abstract mode to populate _ACTIVATION_SHARDINGS_DUMP.
99+
100+
get_abstract_state_nnx uses nnx.eval_shape which only traces model initialization,
101+
not __call__. Activation shardings are only collected during a forward pass.
102+
"""
103+
input_shape = (config.micro_batch_size_to_train_on, config.max_target_length)
104+
105+
def _nnx_forward():
106+
model_instance = create_model_fn()
107+
return model_instance(
108+
decoder_input_tokens=jnp.ones(input_shape, dtype=jnp.int32),
109+
decoder_positions=jnp.ones(input_shape, dtype=jnp.int32),
110+
decoder_segment_ids=jnp.ones(input_shape, dtype=jnp.int32),
111+
enable_dropout=False,
112+
)
113+
114+
with nn_partitioning.axis_rules(config.logical_axis_rules):
115+
jax.eval_shape(_nnx_forward)
116+
117+
96118
def get_shaped_inputs(topology_mesh, config):
97119
"""Get shaped abstractions of inputs to train_step: state, batch and rng"""
98120
# Construct the model and optimizer to get shaped versions of the state
@@ -140,10 +162,17 @@ def create_train_state_fn():
140162
shaped_batch = maxtext_utils.get_shaped_batch(config)
141163

142164
if config.pure_nnx:
143-
shaped_train_args = (abstract_state, shaped_batch, None) # NNX doesn't use dropout_rng
165+
shaped_train_args = (abstract_state, shaped_batch) # NNX doesn't use dropout_rng
144166
else:
145167
shaped_train_args = (abstract_state, shaped_batch, shaped_rng)
146168
shaped_train_kwargs = {}
169+
170+
# Collect activation shardings for NNX by running an abstract forward pass.
171+
# This must happen after get_abstract_state (which uses nnx.eval_shape and only
172+
# traces __init__, not __call__).
173+
if config.debug_sharding and config.pure_nnx:
174+
_collect_nnx_activation_shardings(_create_model_partial, config, topology_mesh)
175+
147176
return shaped_train_args, shaped_train_kwargs, state_mesh_shardings, logical_annotations, model
148177

149178

@@ -279,7 +308,9 @@ def main(argv: Sequence[str]) -> None:
279308
diloco_state, state_mesh_shardings, inner_state_shardings = diloco.build_abstract_diloco_state(
280309
config, abstract_state, state_mesh_shardings, topology_mesh
281310
)
282-
shaped_train_args = (diloco_state, shaped_train_args[1], shaped_train_args[2])
311+
# For NNX, shaped_train_args has 2 elements (state, batch) — no rng; pass None for prng.
312+
shaped_rng_arg = shaped_train_args[2] if len(shaped_train_args) > 2 else None
313+
shaped_train_args = (diloco_state, shaped_train_args[1], shaped_rng_arg)
283314

284315
# Wrap train_step with diloco
285316
train_step_partial = functools.partial(train.train_step, model, config, inner_state_shardings, None)

0 commit comments

Comments
 (0)