Skip to content

Commit d9cbd74

Browse files
Charles Liecnal-cienet
authored andcommitted
Fix unit test errors
1 parent 37ded59 commit d9cbd74

7 files changed

Lines changed: 611 additions & 40 deletions

File tree

src/maxtext/configs/base.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1086,8 +1086,8 @@ position_id_per_seconds: 25
10861086
subslice_shape: ""
10871087

10881088
# NNX
1089-
enable_nnx: false
1090-
pure_nnx_decoder: false
1089+
enable_nnx: True
1090+
pure_nnx_decoder: True
10911091

10921092
################################## Qwen3-Next Specific Configs ##################################
10931093
# Kernel size for the 1D convolution in the Gated Delta Net

src/maxtext/layers/attentions.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -533,14 +533,14 @@ def __init__(
533533
elif self.is_qwen3_next:
534534
self.query_norm = Qwen3NextRMSNorm(
535535
num_features=self.config.head_dim,
536-
eps=self.config.normalization_layer_epsilon,
536+
epsilon=self.config.normalization_layer_epsilon,
537537
dtype=self.config.dtype,
538538
weight_dtype=self.config.weight_dtype,
539539
rngs=self.rngs,
540540
)
541541
self.key_norm = Qwen3NextRMSNorm(
542542
num_features=self.config.head_dim,
543-
eps=self.config.normalization_layer_epsilon,
543+
epsilon=self.config.normalization_layer_epsilon,
544544
dtype=self.config.dtype,
545545
weight_dtype=self.config.weight_dtype,
546546
rngs=self.rngs,

src/maxtext/layers/nnx_decoders.py

Lines changed: 51 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -311,7 +311,7 @@ def __init__(
311311

312312
num_moe = config.num_decoder_layers - config.first_num_dense_layers
313313

314-
self.moe_layer = self._create_scanned_layers(moe_cls, length=num_moe, rngs=rngs)
314+
self.moe_layers = self._create_scanned_layers(moe_cls, length=num_moe, rngs=rngs)
315315
elif self.is_gemma3:
316316
attention_pattern_length = len(gemma3.GEMMA3_ATTENTION_PATTERN)
317317
scan_length = config.num_decoder_layers // attention_pattern_length
@@ -441,36 +441,27 @@ def _apply_layers_sequentially(self, layers, x_in, *args, length: int, **kwargs)
441441

442442
scan_axis = self.config.param_scan_axis
443443
if scan_axis != 0:
444-
# Move scan_axis to 0 so scan can iterate over it
445444
params = jax.tree.map(lambda x: jnp.moveaxis(x, scan_axis, 0), params)
446445

447446
layer_cls = layers.__class__
448447
sig = inspect.signature(layer_cls.__call__)
449448
valid_kwargs = {k: v for k, v in kwargs.items() if k in sig.parameters or "kwargs" in sig.parameters}
450449

451-
layer_cls = layers.__class__ # Access the underlying class
450+
layer_cls = layers.__class__
452451
sig = inspect.signature(layer_cls.__call__)
453-
# Filter kwargs to only include keys that exist in the layer's signature
454452
valid_kwargs = {k: v for k, v in kwargs.items() if k in sig.parameters or "kwargs" in sig.parameters}
455453

456454
def layer_fn(carry, scanned_vars):
457-
# Unpack the sliced variables for THIS layer
458455
current_params, current_state = scanned_vars
459456

460457
if self.config.parameter_memory_host_offload:
461458
current_params = jax.tree.map(lambda x: jax.device_put(x, max_utils.device_space()), current_params)
462459

463-
# Merge using the SLICED state
464460
layer = nnx.merge(graphdef, current_params, current_state)
465-
466-
# Run the layer (Filter kwargs if using the solution from previous turn)
467461
layer_out = layer(carry, *args, **valid_kwargs)
468-
469462
new_carry = layer_out[0] if isinstance(layer_out, tuple) else layer_out
470-
471-
# Extract the updated state to return it
472-
# _, new_current_state = nnx.split(layer, nnx.Param, ...)
473463
new_current_state = nnx.state(layer)
464+
474465
return new_carry, new_current_state
475466

476467
layer_fn = jax.checkpoint(layer_fn, policy=policy, prevent_cse=prevent_cse)
@@ -829,10 +820,19 @@ def _find_next_boundary(self, current_idx, end_idx, engram_indices):
829820
def _apply_single_engram_layer(self, y, current_idx, layer_stack, *args, **kwargs):
830821
"""Applies a single, unscanned Engram layer by dynamically slicing the NNX state."""
831822
graphdef, state = nnx.split(layer_stack)
823+
params, rest = state.split(nnx.Param, ...)
824+
scan_axis = self.config.param_scan_axis
825+
826+
# Helper to generate N-dimensional basic slices (e.g., x[:, idx, :])
827+
def _extract_slice(x, idx, axis):
828+
slices = tuple(idx if i == axis else slice(None) for i in range(x.ndim))
829+
return x[slices]
832830

833-
# Slice the parameters for the current index (assuming scan axis is 0)
834-
sliced_state = jax.tree.map(lambda x: x[current_idx], state)
835-
single_layer = nnx.merge(graphdef, sliced_state)
831+
# Slice using native indexing instead of jnp.take
832+
sliced_params = jax.tree.map(lambda x: _extract_slice(x, current_idx, scan_axis), params)
833+
sliced_rest = jax.tree.map(lambda x: _extract_slice(x, current_idx, 0), rest)
834+
835+
single_layer = nnx.merge(graphdef, sliced_params, sliced_rest)
836836

837837
# Run the single layer
838838
out = single_layer(
@@ -841,37 +841,57 @@ def _apply_single_engram_layer(self, y, current_idx, layer_stack, *args, **kwarg
841841
y = out[0] if isinstance(out, tuple) else out
842842

843843
# Re-merge the updated state back into the specific slice of the stack
844-
new_single_state = nnx.state(single_layer)
845-
updated_state = jax.tree.map(
844+
new_state = nnx.state(single_layer)
845+
new_params, new_rest = new_state.split(nnx.Param, ...)
846+
847+
updated_params = jax.tree.map(
848+
lambda s, new_s: jax.lax.dynamic_update_slice_in_dim(
849+
s, jnp.expand_dims(new_s, axis=scan_axis), current_idx, axis=scan_axis
850+
),
851+
params,
852+
new_params,
853+
)
854+
updated_rest = jax.tree.map(
846855
lambda s, new_s: jax.lax.dynamic_update_slice_in_dim(s, jnp.expand_dims(new_s, axis=0), current_idx, axis=0),
847-
state,
848-
new_single_state,
856+
rest,
857+
new_rest,
849858
)
850-
nnx.update(layer_stack, updated_state)
851859

860+
nnx.update(layer_stack, updated_params, updated_rest)
852861
return y
853862

854863
def _apply_scanned_chunk(self, y, current_idx, next_boundary, layer_stack, *args, **kwargs):
855864
"""Applies a contiguous chunk of layers using scan over a state slice."""
856865
scan_length = next_boundary - current_idx
857866
if scan_length > 0:
858867
graphdef, state = nnx.split(layer_stack)
868+
params, rest = state.split(nnx.Param, ...)
869+
scan_axis = self.config.param_scan_axis
859870

860-
# Slice the chunk state
861-
chunk_state = jax.tree.map(lambda x: jax.lax.dynamic_slice_in_dim(x, current_idx, scan_length, axis=0), state)
862-
chunk_stack = nnx.merge(graphdef, chunk_state)
871+
# Slice the chunk state along the correct axes
872+
chunk_params = jax.tree.map(
873+
lambda x: jax.lax.dynamic_slice_in_dim(x, current_idx, scan_length, axis=scan_axis), params
874+
)
875+
chunk_rest = jax.tree.map(lambda x: jax.lax.dynamic_slice_in_dim(x, current_idx, scan_length, axis=0), rest)
876+
chunk_stack = nnx.merge(graphdef, chunk_params, chunk_rest)
863877

864878
# Apply sequentially
865879
y, chunk_stack = self._apply_layers_sequentially(
866880
chunk_stack, y, *args, length=scan_length, **kwargs.get("layer_kwargs", {})
867881
)
868882

869883
# Update the original stack state
870-
new_chunk_state = nnx.state(chunk_stack)
871-
updated_state = jax.tree.map(
872-
lambda s, new_s: jax.lax.dynamic_update_slice_in_dim(s, new_s, current_idx, axis=0), state, new_chunk_state
884+
new_state = nnx.state(chunk_stack)
885+
new_params, new_rest = new_state.split(nnx.Param, ...)
886+
887+
updated_params = jax.tree.map(
888+
lambda s, new_s: jax.lax.dynamic_update_slice_in_dim(s, new_s, current_idx, axis=scan_axis), params, new_params
889+
)
890+
updated_rest = jax.tree.map(
891+
lambda s, new_s: jax.lax.dynamic_update_slice_in_dim(s, new_s, current_idx, axis=0), rest, new_rest
873892
)
874-
nnx.update(layer_stack, updated_state)
893+
894+
nnx.update(layer_stack, updated_params, updated_rest)
875895

876896
return y
877897

@@ -961,7 +981,7 @@ def __call__(
961981

962982
y = self._apply_interleaved_scanned_layers(
963983
y,
964-
self.moe_layer,
984+
self.moe_layers,
965985
0,
966986
(cfg.num_decoder_layers - cfg.first_num_dense_layers),
967987
[e - cfg.first_num_dense_layers for e in cfg.engram_layers],
@@ -978,7 +998,7 @@ def __call__(
978998
if cfg.use_batch_split_schedule:
979999
policy = self.get_remat_policy()
9801000

981-
mock_params = self._build_linen_params(self.moe_layer)
1001+
mock_params = self._build_linen_params(self.moe_layers)
9821002

9831003
y = deepseek_batchsplit.scan_batch_split_layers(
9841004
y,
@@ -992,8 +1012,8 @@ def __call__(
9921012
policy=policy,
9931013
)
9941014
else:
995-
y, self.moe_layer = self._apply_layers_sequentially(
996-
self.moe_layer, y, *layer_args, length=num_moe, **layer_kwargs
1015+
y, self.moe_layers = self._apply_layers_sequentially(
1016+
self.moe_layers, y, *layer_args, length=num_moe, **layer_kwargs
9971017
)
9981018
elif self.is_gemma3:
9991019
y = self._apply_gemma3_scanned_blocks(

src/maxtext/layers/normalizations.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,17 @@ def __call__(self, x: jnp.ndarray, out_sharding: NamedSharding | None = None) ->
102102
return y_flat.reshape(input_shape)
103103

104104

105-
def Qwen3NextRMSNorm(num_features: int, eps: float, dtype: DType, weight_dtype: DType, *, rngs: nnx.Rngs):
105+
def Qwen3NextRMSNorm(
106+
num_features: int,
107+
epsilon: float,
108+
dtype: DType,
109+
weight_dtype: DType,
110+
shard_mode: ShardMode = ShardMode.AUTO,
111+
kernel_axes: tuple[None | str, ...] = (),
112+
parameter_memory_host_offload: bool = False,
113+
*,
114+
rngs: nnx.Rngs,
115+
):
106116
"""
107117
Used for input and post attention layernorms
108118
in Qwen3NextDecoderLayer.
@@ -115,10 +125,13 @@ def Qwen3NextRMSNorm(num_features: int, eps: float, dtype: DType, weight_dtype:
115125
return nnx.data(
116126
RMSNorm(
117127
num_features=num_features,
118-
epsilon=eps,
128+
epsilon=epsilon,
119129
dtype=dtype,
120130
weight_dtype=weight_dtype,
131+
shard_mode=shard_mode,
132+
kernel_axes=kernel_axes,
121133
scale_init=linen_initializers.zeros,
134+
parameter_memory_host_offload=parameter_memory_host_offload,
122135
scale_offset=1.0,
123136
rngs=rngs,
124137
)

src/maxtext/models/qwen3.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -962,7 +962,7 @@ def __init__(
962962
# First LayerNorm, applied before the attention block.
963963
self.input_layernorm = Qwen3NextRMSNorm(
964964
num_features=cfg.emb_dim,
965-
eps=cfg.normalization_layer_epsilon,
965+
epsilon=cfg.normalization_layer_epsilon,
966966
dtype=cfg.dtype,
967967
weight_dtype=cfg.weight_dtype,
968968
rngs=rngs,
@@ -987,7 +987,7 @@ def __init__(
987987
# Second LayerNorm, applied before the MoE block.
988988
self.post_attention_layernorm = Qwen3NextRMSNorm(
989989
num_features=cfg.emb_dim,
990-
eps=cfg.normalization_layer_epsilon,
990+
epsilon=cfg.normalization_layer_epsilon,
991991
dtype=cfg.dtype,
992992
weight_dtype=cfg.weight_dtype,
993993
rngs=rngs,

0 commit comments

Comments
 (0)