Skip to content

Commit ffeefdd

Browse files
committed
Set NNX flags to true by default
1 parent dd7f895 commit ffeefdd

91 files changed

Lines changed: 48684 additions & 13468 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

src/maxtext/configs/base.yml

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -542,9 +542,13 @@ logical_axis_rules: [
542542
['paged_kv_head_dim_size', []],
543543
['dense_layers', []],
544544
['moe_layers', []],
545+
['layers_outside_pipeline', []],
546+
['layers_per_stage', []],
545547
['engram_dim', ['tensor']],
546548
['mhc', []],
547549
['diloco', 'diloco'],
550+
['num_activations', []],
551+
['circular_repeats', []],
548552
]
549553
# Axes used for DCN must be earlier in this list than ICI, see (b/339009148) for details
550554
data_sharding: [['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'context', 'context_autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence', 'expert', 'autoregressive']]
@@ -1133,9 +1137,9 @@ position_id_per_seconds: 25
11331137
subslice_shape: ""
11341138

11351139
# NNX
1136-
enable_nnx: False
1137-
pure_nnx_decoder: False
1138-
pure_nnx: False
1140+
enable_nnx: True
1141+
pure_nnx_decoder: True
1142+
pure_nnx: True
11391143

11401144
################################## Qwen3-Next Specific Configs ##################################
11411145
# Kernel size for the 1D convolution in the Gated Delta Net

src/maxtext/configs/custom_mesh_and_rule/pipeline-large-moe.yml

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,4 +70,32 @@ logical_axis_rules: [
7070
['exp_with_fsdp', 'fsdp'],
7171
['paged_kv_heads', ['tensor']],
7272
['engram_dim', ['tensor']],
73+
# Axes unsharded: sequence/context/tensor_transpose/autoregressive do not exist in this mesh
74+
['activation_attn_length_no_exp', []],
75+
['activation_length_no_exp', []],
76+
['activation_norm_length', []],
77+
['activation_q_length_no_exp', []],
78+
['prefill_activation_length', []],
79+
['prefill_activation_norm_length', []],
80+
['activation_kv_length', []],
81+
['decode_length', []],
82+
['embed_tensor_transpose', []],
83+
['q_lora_up_proj', []],
84+
['kv_lora_up_proj', []],
85+
['kv', []],
86+
['qkv', []],
87+
['kv_head_dim', []],
88+
['cache_batch_prefill', []],
89+
['cache_batch', []],
90+
['cache_heads_none', []],
91+
['cache_kv', []],
92+
['cache_sequence', []],
93+
['num_pages', []],
94+
['tokens_per_page', []],
95+
['paged_kv_head_dim_size', []],
96+
['dense_layers', []],
97+
['moe_layers', []],
98+
['num_activations', []],
99+
['mhc', []],
100+
['diloco', []],
73101
]

src/maxtext/configs/custom_mesh_and_rule/pure-fsdp.yml

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,4 +32,57 @@ logical_axis_rules: [
3232
['q_lora', ['fsdp']],
3333
['kv_lora', ['fsdp']],
3434
['exp_with_fsdp', 'fsdp'],
35+
# All other axes are unsharded (tensor/sequence/expert axes do not exist in pure-fsdp)
36+
['activation_heads', []],
37+
['activation_kv_heads', []],
38+
['activation_length', []],
39+
['activation_attn_length', []],
40+
['activation_attn_length_no_exp', []],
41+
['activation_length_no_exp', []],
42+
['activation_norm_length', []],
43+
['activation_q_length', []],
44+
['activation_q_length_no_exp', []],
45+
['prefill_activation_length', []],
46+
['prefill_activation_norm_length', []],
47+
['activation_kv_length', []],
48+
['activation_attn_embed', []],
49+
['activation_embed', []],
50+
['activation_mlp', []],
51+
['activation_kv', []],
52+
['activation_kv_head_dim', []],
53+
['activation_vocab', []],
54+
['activation_stage', []],
55+
['activation_exp', []],
56+
['decode_length', []],
57+
['mlp', []],
58+
['mlp_no_fsdp', []],
59+
['vocab', []],
60+
['heads', []],
61+
['q_heads', []],
62+
['kv_heads', []],
63+
['embed_tensor_transpose', []],
64+
['q_lora_up_proj', []],
65+
['kv_lora_up_proj', []],
66+
['norm', []],
67+
['layers', []],
68+
['qkv', []],
69+
['kv', []],
70+
['kv_head_dim', []],
71+
['cache_batch_prefill', []],
72+
['cache_batch', []],
73+
['cache_heads_none', []],
74+
['cache_heads', []],
75+
['cache_kv', []],
76+
['cache_sequence', []],
77+
['exp', []],
78+
['paged_kv_heads', []],
79+
['num_pages', []],
80+
['tokens_per_page', []],
81+
['paged_kv_head_dim_size', []],
82+
['dense_layers', []],
83+
['moe_layers', []],
84+
['num_activations', []],
85+
['engram_dim', []],
86+
['mhc', []],
87+
['diloco', []],
3588
]

src/maxtext/configs/decoupled_base_test.yml

Lines changed: 3 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# Decoupled base test config: used when DECOUPLE_GCLOUD=TRUE for tests that previously relied on base.yml.
2-
# Inherit all model defaults (PyDantic already does this) but override any cloud-coupled paths and disable
3-
# optional cloud features.
2+
# Inherits from base.yml so that logical_axis_rules, mesh_axes, NNX flags, and all other
3+
# model defaults are kept in sync. Overrides only cloud-coupled paths and optional cloud features.
4+
base_config: base.yml
45

56
# Output goes to a local relative directory so tests do not require GCS.
67
base_output_directory: ./maxtext_local_output/gcloud_decoupled_test_logs
@@ -34,34 +35,9 @@ attention: "dot_product"
3435
dump_hlo: false
3536
jax_cache_dir: ""
3637

37-
# Neutral parallelism (single device) for local tests.
38-
ici_data_parallelism: 1
39-
ici_tensor_parallelism: 1
40-
ici_pipeline_parallelism: 1
41-
ici_expert_parallelism: 1
42-
ici_sequence_parallelism: 1
43-
ici_context_parallelism: 1
44-
ici_tensor_transpose_parallelism: 1
45-
ici_tensor_sequence_parallelism: 1
46-
ici_autoregressive_parallelism: 1
47-
ici_fsdp_parallelism: 1
48-
ici_fsdp_transpose_parallelism: 1
4938
# Allow higher unsharded parameter percentage for small device count
5039
sharding_tolerance: 0.3
5140

52-
# DCN dimensions to 1 (no multi-slice expectation locally).
53-
dcn_data_parallelism: 1
54-
dcn_tensor_parallelism: 1
55-
dcn_pipeline_parallelism: 1
56-
dcn_expert_parallelism: 1
57-
dcn_sequence_parallelism: 1
58-
dcn_context_parallelism: 1
59-
dcn_tensor_transpose_parallelism: 1
60-
dcn_tensor_sequence_parallelism: 1
61-
dcn_autoregressive_parallelism: 1
62-
dcn_fsdp_parallelism: 1
63-
dcn_fsdp_transpose_parallelism: 1
64-
6541
# Config logging off unless a test overrides.
6642
log_config: false
6743

src/maxtext/layers/attentions.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -525,14 +525,14 @@ def __init__(
525525
elif self.is_qwen3_next:
526526
self.query_norm = Qwen3NextRMSNorm(
527527
num_features=self.config.head_dim,
528-
eps=self.config.normalization_layer_epsilon,
528+
epsilon=self.config.normalization_layer_epsilon,
529529
dtype=self.config.dtype,
530530
weight_dtype=self.config.weight_dtype,
531531
rngs=self.rngs,
532532
)
533533
self.key_norm = Qwen3NextRMSNorm(
534534
num_features=self.config.head_dim,
535-
eps=self.config.normalization_layer_epsilon,
535+
epsilon=self.config.normalization_layer_epsilon,
536536
dtype=self.config.dtype,
537537
weight_dtype=self.config.weight_dtype,
538538
rngs=self.rngs,

src/maxtext/layers/moe.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2041,15 +2041,15 @@ def __call__(
20412041
w1_kernel = jnp.asarray(self.wi_1[...], self.dtype)
20422042
wo_kernel = jnp.asarray(self.wo[...], self.dtype)
20432043

2044-
if self.per_expert_scale is not None:
2045-
wo_kernel = wo_kernel * jnp.asarray(self.per_expert_scale[...], self.dtype)[:, None, None]
2044+
if self.per_expert_scale is not None:
2045+
wo_kernel = wo_kernel * jnp.asarray(self.per_expert_scale[...], self.dtype)[:, None, None]
20462046

2047-
if cfg.mlp_bias:
2048-
w0_bias = jnp.asarray(self.wi_0_bias[...], self.dtype)
2049-
w1_bias = jnp.asarray(self.wi_1_bias[...], self.dtype)
2050-
wo_bias = jnp.asarray(self.wo_bias[...], self.dtype)
2051-
else:
2052-
w0_bias, w1_bias, wo_bias = None, None, None
2047+
if cfg.mlp_bias:
2048+
w0_bias = jnp.asarray(self.wi_0_bias[...], self.dtype)
2049+
w1_bias = jnp.asarray(self.wi_1_bias[...], self.dtype)
2050+
wo_bias = jnp.asarray(self.wo_bias[...], self.dtype)
2051+
else:
2052+
w0_bias, w1_bias, wo_bias = None, None, None
20532053

20542054
if cfg.sparse_matmul:
20552055
if quantizations.in_serve_mode(self.quant):

src/maxtext/layers/nnx_decoders.py

Lines changed: 51 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -303,6 +303,7 @@ def layer_fn(carry, scanned_vars):
303303
layer = nnx.merge(graphdef, current_params, current_state)
304304
layer_out = layer(carry, decoder_segment_ids, decoder_positions, deterministic, model_mode, **kwargs)
305305
new_carry = layer_out[0] if isinstance(layer_out, tuple) else layer_out
306+
nnx.pop(layer, nnx.Intermediate)
306307
return new_carry, nnx.state(layer)
307308

308309
final_carry, scanned_state = jax.lax.scan(layer_fn, inputs, (params, state))
@@ -534,6 +535,8 @@ def _create_scanned_layers(
534535
self, decoder_layer_class, length: int, metadata_axis_name: str, rngs: nnx.Rngs, **layer_kwargs
535536
):
536537
"""Creates a VMapped stack of layers, forcing parameter init for Compact modules."""
538+
if length == 0:
539+
return nnx.List([])
537540

538541
def create_layer_fn(rng):
539542
return decoder_layer_class(
@@ -566,13 +569,17 @@ def pure_layer_fn(state_in, y_in):
566569
out = merged_layer(y_in, **kwargs)
567570
return out, nnx.state(merged_layer)
568571

569-
checkpointed_fn = jax.checkpoint(pure_layer_fn, policy=policy, prevent_cse=prevent_cse)
570-
out, new_state = checkpointed_fn(state, y)
572+
if not self._uses_linen_fp8_ops():
573+
pure_layer_fn = jax.checkpoint(pure_layer_fn, policy=policy, prevent_cse=prevent_cse)
574+
out, new_state = pure_layer_fn(state, y)
571575
nnx.update(layer, new_state)
572576
return out
573577

574578
def _apply_layers_sequentially(self, layers, x_in, *args, length: int, **kwargs):
575579
"""Runs the layer stack using nnx.scan."""
580+
if length == 0:
581+
_, empty_state = nnx.split(layers)
582+
return x_in, empty_state
576583
policy = self.get_remat_policy()
577584
prevent_cse = maxtext_utils.should_prevent_cse_in_remat(self.config)
578585
graphdef, params, state = nnx.split(
@@ -608,7 +615,25 @@ def layer_fn(carry, scanned_vars):
608615
# Run the layer (Filter kwargs if using the solution from previous turn)
609616
layer_out = layer(carry, *args, **valid_kwargs)
610617
new_carry = layer_out[0] if isinstance(layer_out, tuple) else layer_out
611-
return new_carry, nnx.state(layer)
618+
nnx.pop(layer, nnx.Intermediate)
619+
new_current_state = nnx.state(layer)
620+
return new_carry, new_current_state
621+
622+
if self._uses_linen_fp8_ops():
623+
# jax.lax.scan is incompatible with Linen fp8 ops: put_variable in setup() stores
624+
# scan-level tracers as Python attributes on the Linen module, causing a tracer leak
625+
# across the scan boundary. Fall back to a Python loop instead.
626+
x = x_in
627+
for i in range(length):
628+
params_i = jax.tree.map(lambda p, _i=i: p[_i], params)
629+
state_i = jax.tree.map(lambda s, _i=i: s[_i], state)
630+
layer = nnx.merge(graphdef, params_i, state_i)
631+
layer_out = layer(x, *args, **valid_kwargs)
632+
x = layer_out[0] if isinstance(layer_out, tuple) else layer_out
633+
nnx.pop(layer, nnx.Intermediate)
634+
if scan_axis != 0:
635+
params = jax.tree.map(lambda p: jnp.moveaxis(p, 0, scan_axis), params)
636+
return x, nnx.State.merge(params, state)
612637

613638
layer_fn = jax.checkpoint(layer_fn, policy=policy, prevent_cse=prevent_cse)
614639
final_carry, scanned_state = jax.lax.scan(layer_fn, x_in, (params, state))
@@ -672,7 +697,8 @@ def get_chunk(pytree, start, end):
672697
layer_out = layer(y, *layer_args, **valid_kwargs)
673698
y = layer_out[0] if isinstance(layer_out, tuple) else layer_out
674699

675-
_, new_eng_mutables = nnx.split(layer, nnx.Param, ...)
700+
nnx.pop(layer, nnx.Intermediate)
701+
_, _, new_eng_mutables = nnx.split(layer, nnx.Param, ...)
676702
new_eng_mutables = jax.tree.map(lambda x: jnp.expand_dims(x, axis=0), new_eng_mutables)
677703
updated_mutables_chunks.append(new_eng_mutables)
678704
current_idx += 1
@@ -698,10 +724,12 @@ def layer_fn(carry, scanned_vars):
698724
l = nnx.merge(graphdef, curr_p, curr_m)
699725
l_out = l(carry, *layer_args, **valid_kwargs)
700726
n_carry = l_out[0] if isinstance(l_out, tuple) else l_out
701-
_, n_mut = nnx.split(l, nnx.Param, ...)
727+
nnx.pop(l, nnx.Intermediate)
728+
_, _, n_mut = nnx.split(l, nnx.Param, ...)
702729
return n_carry, n_mut
703730

704-
layer_fn = jax.checkpoint(layer_fn, policy=policy, prevent_cse=prevent_cse)
731+
if not self._uses_linen_fp8_ops():
732+
layer_fn = jax.checkpoint(layer_fn, policy=policy, prevent_cse=prevent_cse)
705733
y, new_chunk_mutables = jax.lax.scan(layer_fn, y, (chunk_params, chunk_mutables))
706734
updated_mutables_chunks.append(new_chunk_mutables)
707735
current_idx = next_boundary
@@ -742,7 +770,11 @@ def pure_layer_fn(graphdef, state_in, y_in, kv_in, dynamic_kwargs):
742770
out_y, out_kv = merged_layer(y_in, *layer_args, kv_cache=kv_in, **dynamic_kwargs)
743771
return out_y, out_kv, nnx.state(merged_layer)
744772

745-
checkpointed_fn = jax.checkpoint(pure_layer_fn, policy=policy, prevent_cse=prevent_cse)
773+
checkpointed_fn = (
774+
pure_layer_fn
775+
if self._uses_linen_fp8_ops()
776+
else jax.checkpoint(pure_layer_fn, policy=policy, prevent_cse=prevent_cse)
777+
)
746778

747779
for lyr in range(num_layers):
748780
attr_name = f"{base_name}_{lyr}"
@@ -921,6 +953,10 @@ def get_remat_policy(self):
921953
assert cfg.remat_policy == "full", "Remat policy needs to be on list of remat policies"
922954
return policy
923955

956+
def _uses_linen_fp8_ops(self) -> bool:
957+
"""Returns True if the quantization mode uses Linen fp8 ops incompatible with jax.checkpoint."""
958+
return self.config.quantization in ("fp8_gpu", "fp8_nanoo")
959+
924960
def get_norm_layer(self, num_features: int, rngs: nnx.Rngs):
925961
"""Helper to retrieve the correct normalization layer class based on config, partially applied with common arguments."""
926962
if self.config.decoder_block in (
@@ -1072,10 +1108,18 @@ def __call__(
10721108
audio_embeddings: None | jnp.ndarray = None,
10731109
audio_masks: None | jnp.ndarray = None,
10741110
deepstack_visual_embeds: None | list[jnp.ndarray] = None,
1111+
multimodal_input=None,
10751112
):
10761113
cfg = self.config
10771114
assert decoder_input_tokens.ndim == 2 # [batch, len]
10781115

1116+
if multimodal_input is not None:
1117+
image_embeddings = multimodal_input.image_embeddings
1118+
bidirectional_mask = multimodal_input.bidirectional_mask
1119+
image_masks = multimodal_input.image_masks
1120+
audio_embeddings = multimodal_input.audio_embeddings
1121+
audio_masks = multimodal_input.audio_masks
1122+
10791123
# [batch, length] -> [batch, length, emb_dim]
10801124
y = self._apply_embedding(
10811125
shared_embedding,
@@ -1223,7 +1267,6 @@ def __call__(
12231267
decoder_input_tokens=decoder_input_tokens,
12241268
)
12251269

1226-
12271270
else:
12281271
# Non-Pipeline Run
12291272
if cfg.scan_layers:

src/maxtext/layers/nnx_wrappers.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from flax.core import FrozenDict
2727
from flax.core import meta
2828
from flax.nnx import graph
29+
from flax.nnx import tracers as nnx_tracers
2930
from flax.nnx import variablelib
3031
from flax.nnx.bridge import module as bdg_module
3132
from flax.nnx.module import Module
@@ -170,6 +171,23 @@ def current_linen_module() -> linen.Module | None:
170171
return None
171172

172173

174+
def _refresh_variable_trace_state(module: Module) -> None:
175+
"""Refresh _trace_state for Variables that have stale trace state.
176+
177+
When nnx.update() is called with tracer values from a JAX transformation
178+
(e.g. jax.grad's LinearizeTracer), it uses _unsafe_bypass_check=True which
179+
updates the raw value but not _trace_state. This leaves Variables with a
180+
stale _trace_state from the outer (Python) context, causing nnx.split() to
181+
fail with "Cannot extract graph node from different trace level" errors.
182+
183+
This function resets _trace_state on any Variables whose _can_update is False
184+
so that downstream NNX operations (e.g. nnx.split in NNXPipeline) succeed.
185+
"""
186+
for _, v in nnx.graph.iter_graph(module):
187+
if isinstance(v, variablelib.Variable) and not v._can_update: # pylint: disable=protected-access
188+
object.__setattr__(v, "_trace_state", nnx_tracers.TraceState())
189+
190+
173191
class ToNNX(Module):
174192
"""A wrapper to turn any Linen module into an NNX module.
175193
@@ -467,6 +485,7 @@ def maybe_unbox(x):
467485
warnings.warn(f"Found unknown module paths in incoming state:{paths_str}")
468486

469487
nnx.update(module, new_state)
488+
_refresh_variable_trace_state(module)
470489

471490
_fix_for_qwix_quantization(module)
472491
method_fn = _get_module_method(module, nnx_method)

0 commit comments

Comments
 (0)