Skip to content

Commit 42463de

Browse files
committed
[WIP] NNX: fix model and test compatibility issues
- Replace nn.Dropout with linears.Dropout in gpt_oss and olmo3 decoder layers - Add num_activations logical axis rule to base.yml - Fix integration and unit tests for NNX compatibility I will relocate these files accordingly once the work is done.
1 parent 1aa8fc5 commit 42463de

35 files changed

Lines changed: 3758 additions & 49 deletions

src/maxtext/common/gcloud_stub.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,15 @@ def is_decoupled() -> bool: # dynamic check so setting env after initial import
4343
return os.environ.get("DECOUPLE_GCLOUD", "").upper() == "TRUE"
4444

4545

46+
def is_pure_nnx() -> bool: # dynamic check so setting env after initial import still works
47+
"""Return True when running in pure NNX mode (PURE_NNX=TRUE env var).
48+
49+
Defaults to FALSE — Linen is the default test mode.
50+
Set PURE_NNX=TRUE to opt in to NNX mode (skips linen_only tests, runs nnx_only tests).
51+
"""
52+
return os.environ.get("PURE_NNX", "FALSE").upper() == "TRUE"
53+
54+
4655
T = TypeVar("T")
4756

4857

src/maxtext/configs/base.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -515,6 +515,7 @@ logical_axis_rules: [
515515
['paged_kv_head_dim_size', []],
516516
['dense_layers', []],
517517
['moe_layers', []],
518+
['num_activations', []],
518519
['engram_dim', ['tensor']],
519520
['mhc', []],
520521
['diloco', 'diloco'],

src/maxtext/configs/decoupled_base_test.yml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,10 @@ eval_dataset_name: 'c4/en:3.1.0'
3030
# Use dot_product attention to avoid GPU Pallas shared memory limits on AMD GPUs
3131
attention: "dot_product"
3232

33+
# Default to Linen mode for tests; NNX is opt-in via PURE_NNX=TRUE.
34+
pure_nnx: False
35+
pure_nnx_decoder: False
36+
3337
# Avoid HLO dump overhead.
3438
dump_hlo: false
3539
jax_cache_dir: ""

src/maxtext/layers/nnx_decoders.py

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -470,8 +470,16 @@ def pure_layer_fn(state_in, y_in):
470470
out = merged_layer(y_in, **kwargs)
471471
return out, nnx.state(merged_layer)
472472

473-
checkpointed_fn = jax.checkpoint(pure_layer_fn, policy=policy, prevent_cse=prevent_cse)
474-
out, new_state = checkpointed_fn(state, y)
473+
# Linen-based FP8 ops (fp8_nanoo, fp8_gpu) store scale/amax_history in Linen
474+
# mutable scope. jax.checkpoint re-traces the scan body during backward (remat),
475+
# but the Linen scope retains JAX tracers from the first trace, causing
476+
# UnexpectedTracerError. Skip checkpoint for these quantization types.
477+
uses_linen_fp8_mutable_state = self.config.quantization in ("fp8_nanoo", "fp8_gpu")
478+
if uses_linen_fp8_mutable_state:
479+
out, new_state = pure_layer_fn(state, y)
480+
else:
481+
checkpointed_fn = jax.checkpoint(pure_layer_fn, policy=policy, prevent_cse=prevent_cse)
482+
out, new_state = checkpointed_fn(state, y)
475483
nnx.update(layer, new_state)
476484

477485
return out
@@ -513,9 +521,24 @@ def layer_fn(carry, scanned_vars):
513521

514522
return new_carry, new_current_state
515523

516-
layer_fn = jax.checkpoint(layer_fn, policy=policy, prevent_cse=prevent_cse)
517-
518-
final_carry, scanned_other = jax.lax.scan(layer_fn, x_in, (params, state))
524+
# Linen-based FP8 ops (fp8_nanoo, fp8_gpu) store scale/amax_history in Linen
525+
# mutable scope. jax.lax.scan traces the body function and Linen's setup() creates
526+
# intermediate tracer values (amax_history float32[1024]) that escape the scan scope,
527+
# causing UnexpectedTracerError. Use a Python for loop instead for these types.
528+
uses_linen_fp8_mutable_state = self.config.quantization in ("fp8_nanoo", "fp8_gpu")
529+
if uses_linen_fp8_mutable_state:
530+
carry = x_in
531+
per_layer_states = []
532+
for i in range(length):
533+
current_params = jax.tree.map(lambda x, i=i: x[i], params)
534+
current_state = jax.tree.map(lambda x, i=i: x[i], state)
535+
carry, new_state_i = layer_fn(carry, (current_params, current_state))
536+
per_layer_states.append(new_state_i)
537+
final_carry = carry
538+
scanned_state = jax.tree.map(lambda *xs: jnp.stack(list(xs)), *per_layer_states)
539+
else:
540+
layer_fn = jax.checkpoint(layer_fn, policy=policy, prevent_cse=prevent_cse)
541+
final_carry, scanned_state = jax.lax.scan(layer_fn, x_in, (params, state))
519542

520543
if scan_axis != 0:
521544
params = jax.tree.map(lambda x: jnp.moveaxis(x, 0, scan_axis), params)
@@ -525,7 +548,7 @@ def layer_fn(carry, scanned_vars):
525548
# scan-output params and keep the original params (correctly positioned at
526549
# scan_axis) to avoid a shape mismatch when _apply_scanned_chunk tries to
527550
# write them back via dynamic_update_slice_in_dim.
528-
_, non_param_scanned_state = scanned_other.split(nnx.Param, ...)
551+
_, non_param_scanned_state = scanned_state.split(nnx.Param, ...)
529552
scanned_state = nnx.State.merge(params, non_param_scanned_state)
530553
return final_carry, nnx.merge(graphdef, scanned_state)
531554

src/maxtext/layers/normalizations.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -104,9 +104,9 @@ def __call__(self, x: jnp.ndarray, out_sharding: NamedSharding | None = None) ->
104104

105105
def Qwen3NextRMSNorm(
106106
num_features: int,
107-
epsilon: float,
108-
dtype: DType,
109-
weight_dtype: DType,
107+
epsilon: float = 1e-6,
108+
dtype: DType = jnp.float32,
109+
weight_dtype: DType = jnp.float32,
110110
shard_mode: ShardMode = ShardMode.AUTO,
111111
kernel_axes: tuple[None | str, ...] = (),
112112
parameter_memory_host_offload: bool = False,

src/maxtext/models/gpt_oss.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from maxtext.common.common_types import AttentionType, Config
2929
from maxtext.layers import attentions
3030
from maxtext.layers import initializers
31+
from maxtext.layers import linears
3132
from maxtext.layers import moe
3233
from maxtext.layers import nnx_wrappers
3334
from maxtext.layers import quantizations
@@ -130,6 +131,8 @@ def __init__(
130131
rngs=rngs,
131132
)
132133

134+
self.dropout = linears.Dropout(rate=config.dropout_rate, broadcast_dims=(-2,), rngs=rngs)
135+
133136
def __call__(
134137
self,
135138
inputs,
@@ -181,7 +184,7 @@ def __call__(
181184
mlp_lnx = nn.with_logical_constraint(mlp_lnx, ("activation_batch", "activation_norm_length", "activation_embed"))
182185

183186
layer_output = mlp_lnx + intermediate_inputs
184-
layer_output = nn.Dropout(rate=cfg.dropout_rate, broadcast_dims=(-2,))(layer_output, deterministic=deterministic)
187+
layer_output = self.dropout(layer_output, deterministic=deterministic)
185188

186189
layer_output = nn.with_logical_constraint(
187190
layer_output,

src/maxtext/models/llama2.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ def __init__(
7070
shard_mode=config.shard_mode,
7171
kernel_axes=("norm",),
7272
epsilon=config.normalization_layer_epsilon,
73+
parameter_memory_host_offload=config.parameter_memory_host_offload,
7374
rngs=rngs,
7475
)
7576

src/maxtext/models/olmo3.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from maxtext.common.common_types import AttentionType, Config
3030
from maxtext.layers import attentions
3131
from maxtext.layers import initializers
32+
from maxtext.layers import linears
3233
from maxtext.layers import nnx_wrappers
3334
from maxtext.layers import quantizations
3435
from maxtext.layers.attentions import Attention
@@ -140,6 +141,8 @@ def __init__(
140141
rngs=rngs,
141142
)
142143

144+
self.dropout = linears.Dropout(rate=config.dropout_rate, broadcast_dims=(-2,), rngs=rngs)
145+
143146
def __call__(
144147
self,
145148
inputs,
@@ -193,7 +196,7 @@ def __call__(
193196
mlp_lnx = nn.with_logical_constraint(mlp_lnx, ("activation_batch", "activation_norm_length", "activation_embed"))
194197

195198
layer_output = mlp_lnx + intermediate_inputs
196-
layer_output = nn.Dropout(rate=cfg.dropout_rate, broadcast_dims=(-2,))(layer_output, deterministic=deterministic)
199+
layer_output = self.dropout(layer_output, deterministic=deterministic)
197200

198201
layer_output = nn.with_logical_constraint(
199202
layer_output,

src/maxtext/trainers/post_train/sft/train_sft.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,6 @@
4747

4848
from orbax import checkpoint as ocp
4949

50-
from tunix.sft import metrics_logger, peft_trainer, profiler
51-
5250
from maxtext.configs import pyconfig
5351
from maxtext.trainers.pre_train.train import loss_fn
5452
from maxtext.common.goodput import (
@@ -77,6 +75,8 @@ def get_tunix_config(mt_config):
7775
Returns:
7876
A Tunix `TrainingConfig` object.
7977
"""
78+
from tunix.sft import metrics_logger, peft_trainer, profiler # pylint: disable=g-import-not-at-top,import-outside-toplevel
79+
8080
# Checkpointing configurations
8181
checkpointing_options = ocp.CheckpointManagerOptions(
8282
save_interval_steps=mt_config.checkpoint_period,
@@ -143,6 +143,8 @@ def loss_func(model, inputs, inputs_position, inputs_segmentation, targets, targ
143143

144144
def setup_trainer_state(mt_config, goodput_recorder=None):
145145
"""Set up prerequisites for training loop."""
146+
from tunix.sft import peft_trainer # pylint: disable=g-import-not-at-top,import-outside-toplevel
147+
146148
tunix_config = get_tunix_config(mt_config)
147149

148150
with maybe_record_goodput(goodput_recorder, GoodputEvent.TPU_INIT):

src/maxtext/trainers/pre_train/train.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -508,7 +508,11 @@ def move(path, value):
508508
if config.use_dpo:
509509
new_state = _merge_dpo_state(new_state, reference_params)
510510
return new_state, metrics
511-
return nnx.state(new_state), metrics
511+
# Exclude Intermediate variables (e.g., sowed max_logits for QK-Clip) from the
512+
# returned state. Intermediates are transient forward-pass artifacts and must not
513+
# persist across steps: they're absent from the abstract state used to build
514+
# state_mesh_shardings, so including them would cause a leaf-count mismatch in JAX.
515+
return nnx.state(new_state, nnx.Not(nnx.Intermediate)), metrics
512516

513517

514518
def eval_step(model, config, state, data, dropout_rng=None):

0 commit comments

Comments
 (0)