Skip to content

Commit 485776c

Browse files
hsuan-lun-chiangecnal-cienet
authored andcommitted
Migrate Decoder (Gemma3/Deepseek/Llama4) and utils to NNX
1 parent 37ded59 commit 485776c

12 files changed

Lines changed: 507 additions & 28 deletions

File tree

src/maxtext/configs/base.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1086,8 +1086,8 @@ position_id_per_seconds: 25
10861086
subslice_shape: ""
10871087

10881088
# NNX
1089-
enable_nnx: false
1090-
pure_nnx_decoder: false
1089+
enable_nnx: True
1090+
pure_nnx_decoder: True
10911091

10921092
################################## Qwen3-Next Specific Configs ##################################
10931093
# Kernel size for the 1D convolution in the Gated Delta Net

src/maxtext/layers/attentions.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -533,14 +533,14 @@ def __init__(
533533
elif self.is_qwen3_next:
534534
self.query_norm = Qwen3NextRMSNorm(
535535
num_features=self.config.head_dim,
536-
eps=self.config.normalization_layer_epsilon,
536+
epsilon=self.config.normalization_layer_epsilon,
537537
dtype=self.config.dtype,
538538
weight_dtype=self.config.weight_dtype,
539539
rngs=self.rngs,
540540
)
541541
self.key_norm = Qwen3NextRMSNorm(
542542
num_features=self.config.head_dim,
543-
eps=self.config.normalization_layer_epsilon,
543+
epsilon=self.config.normalization_layer_epsilon,
544544
dtype=self.config.dtype,
545545
weight_dtype=self.config.weight_dtype,
546546
rngs=self.rngs,

src/maxtext/layers/nnx_decoders.py

Lines changed: 45 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2026 Google LLC
1+
# Copyright 2023–2026 Google LLC
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.
@@ -829,10 +829,19 @@ def _find_next_boundary(self, current_idx, end_idx, engram_indices):
829829
def _apply_single_engram_layer(self, y, current_idx, layer_stack, *args, **kwargs):
830830
"""Applies a single, unscanned Engram layer by dynamically slicing the NNX state."""
831831
graphdef, state = nnx.split(layer_stack)
832+
params, rest = state.split(nnx.Param, ...)
833+
scan_axis = self.config.param_scan_axis
834+
835+
# Helper to generate N-dimensional basic slices (e.g., x[:, idx, :])
836+
def _extract_slice(x, idx, axis):
837+
slices = tuple(idx if i == axis else slice(None) for i in range(x.ndim))
838+
return x[slices]
832839

833-
# Slice the parameters for the current index (assuming scan axis is 0)
834-
sliced_state = jax.tree.map(lambda x: x[current_idx], state)
835-
single_layer = nnx.merge(graphdef, sliced_state)
840+
# Slice using native indexing instead of jnp.take
841+
sliced_params = jax.tree.map(lambda x: _extract_slice(x, current_idx, scan_axis), params)
842+
sliced_rest = jax.tree.map(lambda x: _extract_slice(x, current_idx, 0), rest)
843+
844+
single_layer = nnx.merge(graphdef, sliced_params, sliced_rest)
836845

837846
# Run the single layer
838847
out = single_layer(
@@ -841,37 +850,57 @@ def _apply_single_engram_layer(self, y, current_idx, layer_stack, *args, **kwarg
841850
y = out[0] if isinstance(out, tuple) else out
842851

843852
# Re-merge the updated state back into the specific slice of the stack
844-
new_single_state = nnx.state(single_layer)
845-
updated_state = jax.tree.map(
853+
new_state = nnx.state(single_layer)
854+
new_params, new_rest = new_state.split(nnx.Param, ...)
855+
856+
updated_params = jax.tree.map(
857+
lambda s, new_s: jax.lax.dynamic_update_slice_in_dim(
858+
s, jnp.expand_dims(new_s, axis=scan_axis), current_idx, axis=scan_axis
859+
),
860+
params,
861+
new_params,
862+
)
863+
updated_rest = jax.tree.map(
846864
lambda s, new_s: jax.lax.dynamic_update_slice_in_dim(s, jnp.expand_dims(new_s, axis=0), current_idx, axis=0),
847-
state,
848-
new_single_state,
865+
rest,
866+
new_rest,
849867
)
850-
nnx.update(layer_stack, updated_state)
851868

869+
nnx.update(layer_stack, updated_params, updated_rest)
852870
return y
853871

854872
def _apply_scanned_chunk(self, y, current_idx, next_boundary, layer_stack, *args, **kwargs):
855873
"""Applies a contiguous chunk of layers using scan over a state slice."""
856874
scan_length = next_boundary - current_idx
857875
if scan_length > 0:
858876
graphdef, state = nnx.split(layer_stack)
877+
params, rest = state.split(nnx.Param, ...)
878+
scan_axis = self.config.param_scan_axis
859879

860-
# Slice the chunk state
861-
chunk_state = jax.tree.map(lambda x: jax.lax.dynamic_slice_in_dim(x, current_idx, scan_length, axis=0), state)
862-
chunk_stack = nnx.merge(graphdef, chunk_state)
880+
# Slice the chunk state along the correct axes
881+
chunk_params = jax.tree.map(
882+
lambda x: jax.lax.dynamic_slice_in_dim(x, current_idx, scan_length, axis=scan_axis), params
883+
)
884+
chunk_rest = jax.tree.map(lambda x: jax.lax.dynamic_slice_in_dim(x, current_idx, scan_length, axis=0), rest)
885+
chunk_stack = nnx.merge(graphdef, chunk_params, chunk_rest)
863886

864887
# Apply sequentially
865888
y, chunk_stack = self._apply_layers_sequentially(
866889
chunk_stack, y, *args, length=scan_length, **kwargs.get("layer_kwargs", {})
867890
)
868891

869892
# Update the original stack state
870-
new_chunk_state = nnx.state(chunk_stack)
871-
updated_state = jax.tree.map(
872-
lambda s, new_s: jax.lax.dynamic_update_slice_in_dim(s, new_s, current_idx, axis=0), state, new_chunk_state
893+
new_state = nnx.state(chunk_stack)
894+
new_params, new_rest = new_state.split(nnx.Param, ...)
895+
896+
updated_params = jax.tree.map(
897+
lambda s, new_s: jax.lax.dynamic_update_slice_in_dim(s, new_s, current_idx, axis=scan_axis), params, new_params
873898
)
874-
nnx.update(layer_stack, updated_state)
899+
updated_rest = jax.tree.map(
900+
lambda s, new_s: jax.lax.dynamic_update_slice_in_dim(s, new_s, current_idx, axis=0), rest, new_rest
901+
)
902+
903+
nnx.update(layer_stack, updated_params, updated_rest)
875904

876905
return y
877906

src/maxtext/layers/normalizations.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,17 @@ def __call__(self, x: jnp.ndarray, out_sharding: NamedSharding | None = None) ->
102102
return y_flat.reshape(input_shape)
103103

104104

105-
def Qwen3NextRMSNorm(num_features: int, eps: float, dtype: DType, weight_dtype: DType, *, rngs: nnx.Rngs):
105+
def Qwen3NextRMSNorm(
106+
num_features: int,
107+
epsilon: float = 1e-6,
108+
dtype: DType = jnp.float32,
109+
weight_dtype: DType = jnp.float32,
110+
shard_mode: ShardMode = ShardMode.AUTO,
111+
kernel_axes: tuple[None | str, ...] = (),
112+
parameter_memory_host_offload: bool = False,
113+
*,
114+
rngs: nnx.Rngs,
115+
):
106116
"""
107117
Used for input and post attention layernorms
108118
in Qwen3NextDecoderLayer.
@@ -115,10 +125,13 @@ def Qwen3NextRMSNorm(num_features: int, eps: float, dtype: DType, weight_dtype:
115125
return nnx.data(
116126
RMSNorm(
117127
num_features=num_features,
118-
epsilon=eps,
128+
epsilon=epsilon,
119129
dtype=dtype,
120130
weight_dtype=weight_dtype,
131+
shard_mode=shard_mode,
132+
kernel_axes=kernel_axes,
121133
scale_init=linen_initializers.zeros,
134+
parameter_memory_host_offload=parameter_memory_host_offload,
122135
scale_offset=1.0,
123136
rngs=rngs,
124137
)

src/maxtext/models/gpt_oss.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from maxtext.common.common_types import AttentionType, Config
2929
from maxtext.layers import attentions
3030
from maxtext.layers import initializers
31+
from maxtext.layers import linears
3132
from maxtext.layers import moe
3233
from maxtext.layers import nnx_wrappers
3334
from maxtext.layers import quantizations
@@ -130,6 +131,8 @@ def __init__(
130131
rngs=rngs,
131132
)
132133

134+
self.dropout = linears.Dropout(rate=config.dropout_rate, broadcast_dims=(-2,), rngs=rngs)
135+
133136
def __call__(
134137
self,
135138
inputs,
@@ -181,7 +184,7 @@ def __call__(
181184
mlp_lnx = nn.with_logical_constraint(mlp_lnx, ("activation_batch", "activation_norm_length", "activation_embed"))
182185

183186
layer_output = mlp_lnx + intermediate_inputs
184-
layer_output = nn.Dropout(rate=cfg.dropout_rate, broadcast_dims=(-2,))(layer_output, deterministic=deterministic)
187+
layer_output = self.dropout(layer_output, deterministic=deterministic)
185188

186189
layer_output = nn.with_logical_constraint(
187190
layer_output,

src/maxtext/models/llama2.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ def __init__(
7070
shard_mode=config.shard_mode,
7171
kernel_axes=("norm",),
7272
epsilon=config.normalization_layer_epsilon,
73+
parameter_memory_host_offload=config.parameter_memory_host_offload,
7374
rngs=rngs,
7475
)
7576

src/maxtext/models/olmo3.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from maxtext.common.common_types import AttentionType, Config
3030
from maxtext.layers import attentions
3131
from maxtext.layers import initializers
32+
from maxtext.layers import linears
3233
from maxtext.layers import nnx_wrappers
3334
from maxtext.layers import quantizations
3435
from maxtext.layers.attentions import Attention
@@ -140,6 +141,8 @@ def __init__(
140141
rngs=rngs,
141142
)
142143

144+
self.dropout = linears.Dropout(rate=config.dropout_rate, broadcast_dims=(-2,), rngs=rngs)
145+
143146
def __call__(
144147
self,
145148
inputs,
@@ -193,7 +196,7 @@ def __call__(
193196
mlp_lnx = nn.with_logical_constraint(mlp_lnx, ("activation_batch", "activation_norm_length", "activation_embed"))
194197

195198
layer_output = mlp_lnx + intermediate_inputs
196-
layer_output = nn.Dropout(rate=cfg.dropout_rate, broadcast_dims=(-2,))(layer_output, deterministic=deterministic)
199+
layer_output = self.dropout(layer_output, deterministic=deterministic)
197200

198201
layer_output = nn.with_logical_constraint(
199202
layer_output,

src/maxtext/models/qwen3.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -962,7 +962,7 @@ def __init__(
962962
# First LayerNorm, applied before the attention block.
963963
self.input_layernorm = Qwen3NextRMSNorm(
964964
num_features=cfg.emb_dim,
965-
eps=cfg.normalization_layer_epsilon,
965+
epsilon=cfg.normalization_layer_epsilon,
966966
dtype=cfg.dtype,
967967
weight_dtype=cfg.weight_dtype,
968968
rngs=rngs,
@@ -987,7 +987,7 @@ def __init__(
987987
# Second LayerNorm, applied before the MoE block.
988988
self.post_attention_layernorm = Qwen3NextRMSNorm(
989989
num_features=cfg.emb_dim,
990-
eps=cfg.normalization_layer_epsilon,
990+
epsilon=cfg.normalization_layer_epsilon,
991991
dtype=cfg.dtype,
992992
weight_dtype=cfg.weight_dtype,
993993
rngs=rngs,

0 commit comments

Comments
 (0)