Skip to content

Commit 695aecb

Browse files
committed
[WIP] NNX: fix model and test compatibility issues
- Replace nn.Dropout with linears.Dropout in gpt_oss and olmo3 decoder layers - Add num_activations logical axis rule to base.yml - Fix integration and unit tests for NNX compatibility I will relocate these files accordingly once the work is done.
1 parent 53f8304 commit 695aecb

33 files changed

Lines changed: 3745 additions & 51 deletions

src/maxtext/common/gcloud_stub.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,15 @@ def is_decoupled() -> bool: # dynamic check so setting env after initial import
4343
return os.environ.get("DECOUPLE_GCLOUD", "").upper() == "TRUE"
4444

4545

46+
def is_pure_nnx() -> bool: # dynamic check so setting env after initial import still works
47+
"""Return True when running in pure NNX mode (PURE_NNX=TRUE env var).
48+
49+
Defaults to FALSE — Linen is the default test mode.
50+
Set PURE_NNX=TRUE to opt in to NNX mode (skips linen_only tests, runs nnx_only tests).
51+
"""
52+
return os.environ.get("PURE_NNX", "FALSE").upper() == "TRUE"
53+
54+
4655
T = TypeVar("T")
4756

4857

src/maxtext/configs/base.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -514,6 +514,7 @@ logical_axis_rules: [
514514
['paged_kv_head_dim_size', []],
515515
['dense_layers', []],
516516
['moe_layers', []],
517+
['num_activations', []],
517518
['engram_dim', ['tensor']],
518519
['mhc', []],
519520
['diloco', 'diloco'],

src/maxtext/configs/decoupled_base_test.yml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,10 @@ eval_dataset_name: 'c4/en:3.1.0'
3030
# Use dot_product attention to avoid GPU Pallas shared memory limits on AMD GPUs
3131
attention: "dot_product"
3232

33+
# Default to Linen mode for tests; NNX is opt-in via PURE_NNX=TRUE.
34+
pure_nnx: False
35+
pure_nnx_decoder: False
36+
3337
# Avoid HLO dump overhead.
3438
dump_hlo: false
3539
jax_cache_dir: ""

src/maxtext/layers/nnx_decoders.py

Lines changed: 29 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -432,8 +432,16 @@ def pure_layer_fn(state_in, y_in):
432432
out = merged_layer(y_in, **kwargs)
433433
return out, nnx.state(merged_layer)
434434

435-
checkpointed_fn = jax.checkpoint(pure_layer_fn, policy=policy, prevent_cse=prevent_cse)
436-
out, new_state = checkpointed_fn(state, y)
435+
# Linen-based FP8 ops (fp8_nanoo, fp8_gpu) store scale/amax_history in Linen
436+
# mutable scope. jax.checkpoint re-traces the scan body during backward (remat),
437+
# but the Linen scope retains JAX tracers from the first trace, causing
438+
# UnexpectedTracerError. Skip checkpoint for these quantization types.
439+
uses_linen_fp8_mutable_state = self.config.quantization in ("fp8_nanoo", "fp8_gpu")
440+
if uses_linen_fp8_mutable_state:
441+
out, new_state = pure_layer_fn(state, y)
442+
else:
443+
checkpointed_fn = jax.checkpoint(pure_layer_fn, policy=policy, prevent_cse=prevent_cse)
444+
out, new_state = checkpointed_fn(state, y)
437445
nnx.update(layer, new_state)
438446

439447
return out
@@ -475,20 +483,29 @@ def layer_fn(carry, scanned_vars):
475483

476484
return new_carry, new_current_state
477485

478-
layer_fn = jax.checkpoint(layer_fn, policy=policy, prevent_cse=prevent_cse)
479-
480-
final_carry, scanned_other = jax.lax.scan(layer_fn, x_in, (params, state))
486+
# Linen-based FP8 ops (fp8_nanoo, fp8_gpu) store scale/amax_history in Linen
487+
# mutable scope. jax.lax.scan traces the body function and Linen's setup() creates
488+
# intermediate tracer values (amax_history float32[1024]) that escape the scan scope,
489+
# causing UnexpectedTracerError. Use a Python for loop instead for these types.
490+
uses_linen_fp8_mutable_state = self.config.quantization in ("fp8_nanoo", "fp8_gpu")
491+
if uses_linen_fp8_mutable_state:
492+
carry = x_in
493+
per_layer_states = []
494+
for i in range(length):
495+
current_params = jax.tree.map(lambda x, i=i: x[i], params)
496+
current_state = jax.tree.map(lambda x, i=i: x[i], state)
497+
carry, new_state_i = layer_fn(carry, (current_params, current_state))
498+
per_layer_states.append(new_state_i)
499+
final_carry = carry
500+
scanned_state = jax.tree.map(lambda *xs: jnp.stack(list(xs)), *per_layer_states)
501+
else:
502+
layer_fn = jax.checkpoint(layer_fn, policy=policy, prevent_cse=prevent_cse)
503+
final_carry, scanned_state = jax.lax.scan(layer_fn, x_in, (params, state))
481504

482505
if scan_axis != 0:
483506
params = jax.tree.map(lambda x: jnp.moveaxis(x, 0, scan_axis), params)
484507

485-
# Params are read-only during the forward pass, so the scan output's copy of
486-
# params is at axis=0 (lax.scan default) rather than scan_axis. Discard the
487-
# scan-output params and keep the original params (correctly positioned at
488-
# scan_axis) to avoid a shape mismatch when _apply_scanned_chunk tries to
489-
# write them back via dynamic_update_slice_in_dim.
490-
_, non_param_scanned_state = scanned_state.split(nnx.Param, ...)
491-
scanned_state = nnx.State.merge(params, non_param_scanned_state)
508+
scanned_state = nnx.State.merge(params, scanned_state)
492509
return final_carry, nnx.merge(graphdef, scanned_state)
493510

494511
def get_decoder_layers(self):

src/maxtext/layers/normalizations.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -104,9 +104,9 @@ def __call__(self, x: jnp.ndarray, out_sharding: NamedSharding | None = None) ->
104104

105105
def Qwen3NextRMSNorm(
106106
num_features: int,
107-
epsilon: float,
108-
dtype: DType,
109-
weight_dtype: DType,
107+
epsilon: float = 1e-6,
108+
dtype: DType = jnp.float32,
109+
weight_dtype: DType = jnp.float32,
110110
shard_mode: ShardMode = ShardMode.AUTO,
111111
kernel_axes: tuple[None | str, ...] = (),
112112
parameter_memory_host_offload: bool = False,

src/maxtext/models/gpt_oss.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from maxtext.common.common_types import AttentionType, Config
2929
from maxtext.layers import attentions
3030
from maxtext.layers import initializers
31+
from maxtext.layers import linears
3132
from maxtext.layers import moe
3233
from maxtext.layers import nnx_wrappers
3334
from maxtext.layers import quantizations
@@ -130,6 +131,8 @@ def __init__(
130131
rngs=rngs,
131132
)
132133

134+
self.dropout = linears.Dropout(rate=config.dropout_rate, broadcast_dims=(-2,), rngs=rngs)
135+
133136
def __call__(
134137
self,
135138
inputs,
@@ -181,7 +184,7 @@ def __call__(
181184
mlp_lnx = nn.with_logical_constraint(mlp_lnx, ("activation_batch", "activation_norm_length", "activation_embed"))
182185

183186
layer_output = mlp_lnx + intermediate_inputs
184-
layer_output = nn.Dropout(rate=cfg.dropout_rate, broadcast_dims=(-2,))(layer_output, deterministic=deterministic)
187+
layer_output = self.dropout(layer_output, deterministic=deterministic)
185188

186189
layer_output = nn.with_logical_constraint(
187190
layer_output,

src/maxtext/models/llama2.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ def __init__(
7070
shard_mode=config.shard_mode,
7171
kernel_axes=("norm",),
7272
epsilon=config.normalization_layer_epsilon,
73+
parameter_memory_host_offload=config.parameter_memory_host_offload,
7374
rngs=rngs,
7475
)
7576

src/maxtext/models/olmo3.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from maxtext.common.common_types import AttentionType, Config
3030
from maxtext.layers import attentions
3131
from maxtext.layers import initializers
32+
from maxtext.layers import linears
3233
from maxtext.layers import nnx_wrappers
3334
from maxtext.layers import quantizations
3435
from maxtext.layers.attentions import Attention
@@ -140,6 +141,8 @@ def __init__(
140141
rngs=rngs,
141142
)
142143

144+
self.dropout = linears.Dropout(rate=config.dropout_rate, broadcast_dims=(-2,), rngs=rngs)
145+
143146
def __call__(
144147
self,
145148
inputs,
@@ -193,7 +196,7 @@ def __call__(
193196
mlp_lnx = nn.with_logical_constraint(mlp_lnx, ("activation_batch", "activation_norm_length", "activation_embed"))
194197

195198
layer_output = mlp_lnx + intermediate_inputs
196-
layer_output = nn.Dropout(rate=cfg.dropout_rate, broadcast_dims=(-2,))(layer_output, deterministic=deterministic)
199+
layer_output = self.dropout(layer_output, deterministic=deterministic)
197200

198201
layer_output = nn.with_logical_constraint(
199202
layer_output,

src/maxtext/trainers/post_train/sft/train_sft.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,6 @@
4747

4848
from orbax import checkpoint as ocp
4949

50-
from tunix.sft import metrics_logger, peft_trainer, profiler
51-
5250
from maxtext.configs import pyconfig
5351
from maxtext.trainers.pre_train.train import loss_fn
5452
from maxtext.common.goodput import (
@@ -77,6 +75,8 @@ def get_tunix_config(mt_config):
7775
Returns:
7876
A Tunix `TrainingConfig` object.
7977
"""
78+
from tunix.sft import metrics_logger, peft_trainer, profiler # pylint: disable=g-import-not-at-top,import-outside-toplevel
79+
8080
# Checkpointing configurations
8181
checkpointing_options = ocp.CheckpointManagerOptions(
8282
save_interval_steps=mt_config.checkpoint_period,
@@ -143,6 +143,8 @@ def loss_func(model, inputs, inputs_position, inputs_segmentation, targets, targ
143143

144144
def setup_trainer_state(mt_config, goodput_recorder=None):
145145
"""Set up prerequisites for training loop."""
146+
from tunix.sft import peft_trainer # pylint: disable=g-import-not-at-top,import-outside-toplevel
147+
146148
tunix_config = get_tunix_config(mt_config)
147149

148150
with maybe_record_goodput(goodput_recorder, GoodputEvent.TPU_INIT):

src/maxtext/trainers/pre_train/train.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -497,7 +497,11 @@ def move(path, value):
497497
if config.use_dpo:
498498
new_state = _merge_dpo_state(new_state, reference_params)
499499
return new_state, metrics
500-
return nnx.state(new_state), metrics
500+
# Exclude Intermediate variables (e.g., sowed max_logits for QK-Clip) from the
501+
# returned state. Intermediates are transient forward-pass artifacts and must not
502+
# persist across steps: they're absent from the abstract state used to build
503+
# state_mesh_shardings, so including them would cause a leaf-count mismatch in JAX.
504+
return nnx.state(new_state, nnx.Not(nnx.Intermediate)), metrics
501505

502506

503507
def eval_step(model, config, state, data, dropout_rng=None):

0 commit comments

Comments
 (0)