Skip to content

Commit 2f30ac1

Browse files
Migrate Decoder (Gemma3/Deepseek/Llama4) and utils to NNX
1 parent e67d913 commit 2f30ac1

6 files changed

Lines changed: 311 additions & 55 deletions

File tree

src/MaxText/layers/gemma3.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,6 @@ def __init__(
9191

9292
batch_size, seq_len = max_utils.get_batch_seq_len_for_mode(config, model_mode)
9393
dummy_inputs_shape = (batch_size, seq_len, config.emb_dim)
94-
9594
self.pre_self_attention_norm = RMSNorm(
9695
num_features=config.emb_dim,
9796
dtype=config.dtype,
@@ -198,7 +197,6 @@ def __call__(
198197
inputs = inputs[0]
199198
inputs = nn.with_logical_constraint(inputs, self.activation_axis_names)
200199
inputs = checkpoint_name(inputs, "decoder_layer_input")
201-
202200
lnx = self.pre_self_attention_norm(inputs)
203201
lnx = nn.with_logical_constraint(lnx, self.activation_axis_names)
204202

src/MaxText/layers/nnx_decoders.py

Lines changed: 99 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -29,18 +29,16 @@
2929
from flax import nnx
3030
from flax.nnx import wrappers as nnx_wrappers
3131

32-
from MaxText.configs.types import PositionalEmbedding
3332
from MaxText.common_types import DecoderBlockType, ShardMode, Config, EP_AS_CONTEXT
3433
from MaxText.common_types import MODEL_MODE_TRAIN, MODEL_MODE_PREFILL, MODEL_MODE_AUTOREGRESSIVE
3534
from MaxText.sharding import create_sharding
3635
from MaxText.layers import linears
3736
from MaxText.layers import initializers
3837
from MaxText.layers import quantizations
39-
from MaxText import multimodal_utils
4038
from MaxText import sharding
4139
from MaxText.layers.attentions import Attention
4240
from MaxText.layers.normalizations import RMSNorm
43-
from MaxText.layers.embeddings import Embed, attend_on_embedding
41+
from MaxText.layers.embeddings import Embed, attend_on_embedding, PositionalEmbedding
4442
from MaxText.layers.quantizations import AqtQuantization as Quant
4543
from MaxText.layers import (
4644
deepseek,
@@ -61,6 +59,7 @@
6159
from maxtext.inference import page_manager
6260
from maxtext.utils import max_logging
6361
from maxtext.utils import maxtext_utils
62+
from maxtext.multimodal import utils as mm_utils
6463

6564
# ------------------------------------------------------------------------------
6665
# The network: Decoder Definitions
@@ -195,10 +194,10 @@ def __call__(
195194
layer_output = _maybe_shard_with_logical(layer_output, logical_axis_names)
196195

197196
if cfg.record_internal_nn_metrics:
198-
self.sow("intermediates", "activation_mean", jnp.mean(layer_output))
199-
self.sow("intermediates", "activation_stdev", jnp.std(layer_output))
197+
self.sow(nnx.Intermediate, "activation_mean", jnp.mean(layer_output))
198+
self.sow(nnx.Intermediate, "activation_stdev", jnp.std(layer_output))
200199
self.sow(
201-
"intermediates",
200+
nnx.Intermediate,
202201
"activation_fraction_zero",
203202
jnp.sum(layer_output == 0) / jnp.size(layer_output),
204203
)
@@ -284,19 +283,28 @@ def __init__(
284283
attention_pattern_length = len(gemma3.GEMMA3_ATTENTION_PATTERN)
285284
scan_length = config.num_decoder_layers // attention_pattern_length
286285
num_remaining_layers = config.num_decoder_layers % attention_pattern_length
286+
layer_kwargs = {"num_of_layers": attention_pattern_length}
287+
287288
rem_layer_kwargs = {"num_of_layers": num_remaining_layers}
288289

289290
RemattedGemma3Block = gemma3.Gemma3ScannableBlock
290291

291292
if scan_length > 0:
292-
self.layers = self._create_scanned_layers(RemattedGemma3Block, length=scan_length, rngs=rngs)
293+
self.layers = self._create_scanned_layers(RemattedGemma3Block, length=scan_length, rngs=rngs, **layer_kwargs)
293294
self.layers_remainder = RemattedGemma3Block(
294295
config=self.config, mesh=mesh, quant=self.quant, model_mode=self.model_mode, **rem_layer_kwargs, rngs=rngs
295296
) # pytype: disable=wrong-keyword-args
296297
else:
297298
layer_cls = decoder_block_classes[0]
298-
num_layers = config.num_decoder_layers
299-
self.layers = self._create_scanned_layers(layer_cls, length=num_layers, rngs=rngs)
299+
num_layers = int(config.num_decoder_layers / config.inhomogeneous_layer_cycle_interval)
300+
layer_kwargs = {}
301+
if config.decoder_block == DecoderBlockType.LLAMA4:
302+
layer_kwargs = {
303+
"nope_layer_interval": self.config.nope_layer_interval,
304+
"interleave_moe_layer_step": self.config.interleave_moe_layer_step,
305+
}
306+
307+
self.layers = self._create_scanned_layers(layer_cls, length=num_layers, rngs=rngs, **layer_kwargs)
300308
else:
301309
self.layers = nnx.List([])
302310
if self.is_deepseek:
@@ -309,6 +317,32 @@ def __init__(
309317
for i in range(config.num_decoder_layers):
310318
self._create_and_register_layer(layer_cls, rngs, "layers", i)
311319

320+
self.layers = nnx.List([])
321+
322+
if self.is_deepseek:
323+
dense_cls, moe_cls = decoder_block_classes
324+
for i in range(config.first_num_dense_layers):
325+
self._create_and_register_layer(dense_cls, rngs, "dense_layer", i)
326+
for i in range(config.num_decoder_layers - config.first_num_dense_layers):
327+
self._create_and_register_layer(moe_cls, rngs, "moe_layer", i)
328+
else:
329+
layer_cls = decoder_block_classes[0]
330+
331+
for i in range(config.num_decoder_layers):
332+
layer_kwargs = {}
333+
if config.decoder_block == DecoderBlockType.GEMMA3:
334+
layer_kwargs = {"attention_type": gemma3.get_attention_type(layer_id=i)}
335+
elif config.decoder_block == DecoderBlockType.LLAMA4:
336+
layer_kwargs = {
337+
"is_nope_layer": llama4.determine_is_nope_layer(i, self.config.nope_layer_interval),
338+
"is_moe_layer": llama4.determine_is_moe_layer(i, self.config.interleave_moe_layer_step),
339+
}
340+
elif config.decoder_block == DecoderBlockType.QWEN3_NEXT:
341+
layer_kwargs = {"layer_idx": i}
342+
elif config.decoder_block == DecoderBlockType.GPT_OSS:
343+
layer_kwargs = {"attention_type": gpt_oss.get_attention_type(layer_id=i)}
344+
self._create_and_register_layer(layer_cls, rngs, "layers", i, **layer_kwargs)
345+
312346
def _create_and_register_layer(self, layer_cls, rngs, base_name, i):
313347
attr_name = f"{base_name}_{i}"
314348
layer = self._create_single_layer(layer_cls, rngs)
@@ -346,12 +380,16 @@ def create_layer_fn(rng):
346380
except: # pylint: disable=bare-except
347381
pass
348382

383+
out_axes = nnx.StateAxes({
384+
nnx.Param: self.config.param_scan_axis,
385+
...: 0
386+
})
349387
layers_vmapped = nnx.vmap(
350-
create_layer_fn,
351-
in_axes=0,
352-
out_axes=0,
353-
axis_name="layers",
354-
transform_metadata={nnx.PARTITION_NAME: "layers"},
388+
create_layer_fn,
389+
in_axes=0,
390+
out_axes=out_axes,
391+
axis_name="layers",
392+
transform_metadata={nnx.PARTITION_NAME: "layers"},
355393
)(forked_rngs)
356394

357395
return layers_vmapped
@@ -364,9 +402,17 @@ def _apply_layers_sequentially(self, layers, x_in, *args, length: int, **kwargs)
364402
layers, nnx.Param, ...
365403
) # state: the mutable state we carry (KV cache, RNGs, etc.)
366404

367-
layer_cls = layers.__class__ # Access the underlying class
405+
scan_axis = self.config.param_scan_axis
406+
if scan_axis != 0:
407+
# Move scan_axis to 0 so scan can iterate over it
408+
params = jax.tree.map(lambda x: jnp.moveaxis(x, scan_axis, 0), params)
409+
410+
layer_cls = layers.__class__
368411
sig = inspect.signature(layer_cls.__call__)
412+
valid_kwargs = {k: v for k, v in kwargs.items() if k in sig.parameters or "kwargs" in sig.parameters}
369413

414+
layer_cls = layers.__class__ # Access the underlying class
415+
sig = inspect.signature(layer_cls.__call__)
370416
# Filter kwargs to only include keys that exist in the layer's signature
371417
valid_kwargs = {k: v for k, v in kwargs.items() if k in sig.parameters or "kwargs" in sig.parameters}
372418

@@ -391,6 +437,11 @@ def layer_fn(carry, scanned_vars):
391437

392438
final_carry, scanned_state = jax.lax.scan(layer_fn, x_in, (params, state))
393439

440+
if scan_axis != 0:
441+
scanned_params, scanned_other = scanned_state.split(nnx.Param, ...)
442+
scanned_params = jax.tree.map(lambda x: jnp.moveaxis(x, 0, scan_axis), scanned_params)
443+
scanned_state = nnx.State.merge(scanned_params, scanned_other)
444+
394445
return final_carry, nnx.merge(graphdef, scanned_state)
395446

396447
def get_decoder_layers(self):
@@ -584,7 +635,7 @@ def _apply_embedding(
584635
"llama4-17b-128e",
585636
"qwen3-omni-30b-a3b",
586637
]:
587-
y = multimodal_utils.merge_mm_embeddings(
638+
y = mm_utils.merge_mm_embeddings(
588639
text_embeddings=y,
589640
multimodal_embeddings=image_embeddings,
590641
mask=bidirectional_mask,
@@ -596,7 +647,7 @@ def _apply_embedding(
596647

597648
if audio_embeddings is not None and cfg.use_audio:
598649
if cfg.model_name in ["qwen3-omni-30b-a3b"]:
599-
y = multimodal_utils.merge_mm_embeddings(
650+
y = mm_utils.merge_mm_embeddings(
600651
text_embeddings=y,
601652
multimodal_embeddings=audio_embeddings,
602653
mask=audio_masks,
@@ -609,7 +660,7 @@ def _apply_embedding(
609660
y = y.astype(cfg.dtype)
610661

611662
if cfg.use_untrainable_positional_embedding:
612-
y = self.positional_embedding(y, decoder_positions)
663+
y += self.positional_embedding(y, decoder_positions)
613664

614665
if cfg.trainable_position_size > 0 and self.position_embedder:
615666
y += self.position_embedder(decoder_positions.astype("int32"), model_mode=model_mode)
@@ -625,7 +676,7 @@ def apply_output_head(self, shared_embedding, y, deterministic, model_mode):
625676
else:
626677
norm_out_sharding = None
627678

628-
y = self.decoder_norm(y, norm_out_sharding)
679+
y = self.decoder_norm(y, out_sharding=norm_out_sharding)
629680
y = self.dropout(y, deterministic=deterministic) # NNX call
630681

631682
if model_mode in (MODEL_MODE_PREFILL, MODEL_MODE_AUTOREGRESSIVE):
@@ -693,19 +744,18 @@ def __call__(
693744
audio_masks,
694745
)
695746
layer_args = (decoder_segment_ids, decoder_positions, deterministic, model_mode)
696-
697-
layer_kwargs = {
698-
"previous_chunk": previous_chunk,
699-
"page_state": page_state,
700-
"slot": slot,
701-
"attention_metadata": attention_metadata,
702-
}
703-
747+
748+
layer_kwargs = {}
704749
if cfg.decoder_block == DecoderBlockType.GEMMA3:
705750
layer_kwargs["bidirectional_mask"] = bidirectional_mask
706751

707752
if cfg.scan_layers:
708753
if self.is_deepseek:
754+
layer_kwargs = {
755+
"previous_chunk": previous_chunk,
756+
"page_state": page_state,
757+
"slot": slot,
758+
}
709759
y, self.dense_layers = self._apply_layers_sequentially(
710760
self.dense_layers, y, *layer_args, length=cfg.first_num_dense_layers, **layer_kwargs
711761
)
@@ -733,8 +783,24 @@ def __call__(
733783
else:
734784
for i, layer in enumerate(self.layers):
735785
kv_cache = kv_caches[i] if kv_caches is not None else None
786+
787+
layer_call_kwargs = {}
788+
if cfg.decoder_block == DecoderBlockType.GEMMA3:
789+
layer_call_kwargs = {"bidirectional_mask": bidirectional_mask}
736790

737-
out = layer(y, *layer_args, kv_cache=kv_cache, **layer_kwargs)
791+
out = layer(
792+
y,
793+
decoder_segment_ids,
794+
decoder_positions,
795+
deterministic,
796+
model_mode,
797+
previous_chunk=previous_chunk,
798+
page_state=page_state,
799+
slot=slot,
800+
kv_cache=kv_cache,
801+
attention_metadata=attention_metadata,
802+
**layer_call_kwargs
803+
)
738804

739805
if isinstance(out, tuple):
740806
y, kv_cache_out = out
@@ -775,17 +841,12 @@ def _apply_gemma3_scanned_blocks(
775841
attention_pattern_length = len(gemma3.GEMMA3_ATTENTION_PATTERN)
776842
scan_length = cfg.num_decoder_layers // attention_pattern_length
777843

778-
layer_call_kwargs = {"bidirectional_mask": bidirectional_mask}
844+
layer_args = (decoder_segment_ids, decoder_positions, deterministic, model_mode)
845+
layer_kwargs = {"bidirectional_mask": bidirectional_mask}
779846

780847
# Apply the main scan over the full blocks
781848
if scan_length > 0:
782-
broadcast_args = (
783-
decoder_segment_ids,
784-
decoder_positions,
785-
deterministic,
786-
model_mode,
787-
)
788-
y, _ = self.layers(y, *broadcast_args, **layer_call_kwargs)
849+
y, self.layers = self._apply_layers_sequentially(self.layers, y, *layer_args, length=scan_length, **layer_kwargs)
789850

790851
# Apply any remaining layers that did not fit into a full scanned block
791852
num_remaining_layers = cfg.num_decoder_layers % attention_pattern_length
@@ -800,8 +861,9 @@ def _apply_gemma3_scanned_blocks(
800861
previous_chunk=previous_chunk,
801862
page_state=page_state,
802863
slot=slot,
803-
**layer_call_kwargs,
864+
**layer_kwargs,
804865
)
866+
805867
return y
806868

807869

src/MaxText/utils/ckpt_conversion/to_maxtext.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -385,14 +385,22 @@ def _build_single_axis_stacked_tensor(
385385
The final, assembled NumPy array for the MaxText parameter.
386386
"""
387387
tensors_to_stack = []
388+
# Heuristic to determine if we are stacking layers or experts.
389+
# If the number of items to stack equals the number of layers, it's a standard
390+
# scanned layer, and we use the configured param_scan_axis. Otherwise, it's
391+
# an unscanned MoE layer, and we stack along the expert axis (0).
392+
"""
393+
axis_to_stack = config.param_scan_axis if len(hf_source_keys) == config.base_num_decoder_layers else 0
394+
"""
388395

389-
if config.scan_layers:
390-
# If it's a standard scanned layer, we use the configured param_scan_axis.
391-
axis_to_stack = config.param_scan_axis
396+
# Workaround to load the HF model due to mismatched tensor ordering
397+
if len(hf_source_keys) == config.base_num_decoder_layers:
398+
if getattr(config, "enable_nnx", False):
399+
axis_to_stack = 0
400+
else:
401+
axis_to_stack = config.param_scan_axis
392402
else:
393-
# Otherwise, if an unscanned MoE layer, and we stack along the expert axis (0).
394403
axis_to_stack = 0
395-
396404
# The hook function needs the shape of an individual slice, not the full stacked tensor.
397405
# We calculate it by removing the stacking dimension from the final target shape.
398406
mt_slice_shape_list = list(target_shape)

src/maxtext/configs/base.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -697,7 +697,7 @@ autoregressive_decode_assert: ""
697697

698698
# For nsys profiler, pass the training command to nsys command
699699
# e.g. nsys profile -s none --force-overwrite true --capture-range=cudaProfilerApi --capture-range-end=stop {training command}
700-
profiler: "" # Supported profiler: '', xplane, nsys
700+
profiler: "xplane" # Supported profiler: '', xplane, nsys
701701
# If set to true, upload all profiler results from all hosts. Otherwise, only upload the profiler result from the first host.
702702
upload_all_profiler_results: False
703703
# Skip first n steps for profiling, to omit things like compilation and to give
@@ -1049,8 +1049,8 @@ position_id_per_seconds: 25
10491049
subslice_shape: ""
10501050

10511051
# NNX
1052-
enable_nnx: false
1053-
pure_nnx_decoder: false
1052+
enable_nnx: True
1053+
pure_nnx_decoder: True
10541054

10551055
################################## Qwen3-Next Specific Configs ##################################
10561056
# Kernel size for the 1D convolution in the Gated Delta Net

0 commit comments

Comments
 (0)