2929from flax import nnx
3030from flax .nnx import wrappers as nnx_wrappers
3131
32- from MaxText .configs .types import PositionalEmbedding
3332from MaxText .common_types import DecoderBlockType , ShardMode , Config , EP_AS_CONTEXT
3433from MaxText .common_types import MODEL_MODE_TRAIN , MODEL_MODE_PREFILL , MODEL_MODE_AUTOREGRESSIVE
3534from MaxText .sharding import create_sharding
3635from MaxText .layers import linears
3736from MaxText .layers import initializers
3837from MaxText .layers import quantizations
39- from MaxText import multimodal_utils
4038from MaxText import sharding
4139from MaxText .layers .attentions import Attention
4240from MaxText .layers .normalizations import RMSNorm
43- from MaxText .layers .embeddings import Embed , attend_on_embedding
41+ from MaxText .layers .embeddings import Embed , attend_on_embedding , PositionalEmbedding
4442from MaxText .layers .quantizations import AqtQuantization as Quant
4543from MaxText .layers import (
4644 deepseek ,
6159from maxtext .inference import page_manager
6260from maxtext .utils import max_logging
6361from maxtext .utils import maxtext_utils
62+ from maxtext .multimodal import utils as mm_utils
6463
6564# ------------------------------------------------------------------------------
6665# The network: Decoder Definitions
@@ -195,10 +194,10 @@ def __call__(
195194 layer_output = _maybe_shard_with_logical (layer_output , logical_axis_names )
196195
197196 if cfg .record_internal_nn_metrics :
198- self .sow ("intermediates" , "activation_mean" , jnp .mean (layer_output ))
199- self .sow ("intermediates" , "activation_stdev" , jnp .std (layer_output ))
197+ self .sow (nnx . Intermediate , "activation_mean" , jnp .mean (layer_output ))
198+ self .sow (nnx . Intermediate , "activation_stdev" , jnp .std (layer_output ))
200199 self .sow (
201- "intermediates" ,
200+ nnx . Intermediate ,
202201 "activation_fraction_zero" ,
203202 jnp .sum (layer_output == 0 ) / jnp .size (layer_output ),
204203 )
@@ -284,19 +283,28 @@ def __init__(
284283 attention_pattern_length = len (gemma3 .GEMMA3_ATTENTION_PATTERN )
285284 scan_length = config .num_decoder_layers // attention_pattern_length
286285 num_remaining_layers = config .num_decoder_layers % attention_pattern_length
286+ layer_kwargs = {"num_of_layers" : attention_pattern_length }
287+
287288 rem_layer_kwargs = {"num_of_layers" : num_remaining_layers }
288289
289290 RemattedGemma3Block = gemma3 .Gemma3ScannableBlock
290291
291292 if scan_length > 0 :
292- self .layers = self ._create_scanned_layers (RemattedGemma3Block , length = scan_length , rngs = rngs )
293+ self .layers = self ._create_scanned_layers (RemattedGemma3Block , length = scan_length , rngs = rngs , ** layer_kwargs )
293294 self .layers_remainder = RemattedGemma3Block (
294295 config = self .config , mesh = mesh , quant = self .quant , model_mode = self .model_mode , ** rem_layer_kwargs , rngs = rngs
295296 ) # pytype: disable=wrong-keyword-args
296297 else :
297298 layer_cls = decoder_block_classes [0 ]
298- num_layers = config .num_decoder_layers
299- self .layers = self ._create_scanned_layers (layer_cls , length = num_layers , rngs = rngs )
299+ num_layers = int (config .num_decoder_layers / config .inhomogeneous_layer_cycle_interval )
300+ layer_kwargs = {}
301+ if config .decoder_block == DecoderBlockType .LLAMA4 :
302+ layer_kwargs = {
303+ "nope_layer_interval" : self .config .nope_layer_interval ,
304+ "interleave_moe_layer_step" : self .config .interleave_moe_layer_step ,
305+ }
306+
307+ self .layers = self ._create_scanned_layers (layer_cls , length = num_layers , rngs = rngs , ** layer_kwargs )
300308 else :
301309 self .layers = nnx .List ([])
302310 if self .is_deepseek :
@@ -309,6 +317,32 @@ def __init__(
309317 for i in range (config .num_decoder_layers ):
310318 self ._create_and_register_layer (layer_cls , rngs , "layers" , i )
311319
320+ self .layers = nnx .List ([])
321+
322+ if self .is_deepseek :
323+ dense_cls , moe_cls = decoder_block_classes
324+ for i in range (config .first_num_dense_layers ):
325+ self ._create_and_register_layer (dense_cls , rngs , "dense_layer" , i )
326+ for i in range (config .num_decoder_layers - config .first_num_dense_layers ):
327+ self ._create_and_register_layer (moe_cls , rngs , "moe_layer" , i )
328+ else :
329+ layer_cls = decoder_block_classes [0 ]
330+
331+ for i in range (config .num_decoder_layers ):
332+ layer_kwargs = {}
333+ if config .decoder_block == DecoderBlockType .GEMMA3 :
334+ layer_kwargs = {"attention_type" : gemma3 .get_attention_type (layer_id = i )}
335+ elif config .decoder_block == DecoderBlockType .LLAMA4 :
336+ layer_kwargs = {
337+ "is_nope_layer" : llama4 .determine_is_nope_layer (i , self .config .nope_layer_interval ),
338+ "is_moe_layer" : llama4 .determine_is_moe_layer (i , self .config .interleave_moe_layer_step ),
339+ }
340+ elif config .decoder_block == DecoderBlockType .QWEN3_NEXT :
341+ layer_kwargs = {"layer_idx" : i }
342+ elif config .decoder_block == DecoderBlockType .GPT_OSS :
343+ layer_kwargs = {"attention_type" : gpt_oss .get_attention_type (layer_id = i )}
344+ self ._create_and_register_layer (layer_cls , rngs , "layers" , i , ** layer_kwargs )
345+
312346 def _create_and_register_layer (self , layer_cls , rngs , base_name , i ):
313347 attr_name = f"{ base_name } _{ i } "
314348 layer = self ._create_single_layer (layer_cls , rngs )
@@ -346,12 +380,16 @@ def create_layer_fn(rng):
346380 except : # pylint: disable=bare-except
347381 pass
348382
383+ out_axes = nnx .StateAxes ({
384+ nnx .Param : self .config .param_scan_axis ,
385+ ...: 0
386+ })
349387 layers_vmapped = nnx .vmap (
350- create_layer_fn ,
351- in_axes = 0 ,
352- out_axes = 0 ,
353- axis_name = "layers" ,
354- transform_metadata = {nnx .PARTITION_NAME : "layers" },
388+ create_layer_fn ,
389+ in_axes = 0 ,
390+ out_axes = out_axes ,
391+ axis_name = "layers" ,
392+ transform_metadata = {nnx .PARTITION_NAME : "layers" },
355393 )(forked_rngs )
356394
357395 return layers_vmapped
@@ -364,9 +402,17 @@ def _apply_layers_sequentially(self, layers, x_in, *args, length: int, **kwargs)
364402 layers , nnx .Param , ...
365403 ) # state: the mutable state we carry (KV cache, RNGs, etc.)
366404
367- layer_cls = layers .__class__ # Access the underlying class
405+ scan_axis = self .config .param_scan_axis
406+ if scan_axis != 0 :
407+ # Move scan_axis to 0 so scan can iterate over it
408+ params = jax .tree .map (lambda x : jnp .moveaxis (x , scan_axis , 0 ), params )
409+
410+ layer_cls = layers .__class__
368411 sig = inspect .signature (layer_cls .__call__ )
412+ valid_kwargs = {k : v for k , v in kwargs .items () if k in sig .parameters or "kwargs" in sig .parameters }
369413
414+ layer_cls = layers .__class__ # Access the underlying class
415+ sig = inspect .signature (layer_cls .__call__ )
370416 # Filter kwargs to only include keys that exist in the layer's signature
371417 valid_kwargs = {k : v for k , v in kwargs .items () if k in sig .parameters or "kwargs" in sig .parameters }
372418
@@ -391,6 +437,11 @@ def layer_fn(carry, scanned_vars):
391437
392438 final_carry , scanned_state = jax .lax .scan (layer_fn , x_in , (params , state ))
393439
440+ if scan_axis != 0 :
441+ scanned_params , scanned_other = scanned_state .split (nnx .Param , ...)
442+ scanned_params = jax .tree .map (lambda x : jnp .moveaxis (x , 0 , scan_axis ), scanned_params )
443+ scanned_state = nnx .State .merge (scanned_params , scanned_other )
444+
394445 return final_carry , nnx .merge (graphdef , scanned_state )
395446
396447 def get_decoder_layers (self ):
@@ -584,7 +635,7 @@ def _apply_embedding(
584635 "llama4-17b-128e" ,
585636 "qwen3-omni-30b-a3b" ,
586637 ]:
587- y = multimodal_utils .merge_mm_embeddings (
638+ y = mm_utils .merge_mm_embeddings (
588639 text_embeddings = y ,
589640 multimodal_embeddings = image_embeddings ,
590641 mask = bidirectional_mask ,
@@ -596,7 +647,7 @@ def _apply_embedding(
596647
597648 if audio_embeddings is not None and cfg .use_audio :
598649 if cfg .model_name in ["qwen3-omni-30b-a3b" ]:
599- y = multimodal_utils .merge_mm_embeddings (
650+ y = mm_utils .merge_mm_embeddings (
600651 text_embeddings = y ,
601652 multimodal_embeddings = audio_embeddings ,
602653 mask = audio_masks ,
@@ -609,7 +660,7 @@ def _apply_embedding(
609660 y = y .astype (cfg .dtype )
610661
611662 if cfg .use_untrainable_positional_embedding :
612- y = self .positional_embedding (y , decoder_positions )
663+ y + = self .positional_embedding (y , decoder_positions )
613664
614665 if cfg .trainable_position_size > 0 and self .position_embedder :
615666 y += self .position_embedder (decoder_positions .astype ("int32" ), model_mode = model_mode )
@@ -625,7 +676,7 @@ def apply_output_head(self, shared_embedding, y, deterministic, model_mode):
625676 else :
626677 norm_out_sharding = None
627678
628- y = self .decoder_norm (y , norm_out_sharding )
679+ y = self .decoder_norm (y , out_sharding = norm_out_sharding )
629680 y = self .dropout (y , deterministic = deterministic ) # NNX call
630681
631682 if model_mode in (MODEL_MODE_PREFILL , MODEL_MODE_AUTOREGRESSIVE ):
@@ -693,19 +744,18 @@ def __call__(
693744 audio_masks ,
694745 )
695746 layer_args = (decoder_segment_ids , decoder_positions , deterministic , model_mode )
696-
697- layer_kwargs = {
698- "previous_chunk" : previous_chunk ,
699- "page_state" : page_state ,
700- "slot" : slot ,
701- "attention_metadata" : attention_metadata ,
702- }
703-
747+
748+ layer_kwargs = {}
704749 if cfg .decoder_block == DecoderBlockType .GEMMA3 :
705750 layer_kwargs ["bidirectional_mask" ] = bidirectional_mask
706751
707752 if cfg .scan_layers :
708753 if self .is_deepseek :
754+ layer_kwargs = {
755+ "previous_chunk" : previous_chunk ,
756+ "page_state" : page_state ,
757+ "slot" : slot ,
758+ }
709759 y , self .dense_layers = self ._apply_layers_sequentially (
710760 self .dense_layers , y , * layer_args , length = cfg .first_num_dense_layers , ** layer_kwargs
711761 )
@@ -733,8 +783,24 @@ def __call__(
733783 else :
734784 for i , layer in enumerate (self .layers ):
735785 kv_cache = kv_caches [i ] if kv_caches is not None else None
786+
787+ layer_call_kwargs = {}
788+ if cfg .decoder_block == DecoderBlockType .GEMMA3 :
789+ layer_call_kwargs = {"bidirectional_mask" : bidirectional_mask }
736790
737- out = layer (y , * layer_args , kv_cache = kv_cache , ** layer_kwargs )
791+ out = layer (
792+ y ,
793+ decoder_segment_ids ,
794+ decoder_positions ,
795+ deterministic ,
796+ model_mode ,
797+ previous_chunk = previous_chunk ,
798+ page_state = page_state ,
799+ slot = slot ,
800+ kv_cache = kv_cache ,
801+ attention_metadata = attention_metadata ,
802+ ** layer_call_kwargs
803+ )
738804
739805 if isinstance (out , tuple ):
740806 y , kv_cache_out = out
@@ -775,17 +841,12 @@ def _apply_gemma3_scanned_blocks(
775841 attention_pattern_length = len (gemma3 .GEMMA3_ATTENTION_PATTERN )
776842 scan_length = cfg .num_decoder_layers // attention_pattern_length
777843
778- layer_call_kwargs = {"bidirectional_mask" : bidirectional_mask }
844+ layer_args = (decoder_segment_ids , decoder_positions , deterministic , model_mode )
845+ layer_kwargs = {"bidirectional_mask" : bidirectional_mask }
779846
780847 # Apply the main scan over the full blocks
781848 if scan_length > 0 :
782- broadcast_args = (
783- decoder_segment_ids ,
784- decoder_positions ,
785- deterministic ,
786- model_mode ,
787- )
788- y , _ = self .layers (y , * broadcast_args , ** layer_call_kwargs )
849+ y , self .layers = self ._apply_layers_sequentially (self .layers , y , * layer_args , length = scan_length , ** layer_kwargs )
789850
790851 # Apply any remaining layers that did not fit into a full scanned block
791852 num_remaining_layers = cfg .num_decoder_layers % attention_pattern_length
@@ -800,8 +861,9 @@ def _apply_gemma3_scanned_blocks(
800861 previous_chunk = previous_chunk ,
801862 page_state = page_state ,
802863 slot = slot ,
803- ** layer_call_kwargs ,
864+ ** layer_kwargs ,
804865 )
866+
805867 return y
806868
807869
0 commit comments