Skip to content

Commit 68509ee

Browse files
committed
[WIP] NNX: fix model and test compatibility issues
- Replace nn.Dropout with linears.Dropout in gpt_oss and olmo3 decoder layers - Add num_activations logical axis rule to base.yml - Fix integration and unit tests for NNX compatibility I will relocate these files accordingly once the work is done.
1 parent f90a659 commit 68509ee

37 files changed

Lines changed: 3881 additions & 58 deletions

src/maxtext/common/gcloud_stub.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,15 @@ def is_decoupled() -> bool: # dynamic check so setting env after initial import
4343
return os.environ.get("DECOUPLE_GCLOUD", "").upper() == "TRUE"
4444

4545

46+
def is_pure_nnx() -> bool: # dynamic check so setting env after initial import still works
47+
"""Return True when running in pure NNX mode (PURE_NNX=TRUE env var).
48+
49+
Defaults to FALSE — Linen is the default test mode.
50+
Set PURE_NNX=TRUE to opt in to NNX mode (skips linen_only tests, runs nnx_only tests).
51+
"""
52+
return os.environ.get("PURE_NNX", "FALSE").upper() == "TRUE"
53+
54+
4655
T = TypeVar("T")
4756

4857

src/maxtext/configs/base.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -534,6 +534,7 @@ logical_axis_rules: [
534534
['paged_kv_head_dim_size', []],
535535
['dense_layers', []],
536536
['moe_layers', []],
537+
['num_activations', []],
537538
['engram_dim', ['tensor']],
538539
['mhc', []],
539540
['diloco', 'diloco'],

src/maxtext/configs/custom_mesh_and_rule/pipeline-large-moe.yml

Lines changed: 37 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,17 +12,17 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
# This logical rule is designed to optimize pipeline parallelism for large-scale jobs.
16-
# Key changes include removing expert weight sharding on the `q_lora` dimension, which
17-
# is relatively small (e.g., 512 for DeepSeek), and limiting sharding strategies when
18-
# EP x FSDP > 512.
15+
# This logical rule is designed to optimize pipeline parallelism for large-scale jobs.
16+
# Key changes include removing expert weight sharding on the `q_lora` dimension, which
17+
# is relatively small (e.g., 512 for DeepSeek), and limiting sharding strategies when
18+
# EP x FSDP > 512.
1919
#
20-
# The `data` axis is preserved for two reasons: first, the pipeline stage acts as a
21-
# data parallel (DP) domain externally, making the `data` axis a necessary reference;
22-
# second, it may be required for DCN communication.
20+
# The `data` axis is preserved for two reasons: first, the pipeline stage acts as a
21+
# data parallel (DP) domain externally, making the `data` axis a necessary reference;
22+
# second, it may be required for DCN communication.
2323
#
24-
# Finally, the `tensor` axis is used to shard weights when `pipeline_fsdp_ag_once` or
25-
# `pipeline_fsdp_ag_per_repeat` is enabled, ensuring we have sufficient memory to
24+
# Finally, the `tensor` axis is used to shard weights when `pipeline_fsdp_ag_once` or
25+
# `pipeline_fsdp_ag_per_repeat` is enabled, ensuring we have sufficient memory to
2626
# store prefetched weights.
2727
mesh_axes: ['data', 'stage', 'fsdp', 'tensor', 'expert']
2828
data_sharding: [['data', 'stage', 'fsdp', 'tensor', 'expert']]
@@ -71,4 +71,32 @@ logical_axis_rules: [
7171
['exp_with_fsdp', 'fsdp'],
7272
['paged_kv_heads', ['tensor']],
7373
['engram_dim', ['tensor']],
74+
# Axes unsharded: sequence/context/tensor_transpose/autoregressive do not exist in this mesh
75+
['activation_attn_length_no_exp', []],
76+
['activation_length_no_exp', []],
77+
['activation_norm_length', []],
78+
['activation_q_length_no_exp', []],
79+
['prefill_activation_length', []],
80+
['prefill_activation_norm_length', []],
81+
['activation_kv_length', []],
82+
['decode_length', []],
83+
['embed_tensor_transpose', []],
84+
['q_lora_up_proj', []],
85+
['kv_lora_up_proj', []],
86+
['kv', []],
87+
['qkv', []],
88+
['kv_head_dim', []],
89+
['cache_batch_prefill', []],
90+
['cache_batch', []],
91+
['cache_heads_none', []],
92+
['cache_kv', []],
93+
['cache_sequence', []],
94+
['num_pages', []],
95+
['tokens_per_page', []],
96+
['paged_kv_head_dim_size', []],
97+
['dense_layers', []],
98+
['moe_layers', []],
99+
['num_activations', []],
100+
['mhc', []],
101+
['diloco', []],
74102
]

src/maxtext/configs/custom_mesh_and_rule/pure-fsdp.yml

Lines changed: 56 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,12 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
# This rule only uses FSDP. Pure FSDP is the go-to sharding strategy
15+
# This rule only uses FSDP. Pure FSDP is the go-to sharding strategy
1616
# for small-scale training and this rule simplifies the overall configuration.
1717
mesh_axes: ['fsdp']
1818
data_sharding: [['fsdp']]
1919
logical_axis_rules: [
20+
# Batch/data dimensions sharded on fsdp
2021
['activation_batch', ['fsdp']],
2122
['activation_batch_no_exp', ['fsdp']],
2223
['activation_batch_moe', ['fsdp']],
@@ -27,11 +28,65 @@ logical_axis_rules: [
2728
['activation_kv_batch', ['fsdp']],
2829
['activation_kv_batch_no_exp', ['fsdp']],
2930
['decode_batch', ['fsdp']],
31+
# Weight dimensions sharded on fsdp
3032
['embed', ['fsdp']],
3133
['embed_no_exp', ['fsdp']],
3234
['embed_moe', ['fsdp']],
3335
['embed_no_exp_moe', ['fsdp']],
3436
['q_lora', ['fsdp']],
3537
['kv_lora', ['fsdp']],
3638
['exp_with_fsdp', 'fsdp'],
39+
# All other axes are unsharded (tensor/sequence/expert axes do not exist in pure-fsdp)
40+
['activation_heads', []],
41+
['activation_kv_heads', []],
42+
['activation_length', []],
43+
['activation_attn_length', []],
44+
['activation_attn_length_no_exp', []],
45+
['activation_length_no_exp', []],
46+
['activation_norm_length', []],
47+
['activation_q_length', []],
48+
['activation_q_length_no_exp', []],
49+
['prefill_activation_length', []],
50+
['prefill_activation_norm_length', []],
51+
['activation_kv_length', []],
52+
['activation_attn_embed', []],
53+
['activation_embed', []],
54+
['activation_mlp', []],
55+
['activation_kv', []],
56+
['activation_kv_head_dim', []],
57+
['activation_vocab', []],
58+
['activation_stage', []],
59+
['activation_exp', []],
60+
['decode_length', []],
61+
['mlp', []],
62+
['mlp_no_fsdp', []],
63+
['vocab', []],
64+
['heads', []],
65+
['q_heads', []],
66+
['kv_heads', []],
67+
['embed_tensor_transpose', []],
68+
['q_lora_up_proj', []],
69+
['kv_lora_up_proj', []],
70+
['norm', []],
71+
['layers', []],
72+
['qkv', []],
73+
['kv', []],
74+
['kv_head_dim', []],
75+
['cache_batch_prefill', []],
76+
['cache_batch', []],
77+
['cache_heads_none', []],
78+
['cache_heads', []],
79+
['cache_kv', []],
80+
['cache_sequence', []],
81+
['exp', []],
82+
['paged_kv_heads', []],
83+
['num_pages', []],
84+
['tokens_per_page', []],
85+
['paged_kv_head_dim_size', []],
86+
['dense_layers', []],
87+
['moe_layers', []],
88+
['num_activations', []],
89+
['engram_dim', []],
90+
['mhc', []],
91+
['diloco', []],
3792
]

src/maxtext/configs/decoupled_base_test.yml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,10 @@ eval_dataset_name: 'c4/en:3.1.0'
3030
# Use dot_product attention to avoid GPU Pallas shared memory limits on AMD GPUs
3131
attention: "dot_product"
3232

33+
# Default to Linen mode for tests; NNX is opt-in via PURE_NNX=TRUE.
34+
pure_nnx: False
35+
pure_nnx_decoder: False
36+
3337
# Avoid HLO dump overhead.
3438
dump_hlo: false
3539
jax_cache_dir: ""

src/maxtext/layers/nnx_decoders.py

Lines changed: 28 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -486,8 +486,16 @@ def pure_layer_fn(state_in, y_in):
486486
out = merged_layer(y_in, **kwargs)
487487
return out, nnx.state(merged_layer)
488488

489-
checkpointed_fn = jax.checkpoint(pure_layer_fn, policy=policy, prevent_cse=prevent_cse)
490-
out, new_state = checkpointed_fn(state, y)
489+
# Linen-based FP8 ops (fp8_nanoo, fp8_gpu) store scale/amax_history in Linen
490+
# mutable scope. jax.checkpoint re-traces the scan body during backward (remat),
491+
# but the Linen scope retains JAX tracers from the first trace, causing
492+
# UnexpectedTracerError. Skip checkpoint for these quantization types.
493+
uses_linen_fp8_mutable_state = self.config.quantization in ("fp8_nanoo", "fp8_gpu")
494+
if uses_linen_fp8_mutable_state:
495+
out, new_state = pure_layer_fn(state, y)
496+
else:
497+
checkpointed_fn = jax.checkpoint(pure_layer_fn, policy=policy, prevent_cse=prevent_cse)
498+
out, new_state = checkpointed_fn(state, y)
491499
nnx.update(layer, new_state)
492500

493501
return out
@@ -529,9 +537,24 @@ def layer_fn(carry, scanned_vars):
529537
# ONLY return non-param state to prevent memory duplication of weights
530538
return new_carry, new_current_state
531539

532-
layer_fn = jax.checkpoint(layer_fn, policy=policy, prevent_cse=prevent_cse)
533-
534-
final_carry, scanned_other = jax.lax.scan(layer_fn, x_in, (params, state))
540+
# Linen-based FP8 ops (fp8_nanoo, fp8_gpu) store scale/amax_history in Linen
541+
# mutable scope. jax.lax.scan traces the body function and Linen's setup() creates
542+
# intermediate tracer values (amax_history float32[1024]) that escape the scan scope,
543+
# causing UnexpectedTracerError. Use a Python for loop instead for these types.
544+
uses_linen_fp8_mutable_state = self.config.quantization in ("fp8_nanoo", "fp8_gpu")
545+
if uses_linen_fp8_mutable_state:
546+
carry = x_in
547+
per_layer_states = []
548+
for i in range(length):
549+
current_params = jax.tree.map(lambda x, i=i: x[i], params)
550+
current_state = jax.tree.map(lambda x, i=i: x[i], state)
551+
carry, new_state_i = layer_fn(carry, (current_params, current_state))
552+
per_layer_states.append(new_state_i)
553+
final_carry = carry
554+
scanned_state = jax.tree.map(lambda *xs: jnp.stack(list(xs)), *per_layer_states)
555+
else:
556+
layer_fn = jax.checkpoint(layer_fn, policy=policy, prevent_cse=prevent_cse)
557+
final_carry, scanned_state = jax.lax.scan(layer_fn, x_in, (params, state))
535558

536559
if scan_axis != 0:
537560
params = jax.tree.map(lambda x: jnp.moveaxis(x, 0, scan_axis), params)

src/maxtext/layers/normalizations.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -104,9 +104,9 @@ def __call__(self, x: jnp.ndarray, out_sharding: NamedSharding | None = None) ->
104104

105105
def Qwen3NextRMSNorm(
106106
num_features: int,
107-
epsilon: float,
108-
dtype: DType,
109-
weight_dtype: DType,
107+
epsilon: float = 1e-6,
108+
dtype: DType = jnp.float32,
109+
weight_dtype: DType = jnp.float32,
110110
shard_mode: ShardMode = ShardMode.AUTO,
111111
kernel_axes: tuple[None | str, ...] = (),
112112
parameter_memory_host_offload: bool = False,

src/maxtext/models/gpt_oss.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from maxtext.common.common_types import AttentionType, Config
2929
from maxtext.layers import attentions
3030
from maxtext.layers import initializers
31+
from maxtext.layers import linears
3132
from maxtext.layers import moe
3233
from maxtext.layers import nnx_wrappers
3334
from maxtext.layers import quantizations
@@ -130,6 +131,8 @@ def __init__(
130131
rngs=rngs,
131132
)
132133

134+
self.dropout = linears.Dropout(rate=config.dropout_rate, broadcast_dims=(-2,), rngs=rngs)
135+
133136
def __call__(
134137
self,
135138
inputs,
@@ -181,7 +184,7 @@ def __call__(
181184
mlp_lnx = nn.with_logical_constraint(mlp_lnx, ("activation_batch", "activation_norm_length", "activation_embed"))
182185

183186
layer_output = mlp_lnx + intermediate_inputs
184-
layer_output = nn.Dropout(rate=cfg.dropout_rate, broadcast_dims=(-2,))(layer_output, deterministic=deterministic)
187+
layer_output = self.dropout(layer_output, deterministic=deterministic)
185188

186189
layer_output = nn.with_logical_constraint(
187190
layer_output,

src/maxtext/models/llama2.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ def __init__(
7070
shard_mode=config.shard_mode,
7171
kernel_axes=("norm",),
7272
epsilon=config.normalization_layer_epsilon,
73+
parameter_memory_host_offload=config.parameter_memory_host_offload,
7374
rngs=rngs,
7475
)
7576

src/maxtext/models/olmo3.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from maxtext.common.common_types import AttentionType, Config
3030
from maxtext.layers import attentions
3131
from maxtext.layers import initializers
32+
from maxtext.layers import linears
3233
from maxtext.layers import nnx_wrappers
3334
from maxtext.layers import quantizations
3435
from maxtext.layers.attentions import Attention
@@ -140,6 +141,8 @@ def __init__(
140141
rngs=rngs,
141142
)
142143

144+
self.dropout = linears.Dropout(rate=config.dropout_rate, broadcast_dims=(-2,), rngs=rngs)
145+
143146
def __call__(
144147
self,
145148
inputs,
@@ -193,7 +196,7 @@ def __call__(
193196
mlp_lnx = nn.with_logical_constraint(mlp_lnx, ("activation_batch", "activation_norm_length", "activation_embed"))
194197

195198
layer_output = mlp_lnx + intermediate_inputs
196-
layer_output = nn.Dropout(rate=cfg.dropout_rate, broadcast_dims=(-2,))(layer_output, deterministic=deterministic)
199+
layer_output = self.dropout(layer_output, deterministic=deterministic)
197200

198201
layer_output = nn.with_logical_constraint(
199202
layer_output,

0 commit comments

Comments
 (0)