Skip to content

Commit f26e160

Browse files
Migrate Decoder (Gemma3/Deepseek/Llama4) and utils to NNX
1 parent e67d913 commit f26e160

6 files changed

Lines changed: 285 additions & 47 deletions

File tree

src/MaxText/layers/gemma3.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,6 @@ def __init__(
9191

9292
batch_size, seq_len = max_utils.get_batch_seq_len_for_mode(config, model_mode)
9393
dummy_inputs_shape = (batch_size, seq_len, config.emb_dim)
94-
9594
self.pre_self_attention_norm = RMSNorm(
9695
num_features=config.emb_dim,
9796
dtype=config.dtype,
@@ -198,7 +197,6 @@ def __call__(
198197
inputs = inputs[0]
199198
inputs = nn.with_logical_constraint(inputs, self.activation_axis_names)
200199
inputs = checkpoint_name(inputs, "decoder_layer_input")
201-
202200
lnx = self.pre_self_attention_norm(inputs)
203201
lnx = nn.with_logical_constraint(lnx, self.activation_axis_names)
204202

src/MaxText/layers/nnx_decoders.py

Lines changed: 73 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -29,18 +29,16 @@
2929
from flax import nnx
3030
from flax.nnx import wrappers as nnx_wrappers
3131

32-
from MaxText.configs.types import PositionalEmbedding
3332
from MaxText.common_types import DecoderBlockType, ShardMode, Config, EP_AS_CONTEXT
3433
from MaxText.common_types import MODEL_MODE_TRAIN, MODEL_MODE_PREFILL, MODEL_MODE_AUTOREGRESSIVE
3534
from MaxText.sharding import create_sharding
3635
from MaxText.layers import linears
3736
from MaxText.layers import initializers
3837
from MaxText.layers import quantizations
39-
from MaxText import multimodal_utils
4038
from MaxText import sharding
4139
from MaxText.layers.attentions import Attention
4240
from MaxText.layers.normalizations import RMSNorm
43-
from MaxText.layers.embeddings import Embed, attend_on_embedding
41+
from MaxText.layers.embeddings import Embed, attend_on_embedding, PositionalEmbedding
4442
from MaxText.layers.quantizations import AqtQuantization as Quant
4543
from MaxText.layers import (
4644
deepseek,
@@ -61,6 +59,7 @@
6159
from maxtext.inference import page_manager
6260
from maxtext.utils import max_logging
6361
from maxtext.utils import maxtext_utils
62+
from maxtext.multimodal import utils as mm_utils
6463

6564
# ------------------------------------------------------------------------------
6665
# The network: Decoder Definitions
@@ -284,19 +283,28 @@ def __init__(
284283
attention_pattern_length = len(gemma3.GEMMA3_ATTENTION_PATTERN)
285284
scan_length = config.num_decoder_layers // attention_pattern_length
286285
num_remaining_layers = config.num_decoder_layers % attention_pattern_length
286+
layer_kwargs = {"num_of_layers": attention_pattern_length}
287+
287288
rem_layer_kwargs = {"num_of_layers": num_remaining_layers}
288289

289290
RemattedGemma3Block = gemma3.Gemma3ScannableBlock
290291

291292
if scan_length > 0:
292-
self.layers = self._create_scanned_layers(RemattedGemma3Block, length=scan_length, rngs=rngs)
293+
self.layers = self._create_scanned_layers(RemattedGemma3Block, length=scan_length, rngs=rngs, **layer_kwargs)
293294
self.layers_remainder = RemattedGemma3Block(
294295
config=self.config, mesh=mesh, quant=self.quant, model_mode=self.model_mode, **rem_layer_kwargs, rngs=rngs
295296
) # pytype: disable=wrong-keyword-args
296297
else:
297298
layer_cls = decoder_block_classes[0]
298-
num_layers = config.num_decoder_layers
299-
self.layers = self._create_scanned_layers(layer_cls, length=num_layers, rngs=rngs)
299+
num_layers = int(config.num_decoder_layers / config.inhomogeneous_layer_cycle_interval)
300+
layer_kwargs = {}
301+
if config.decoder_block == DecoderBlockType.LLAMA4:
302+
layer_kwargs = {
303+
"nope_layer_interval": self.config.nope_layer_interval,
304+
"interleave_moe_layer_step": self.config.interleave_moe_layer_step,
305+
}
306+
307+
self.layers = self._create_scanned_layers(layer_cls, length=num_layers, rngs=rngs, **layer_kwargs)
300308
else:
301309
self.layers = nnx.List([])
302310
if self.is_deepseek:
@@ -309,6 +317,32 @@ def __init__(
309317
for i in range(config.num_decoder_layers):
310318
self._create_and_register_layer(layer_cls, rngs, "layers", i)
311319

320+
self.layers = nnx.List([])
321+
322+
if self.is_deepseek:
323+
dense_cls, moe_cls = decoder_block_classes
324+
for i in range(config.first_num_dense_layers):
325+
self._create_and_register_layer(dense_cls, rngs, "dense_layer", i)
326+
for i in range(config.num_decoder_layers - config.first_num_dense_layers):
327+
self._create_and_register_layer(moe_cls, rngs, "moe_layer", i)
328+
else:
329+
layer_cls = decoder_block_classes[0]
330+
331+
for i in range(config.num_decoder_layers):
332+
layer_kwargs = {}
333+
if config.decoder_block == DecoderBlockType.GEMMA3:
334+
layer_kwargs = {"attention_type": gemma3.get_attention_type(layer_id=i)}
335+
elif config.decoder_block == DecoderBlockType.LLAMA4:
336+
layer_kwargs = {
337+
"is_nope_layer": llama4.determine_is_nope_layer(i, self.config.nope_layer_interval),
338+
"is_moe_layer": llama4.determine_is_moe_layer(i, self.config.interleave_moe_layer_step),
339+
}
340+
elif config.decoder_block == DecoderBlockType.QWEN3_NEXT:
341+
layer_kwargs = {"layer_idx": i}
342+
elif config.decoder_block == DecoderBlockType.GPT_OSS:
343+
layer_kwargs = {"attention_type": gpt_oss.get_attention_type(layer_id=i)}
344+
self._create_and_register_layer(layer_cls, rngs, "layers", i, **layer_kwargs)
345+
312346
def _create_and_register_layer(self, layer_cls, rngs, base_name, i):
313347
attr_name = f"{base_name}_{i}"
314348
layer = self._create_single_layer(layer_cls, rngs)
@@ -366,7 +400,6 @@ def _apply_layers_sequentially(self, layers, x_in, *args, length: int, **kwargs)
366400

367401
layer_cls = layers.__class__ # Access the underlying class
368402
sig = inspect.signature(layer_cls.__call__)
369-
370403
# Filter kwargs to only include keys that exist in the layer's signature
371404
valid_kwargs = {k: v for k, v in kwargs.items() if k in sig.parameters or "kwargs" in sig.parameters}
372405

@@ -584,7 +617,7 @@ def _apply_embedding(
584617
"llama4-17b-128e",
585618
"qwen3-omni-30b-a3b",
586619
]:
587-
y = multimodal_utils.merge_mm_embeddings(
620+
y = mm_utils.merge_mm_embeddings(
588621
text_embeddings=y,
589622
multimodal_embeddings=image_embeddings,
590623
mask=bidirectional_mask,
@@ -596,7 +629,7 @@ def _apply_embedding(
596629

597630
if audio_embeddings is not None and cfg.use_audio:
598631
if cfg.model_name in ["qwen3-omni-30b-a3b"]:
599-
y = multimodal_utils.merge_mm_embeddings(
632+
y = mm_utils.merge_mm_embeddings(
600633
text_embeddings=y,
601634
multimodal_embeddings=audio_embeddings,
602635
mask=audio_masks,
@@ -609,7 +642,7 @@ def _apply_embedding(
609642
y = y.astype(cfg.dtype)
610643

611644
if cfg.use_untrainable_positional_embedding:
612-
y = self.positional_embedding(y, decoder_positions)
645+
y += self.positional_embedding(y, decoder_positions)
613646

614647
if cfg.trainable_position_size > 0 and self.position_embedder:
615648
y += self.position_embedder(decoder_positions.astype("int32"), model_mode=model_mode)
@@ -625,7 +658,7 @@ def apply_output_head(self, shared_embedding, y, deterministic, model_mode):
625658
else:
626659
norm_out_sharding = None
627660

628-
y = self.decoder_norm(y, norm_out_sharding)
661+
y = self.decoder_norm(y, out_sharding=norm_out_sharding)
629662
y = self.dropout(y, deterministic=deterministic) # NNX call
630663

631664
if model_mode in (MODEL_MODE_PREFILL, MODEL_MODE_AUTOREGRESSIVE):
@@ -693,19 +726,18 @@ def __call__(
693726
audio_masks,
694727
)
695728
layer_args = (decoder_segment_ids, decoder_positions, deterministic, model_mode)
696-
697-
layer_kwargs = {
698-
"previous_chunk": previous_chunk,
699-
"page_state": page_state,
700-
"slot": slot,
701-
"attention_metadata": attention_metadata,
702-
}
703-
729+
730+
layer_kwargs = {}
704731
if cfg.decoder_block == DecoderBlockType.GEMMA3:
705732
layer_kwargs["bidirectional_mask"] = bidirectional_mask
706733

707734
if cfg.scan_layers:
708735
if self.is_deepseek:
736+
layer_kwargs = {
737+
"previous_chunk": previous_chunk,
738+
"page_state": page_state,
739+
"slot": slot,
740+
}
709741
y, self.dense_layers = self._apply_layers_sequentially(
710742
self.dense_layers, y, *layer_args, length=cfg.first_num_dense_layers, **layer_kwargs
711743
)
@@ -733,8 +765,24 @@ def __call__(
733765
else:
734766
for i, layer in enumerate(self.layers):
735767
kv_cache = kv_caches[i] if kv_caches is not None else None
768+
769+
layer_call_kwargs = {}
770+
if cfg.decoder_block == DecoderBlockType.GEMMA3:
771+
layer_call_kwargs = {"bidirectional_mask": bidirectional_mask}
736772

737-
out = layer(y, *layer_args, kv_cache=kv_cache, **layer_kwargs)
773+
out = layer(
774+
y,
775+
decoder_segment_ids,
776+
decoder_positions,
777+
deterministic,
778+
model_mode,
779+
previous_chunk=previous_chunk,
780+
page_state=page_state,
781+
slot=slot,
782+
kv_cache=kv_cache,
783+
attention_metadata=attention_metadata,
784+
**layer_call_kwargs
785+
)
738786

739787
if isinstance(out, tuple):
740788
y, kv_cache_out = out
@@ -775,17 +823,12 @@ def _apply_gemma3_scanned_blocks(
775823
attention_pattern_length = len(gemma3.GEMMA3_ATTENTION_PATTERN)
776824
scan_length = cfg.num_decoder_layers // attention_pattern_length
777825

778-
layer_call_kwargs = {"bidirectional_mask": bidirectional_mask}
826+
layer_args = (decoder_segment_ids, decoder_positions, deterministic, model_mode)
827+
layer_kwargs = {"bidirectional_mask": bidirectional_mask}
779828

780829
# Apply the main scan over the full blocks
781830
if scan_length > 0:
782-
broadcast_args = (
783-
decoder_segment_ids,
784-
decoder_positions,
785-
deterministic,
786-
model_mode,
787-
)
788-
y, _ = self.layers(y, *broadcast_args, **layer_call_kwargs)
831+
y, self.layers = self._apply_layers_sequentially(self.layers, y, *layer_args, length=scan_length, **layer_kwargs)
789832

790833
# Apply any remaining layers that did not fit into a full scanned block
791834
num_remaining_layers = cfg.num_decoder_layers % attention_pattern_length
@@ -800,8 +843,9 @@ def _apply_gemma3_scanned_blocks(
800843
previous_chunk=previous_chunk,
801844
page_state=page_state,
802845
slot=slot,
803-
**layer_call_kwargs,
846+
**layer_kwargs,
804847
)
848+
805849
return y
806850

807851

src/MaxText/utils/ckpt_conversion/to_maxtext.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -385,14 +385,22 @@ def _build_single_axis_stacked_tensor(
385385
The final, assembled NumPy array for the MaxText parameter.
386386
"""
387387
tensors_to_stack = []
388+
# Heuristic to determine if we are stacking layers or experts.
389+
# If the number of items to stack equals the number of layers, it's a standard
390+
# scanned layer, and we use the configured param_scan_axis. Otherwise, it's
391+
# an unscanned MoE layer, and we stack along the expert axis (0).
392+
"""
393+
axis_to_stack = config.param_scan_axis if len(hf_source_keys) == config.base_num_decoder_layers else 0
394+
"""
388395

389-
if config.scan_layers:
390-
# If it's a standard scanned layer, we use the configured param_scan_axis.
391-
axis_to_stack = config.param_scan_axis
396+
# Workaround to load the HF model due to mismatched tensor ordering
397+
if len(hf_source_keys) == config.base_num_decoder_layers:
398+
if getattr(config, "enable_nnx", False):
399+
axis_to_stack = 0
400+
else:
401+
axis_to_stack = config.param_scan_axis
392402
else:
393-
# Otherwise, if an unscanned MoE layer, and we stack along the expert axis (0).
394403
axis_to_stack = 0
395-
396404
# The hook function needs the shape of an individual slice, not the full stacked tensor.
397405
# We calculate it by removing the stacking dimension from the final target shape.
398406
mt_slice_shape_list = list(target_shape)

src/maxtext/configs/base.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -315,7 +315,7 @@ attention_out: 'remat'
315315

316316
optimizer_memory_host_offload: False
317317
parameter_memory_host_offload: False
318-
scan_layers: True # We recommend setting this to false when using pipeline parallelism, instead scanning the PP iterations.
318+
scan_layers: False # We recommend setting this to false when using pipeline parallelism, instead scanning the PP iterations.
319319
param_scan_axis: 1
320320

321321
# The attention parameter dictates the specific algorithm/methodology used to compute the attention scores
@@ -1049,8 +1049,8 @@ position_id_per_seconds: 25
10491049
subslice_shape: ""
10501050

10511051
# NNX
1052-
enable_nnx: false
1053-
pure_nnx_decoder: false
1052+
enable_nnx: True
1053+
pure_nnx_decoder: True
10541054

10551055
################################## Qwen3-Next Specific Configs ##################################
10561056
# Kernel size for the 1D convolution in the Gated Delta Net

0 commit comments

Comments
 (0)