Skip to content

Commit 0176830

Browse files
committed
Set NNX flags to true by default
1 parent 60f0e87 commit 0176830

55 files changed

Lines changed: 17851 additions & 5145 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

src/maxtext/configs/base.yml

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -559,10 +559,17 @@ logical_axis_rules: [
559559
['tokens_per_page', []],
560560
['paged_kv_head_dim_size', []],
561561
# ==========================================
562+
# Pipeline Parallelism
563+
# ==========================================
564+
['layers_outside_pipeline', []],
565+
['layers_per_stage', []],
566+
['num_activations', []],
567+
['circular_repeats', []],
568+
# ==========================================
562569
# Deprecated / Scheduled for Removal
563570
# ==========================================
564-
['mlp_no_fsdp', ['tensor', 'tensor_sequence', 'autoregressive']],
565-
['embed_tensor_transpose', ['tensor_transpose']],
571+
['mlp_no_fsdp', ['tensor', 'tensor_sequence', 'autoregressive']],
572+
['embed_tensor_transpose', ['tensor_transpose']],
566573
['exp_with_fsdp', 'fsdp'],
567574
]
568575
# Axes used for DCN must be earlier in this list than ICI, see (b/339009148) for details
@@ -1160,7 +1167,7 @@ subslice_shape: ""
11601167
# NNX
11611168
enable_nnx: True
11621169
pure_nnx_decoder: True
1163-
pure_nnx: False
1170+
pure_nnx: True
11641171
use_nnx_pipeline: False # Set to False to use native Linen pipeline (with custom VJP)
11651172

11661173

src/maxtext/configs/custom_mesh_and_rule/pipeline-large-moe.yml

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,4 +71,32 @@ logical_axis_rules: [
7171
['exp_with_fsdp', 'fsdp'],
7272
['paged_kv_heads', ['tensor']],
7373
['engram_dim', ['tensor']],
74+
# Axes unsharded: sequence/context/tensor_transpose/autoregressive do not exist in this mesh
75+
['activation_attn_length_no_exp', []],
76+
['activation_length_no_exp', []],
77+
['activation_norm_length', []],
78+
['activation_q_length_no_exp', []],
79+
['prefill_activation_length', []],
80+
['prefill_activation_norm_length', []],
81+
['activation_kv_length', []],
82+
['decode_length', []],
83+
['embed_tensor_transpose', []],
84+
['q_lora_up_proj', []],
85+
['kv_lora_up_proj', []],
86+
['kv', []],
87+
['qkv', []],
88+
['kv_head_dim', []],
89+
['cache_batch_prefill', []],
90+
['cache_batch', []],
91+
['cache_heads_none', []],
92+
['cache_kv', []],
93+
['cache_sequence', []],
94+
['num_pages', []],
95+
['tokens_per_page', []],
96+
['paged_kv_head_dim_size', []],
97+
['dense_layers', []],
98+
['moe_layers', []],
99+
['num_activations', []],
100+
['mhc', []],
101+
['diloco', []],
74102
]

src/maxtext/configs/custom_mesh_and_rule/pure-fsdp.yml

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,4 +31,57 @@ logical_axis_rules: [
3131
['q_lora', ['fsdp']],
3232
['kv_lora', ['fsdp']],
3333
['exp_with_fsdp', 'fsdp'],
34+
# All other axes are unsharded (tensor/sequence/expert axes do not exist in pure-fsdp)
35+
['activation_heads', []],
36+
['activation_kv_heads', []],
37+
['activation_length', []],
38+
['activation_attn_length', []],
39+
['activation_attn_length_no_exp', []],
40+
['activation_length_no_exp', []],
41+
['activation_norm_length', []],
42+
['activation_q_length', []],
43+
['activation_q_length_no_exp', []],
44+
['prefill_activation_length', []],
45+
['prefill_activation_norm_length', []],
46+
['activation_kv_length', []],
47+
['activation_attn_embed', []],
48+
['activation_embed', []],
49+
['activation_mlp', []],
50+
['activation_kv', []],
51+
['activation_kv_head_dim', []],
52+
['activation_vocab', []],
53+
['activation_stage', []],
54+
['activation_exp', []],
55+
['decode_length', []],
56+
['mlp', []],
57+
['mlp_no_fsdp', []],
58+
['vocab', []],
59+
['heads', []],
60+
['q_heads', []],
61+
['kv_heads', []],
62+
['embed_tensor_transpose', []],
63+
['q_lora_up_proj', []],
64+
['kv_lora_up_proj', []],
65+
['norm', []],
66+
['layers', []],
67+
['qkv', []],
68+
['kv', []],
69+
['kv_head_dim', []],
70+
['cache_batch_prefill', []],
71+
['cache_batch', []],
72+
['cache_heads_none', []],
73+
['cache_heads', []],
74+
['cache_kv', []],
75+
['cache_sequence', []],
76+
['exp', []],
77+
['paged_kv_heads', []],
78+
['num_pages', []],
79+
['tokens_per_page', []],
80+
['paged_kv_head_dim_size', []],
81+
['dense_layers', []],
82+
['moe_layers', []],
83+
['num_activations', []],
84+
['engram_dim', []],
85+
['mhc', []],
86+
['diloco', []],
3487
]

src/maxtext/configs/decoupled_base_test.yml

Lines changed: 3 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# Decoupled base test config: used when DECOUPLE_GCLOUD=TRUE for tests that previously relied on base.yml.
2-
# Inherit all model defaults (PyDantic already does this) but override any cloud-coupled paths and disable
3-
# optional cloud features.
2+
# Inherits from base.yml so that logical_axis_rules, mesh_axes, NNX flags, and all other
3+
# model defaults are kept in sync. Overrides only cloud-coupled paths and optional cloud features.
4+
base_config: base.yml
45

56
# Output goes to a local relative directory so tests do not require GCS.
67
base_output_directory: ./maxtext_local_output/gcloud_decoupled_test_logs
@@ -34,34 +35,9 @@ attention: "dot_product"
3435
dump_hlo: false
3536
jax_cache_dir: ""
3637

37-
# Neutral parallelism (single device) for local tests.
38-
ici_data_parallelism: 1
39-
ici_tensor_parallelism: 1
40-
ici_pipeline_parallelism: 1
41-
ici_expert_parallelism: 1
42-
ici_sequence_parallelism: 1
43-
ici_context_parallelism: 1
44-
ici_tensor_transpose_parallelism: 1
45-
ici_tensor_sequence_parallelism: 1
46-
ici_autoregressive_parallelism: 1
47-
ici_fsdp_parallelism: 1
48-
ici_fsdp_transpose_parallelism: 1
4938
# Allow higher unsharded parameter percentage for small device count
5039
sharding_tolerance: 0.3
5140

52-
# DCN dimensions to 1 (no multi-slice expectation locally).
53-
dcn_data_parallelism: 1
54-
dcn_tensor_parallelism: 1
55-
dcn_pipeline_parallelism: 1
56-
dcn_expert_parallelism: 1
57-
dcn_sequence_parallelism: 1
58-
dcn_context_parallelism: 1
59-
dcn_tensor_transpose_parallelism: 1
60-
dcn_tensor_sequence_parallelism: 1
61-
dcn_autoregressive_parallelism: 1
62-
dcn_fsdp_parallelism: 1
63-
dcn_fsdp_transpose_parallelism: 1
64-
6541
# Config logging off unless a test overrides.
6642
log_config: false
6743

src/maxtext/layers/attentions.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -525,14 +525,14 @@ def __init__(
525525
elif self.is_qwen3_next:
526526
self.query_norm = Qwen3NextRMSNorm(
527527
num_features=self.config.head_dim,
528-
eps=self.config.normalization_layer_epsilon,
528+
epsilon=self.config.normalization_layer_epsilon,
529529
dtype=self.config.dtype,
530530
weight_dtype=self.config.weight_dtype,
531531
rngs=self.rngs,
532532
)
533533
self.key_norm = Qwen3NextRMSNorm(
534534
num_features=self.config.head_dim,
535-
eps=self.config.normalization_layer_epsilon,
535+
epsilon=self.config.normalization_layer_epsilon,
536536
dtype=self.config.dtype,
537537
weight_dtype=self.config.weight_dtype,
538538
rngs=self.rngs,

src/maxtext/layers/moe.py

Lines changed: 8 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -403,16 +403,9 @@ def __init__(
403403
if rule is not None:
404404
if not isinstance(rule, qwix.QtRule):
405405
raise ValueError("Expect a QtRule for quantized training.")
406-
if (
407-
rule.additional_qt_config
408-
and "sparsity_rule" in rule.additional_qt_config
409-
):
406+
if rule.additional_qt_config and "sparsity_rule" in rule.additional_qt_config:
410407
q_s_rule = rule.additional_qt_config["sparsity_rule"]
411-
if (
412-
q_s_rule
413-
and q_s_rule.weight_sparsity_n
414-
and q_s_rule.weight_sparsity_m
415-
):
408+
if q_s_rule and q_s_rule.weight_sparsity_n and q_s_rule.weight_sparsity_m:
416409
sparsity_rule = q_s_rule
417410

418411
if sparsity_rule is not None:
@@ -1064,8 +1057,7 @@ def jax_ragged_dot_gmm(inputs, kernel, tiling, group_sizes, expert_assignments,
10641057
def get_tokamax_group_sizes(group_sizes, inputs, kernel):
10651058
# TODO (b/491979205) pipeline fsdp ag per repeat fails tokamax gmm
10661059
if self.config.use_qwix_quantization or (
1067-
self.config.using_pipeline_parallelism
1068-
and self.config.pipeline_fsdp_ag_per_repeat
1060+
self.config.using_pipeline_parallelism and self.config.pipeline_fsdp_ag_per_repeat
10691061
):
10701062
return group_sizes
10711063
elif self.config.attention == "vllm_rpa":
@@ -2190,19 +2182,13 @@ def __call__(
21902182
w0_kernel = jnp.asarray(self.wi_0[...], self.dtype)
21912183
w1_kernel = jnp.asarray(self.wi_1[...], self.dtype)
21922184

2193-
if self.per_expert_scale is not None:
2194-
wo_kernel = wo_kernel * jnp.asarray(self.per_expert_scale[...], self.dtype)[:, None, None]
2185+
if self.per_expert_scale is not None:
2186+
wo_kernel = wo_kernel * jnp.asarray(self.per_expert_scale[...], self.dtype)[:, None, None]
21952187

21962188
if self.wi_0_sparsity_module is not None:
2197-
_, w0_kernel = self.wi_0_sparsity_module(
2198-
jnp.zeros_like(w0_kernel), w0_kernel
2199-
)
2200-
_, w1_kernel = self.wi_1_sparsity_module(
2201-
jnp.zeros_like(w1_kernel), w1_kernel
2202-
)
2203-
_, wo_kernel = self.wo_sparsity_module(
2204-
jnp.zeros_like(wo_kernel), wo_kernel
2205-
)
2189+
_, w0_kernel = self.wi_0_sparsity_module(jnp.zeros_like(w0_kernel), w0_kernel)
2190+
_, w1_kernel = self.wi_1_sparsity_module(jnp.zeros_like(w1_kernel), w1_kernel)
2191+
_, wo_kernel = self.wo_sparsity_module(jnp.zeros_like(wo_kernel), wo_kernel)
22062192
if cfg.mlp_bias:
22072193
w0_bias = jnp.asarray(self.wi_0_bias[...], self.dtype)
22082194
w1_bias = jnp.asarray(self.wi_1_bias[...], self.dtype)

src/maxtext/layers/nnx_decoders.py

Lines changed: 76 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,82 @@ def deepstack_process(hidden_states, bidirectional_mask, visual_embeds):
241241
return hidden_states
242242

243243

244+
class NNXSequentialPipelineStage(nnx.Module):
245+
"""Sequential unscanned series of decoder layers formatted for a single pipeline stage."""
246+
247+
def __init__(
248+
self, layer_cls, num_layers: int, config: Config, mesh: Mesh, quant: Quant, model_mode: str, *, rngs: nnx.Rngs
249+
):
250+
self.config = config
251+
self.scan_layers = config.scan_layers
252+
self.num_layers = num_layers
253+
for i in range(num_layers):
254+
layer = layer_cls(config=config, mesh=mesh, quant=quant, model_mode=model_mode, rngs=rngs)
255+
setattr(self, f"layers_{i}", layer)
256+
257+
def __call__(self, inputs, decoder_segment_ids, decoder_positions, deterministic, model_mode, **kwargs):
258+
for i in range(self.num_layers):
259+
layer = getattr(self, f"layers_{i}")
260+
out = layer(inputs, decoder_segment_ids, decoder_positions, deterministic, model_mode, **kwargs)
261+
inputs = out[0] if isinstance(out, tuple) else out
262+
if self.scan_layers:
263+
return inputs, None
264+
return inputs
265+
266+
267+
class NNXScannedPipelineStage(nnx.Module):
268+
"""Scanned block of decoder layers formatted for a single pipeline stage."""
269+
270+
def __init__(
271+
self, layer_cls, num_layers: int, config: Config, mesh: Mesh, quant: Quant, model_mode: str, *, rngs: nnx.Rngs
272+
):
273+
self.config = config
274+
275+
def create_layer_fn(rng):
276+
return layer_cls(config=config, mesh=mesh, quant=quant, model_mode=model_mode, rngs=rng)
277+
278+
try:
279+
forked_rngs = rngs.fork(split=num_layers)
280+
except: # pylint: disable=bare-except
281+
forked_rngs = rngs
282+
283+
out_axes = nnx.StateAxes({nnx.Param: config.param_scan_axis, ...: 0})
284+
self.scanned_layers = nnx.vmap(
285+
create_layer_fn,
286+
in_axes=0,
287+
out_axes=out_axes,
288+
axis_name="layers_per_stage",
289+
transform_metadata={nnx.PARTITION_NAME: "layers_per_stage"},
290+
)(forked_rngs)
291+
292+
def __call__(self, inputs, decoder_segment_ids, decoder_positions, deterministic, model_mode, **kwargs):
293+
graphdef, params, state = nnx.split(self.scanned_layers, nnx.Param, ...)
294+
295+
scan_axis = self.config.param_scan_axis
296+
if scan_axis != 0:
297+
params = jax.tree.map(lambda x: jnp.moveaxis(x, scan_axis, 0), params)
298+
299+
def layer_fn(carry, scanned_vars):
300+
current_params, current_state = scanned_vars
301+
layer = nnx.merge(graphdef, current_params, current_state)
302+
layer_out = layer(carry, decoder_segment_ids, decoder_positions, deterministic, model_mode, **kwargs)
303+
new_carry = layer_out[0] if isinstance(layer_out, tuple) else layer_out
304+
return new_carry, nnx.state(layer)
305+
306+
final_carry, scanned_state = jax.lax.scan(layer_fn, inputs, (params, state))
307+
308+
if scan_axis != 0:
309+
scanned_params, scanned_other = scanned_state.split(nnx.Param, ...)
310+
scanned_params = jax.tree.map(lambda x: jnp.moveaxis(x, 0, scan_axis), scanned_params)
311+
scanned_state = nnx.State.merge(scanned_params, scanned_other)
312+
313+
nnx.update(self.scanned_layers, scanned_state)
314+
315+
if self.config.scan_layers:
316+
return final_carry, None
317+
return final_carry
318+
319+
244320
class NNXDecoder(nnx.Module):
245321
"""A stack of decoder layers as a part of an encoder-decoder architecture, using NNX."""
246322

@@ -992,7 +1068,6 @@ def __call__(
9921068
previous_chunk=None,
9931069
slot: None | int = None,
9941070
page_state: None | page_manager.PageState = None,
995-
multimodal_input: None | Any = None,
9961071
kv_caches: list[jax.Array] | None = None,
9971072
attention_metadata=None,
9981073
deepstack_visual_embeds: None | list[jnp.ndarray] = None,

src/maxtext/layers/nnx_wrappers.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from flax.core import FrozenDict
2727
from flax.core import meta
2828
from flax.nnx import graph
29+
from flax.nnx import tracers as nnx_tracers
2930
from flax.nnx import variablelib
3031
from flax.nnx.bridge import module as bdg_module
3132
from flax.nnx.module import Module
@@ -186,6 +187,23 @@ def is_linen_initializing() -> bool:
186187
return False
187188

188189

190+
def _refresh_variable_trace_state(module: Module) -> None:
191+
"""Refresh _trace_state for Variables that have stale trace state.
192+
193+
When nnx.update() is called with tracer values from a JAX transformation
194+
(e.g. jax.grad's LinearizeTracer), it uses _unsafe_bypass_check=True which
195+
updates the raw value but not _trace_state. This leaves Variables with a
196+
stale _trace_state from the outer (Python) context, causing nnx.split() to
197+
fail with "Cannot extract graph node from different trace level" errors.
198+
199+
This function resets _trace_state on any Variables whose _can_update is False
200+
so that downstream NNX operations (e.g. nnx.split in NNXPipeline) succeed.
201+
"""
202+
for _, v in nnx.graph.iter_graph(module):
203+
if isinstance(v, variablelib.Variable) and not v._can_update: # pylint: disable=protected-access
204+
object.__setattr__(v, "_trace_state", nnx_tracers.TraceState())
205+
206+
189207
class ToNNX(Module):
190208
"""A wrapper to turn any Linen module into an NNX module.
191209
@@ -483,6 +501,7 @@ def maybe_unbox(x):
483501
warnings.warn(f"Found unknown module paths in incoming state:{paths_str}")
484502

485503
nnx.update(module, new_state)
504+
_refresh_variable_trace_state(module)
486505

487506
_fix_for_qwix_quantization(module)
488507
method_fn = _get_module_method(module, nnx_method)

src/maxtext/layers/normalizations.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,17 @@ def __call__(self, x: jnp.ndarray, out_sharding: NamedSharding | None = None) ->
114114
return y_flat.reshape(input_shape)
115115

116116

117-
def Qwen3NextRMSNorm(num_features: int, eps: float, dtype: DType, weight_dtype: DType, *, rngs: nnx.Rngs):
117+
def Qwen3NextRMSNorm(
118+
num_features: int,
119+
epsilon: float = 1e-6,
120+
dtype: DType = None,
121+
weight_dtype: DType = None,
122+
shard_mode=None,
123+
kernel_axes=None,
124+
parameter_memory_host_offload=None,
125+
*,
126+
rngs: nnx.Rngs,
127+
):
118128
"""
119129
Used for input and post attention layernorms
120130
in Qwen3NextDecoderLayer.
@@ -127,7 +137,7 @@ def Qwen3NextRMSNorm(num_features: int, eps: float, dtype: DType, weight_dtype:
127137
return nnx.data(
128138
RMSNorm(
129139
num_features=num_features,
130-
epsilon=eps,
140+
epsilon=epsilon,
131141
dtype=dtype,
132142
weight_dtype=weight_dtype,
133143
scale_init=linen_initializers.zeros,

0 commit comments

Comments
 (0)