3333from maxtext .layers import mhc
3434from maxtext .layers import normalizations
3535from maxtext .layers import pipeline
36+ from maxtext .layers .nnx_decoders import NNXDecoderLayer , NNXSequentialPipelineStage , NNXScannedPipelineStage
3637from maxtext .layers import quantizations
3738from maxtext .layers .attentions import attention_as_linen
3839from maxtext .layers .embeddings import attend_on_embedding , embed_as_linen , positional_embedding_as_linen
@@ -260,7 +261,7 @@ def __call__(
260261 slot = slot ,
261262 )
262263 if self .config .scan_layers :
263- inputs = inputs [0 ] # When scan_layers is True the decoder layers return (outputs, None).
264+ inputs = inputs [0 ] # When scan_layers is True the decoder layers return (outputs, None).
264265 if self .config .scan_layers :
265266 return inputs , None # pytype: disable=bad-return-type
266267 else :
@@ -305,10 +306,19 @@ def setup(self):
305306 self .decoder_layer = self .get_decoder_layers ()
306307 self .norm_layer = self .get_norm_layer (num_features = self .config .emb_dim )
307308 if self .config .using_pipeline_parallelism :
308- pipeline_stage_module = self .get_pipeline_stage_module (self .decoder_layer )
309309 remat_policy = self .get_remat_policy ()
310+ nnx_blocks = self ._get_nnx_decoder_block_classes ()
311+
312+ # Per-stage builder handed to create_pipeline: the pipeline transform invokes it
313+ # once per pipeline stage, passing that stage's rngs, to construct the stage's
314+ # decoder block(s). nnx_blocks (the selected decoder block classes) is captured from
315+ # this setup scope; remat_policy is applied separately by create_pipeline.
316+ def build_pipeline_stage_layers (rngs ):
317+ """Builds one pipeline stage module from the selected NNX decoder block classes."""
318+ return self ._build_nnx_pipeline_stage (nnx_blocks , rngs )
319+
310320 self .pipeline_module = pipeline .create_pipeline (
311- config = self .config , mesh = self .mesh , layers = pipeline_stage_module , remat_policy = remat_policy
321+ config = self .config , layers = build_pipeline_stage_layers , mesh = self .mesh , remat_policy = remat_policy
312322 )
313323
314324 def minimal_policy (self , with_context = False , with_quantization = False ):
@@ -498,6 +508,42 @@ def get_decoder_layers(self):
498508 # Default case to handle any unknown decoder block types.
499509 raise ValueError (f"Incorrect decoder_block name { self .config .decoder_block .value = } " )
500510
511+ def _get_nnx_decoder_block_classes (self ):
512+ """Returns NNX decoder block classes for pipeline stage creation."""
513+ cfg = self .config
514+
515+ def get_scannable (normal_cls , scannable_cls ):
516+ return [scannable_cls ] if cfg .scan_layers else [normal_cls ]
517+
518+ layer_map = {
519+ DecoderBlockType .DEFAULT : [NNXDecoderLayer ],
520+ DecoderBlockType .LLAMA2 : [llama2 .LlamaDecoderLayer ],
521+ DecoderBlockType .MISTRAL : [mistral .MistralDecoderLayer ],
522+ DecoderBlockType .MIXTRAL : [mixtral .MixtralDecoderLayer ],
523+ DecoderBlockType .GEMMA : [gemma .GemmaDecoderLayer ],
524+ DecoderBlockType .GEMMA2 : [gemma2 .Gemma2DecoderLayer ],
525+ DecoderBlockType .GEMMA3 : [gemma3 .Gemma3DecoderLayer ],
526+ DecoderBlockType .GEMMA4 : get_scannable (gemma4 .Gemma4DecoderLayer , gemma4 .Gemma4ScannableBlock ),
527+ DecoderBlockType .GEMMA4_SMALL : [gemma4_small .Gemma4SmallDecoderLayer ],
528+ DecoderBlockType .GPT3 : [gpt3 .Gpt3DecoderLayer ],
529+ DecoderBlockType .GPT_OSS : get_scannable (gpt_oss .GptOssDecoderLayer , gpt_oss .GptOssScannableBlock ),
530+ DecoderBlockType .QWEN2 : [qwen2 .Qwen2DecoderLayer ],
531+ DecoderBlockType .QWEN3 : [qwen3 .Qwen3DecoderLayer ],
532+ DecoderBlockType .QWEN3_MOE : [qwen3 .Qwen3MoeDecoderLayer ],
533+ DecoderBlockType .QWEN3_CUSTOM_MOE : [qwen3_custom .Qwen3CustomMoeDecoderLayer ],
534+ DecoderBlockType .QWEN3_5 : get_scannable (qwen3_5 .Qwen3_5DecoderLayer , qwen3_5 .Qwen3_5ScannableBlock ),
535+ DecoderBlockType .QWEN3_NEXT : get_scannable (qwen3 .Qwen3NextDecoderLayer , qwen3 .Qwen3NextScannableBlock ),
536+ DecoderBlockType .SIMPLE : [simple_layer .SimpleDecoderLayer ],
537+ DecoderBlockType .SIMPLE_MLP : [simple_layer .SimpleMlpDecoderLayer ],
538+ DecoderBlockType .DEEPSEEK : [deepseek .DeepSeekDenseLayer , deepseek .DeepSeekMoELayer ],
539+ DecoderBlockType .LLAMA4 : get_scannable (llama4 .Llama4DecoderLayer , llama4 .Llama4ScannableBlock ),
540+ DecoderBlockType .OLMO3 : get_scannable (olmo3 .Olmo3DecoderLayer , olmo3 .Olmo3ScannableBlock ),
541+ }
542+
543+ if cfg .decoder_block not in layer_map :
544+ raise ValueError (f"Incorrect decoder_block name { cfg .decoder_block .value = } " )
545+ return layer_map [cfg .decoder_block ]
546+
501547 def set_remat_policy (self , block_layers , policy ):
502548 """Set remat policy"""
503549 RemattedBlockLayers = []
@@ -526,6 +572,58 @@ def map_fn(path, value):
526572 RemattedBlockLayers .append (layer )
527573 return RemattedBlockLayers
528574
575+ def _build_nnx_pipeline_stage (self , decoder_blocks , rngs ):
576+ """Creates a single NNX pipeline stage module."""
577+ cfg = self .config
578+ base_stage_cls = decoder_blocks [1 ] if cfg .decoder_block == DecoderBlockType .DEEPSEEK else decoder_blocks [0 ]
579+
580+ if cfg .num_layers_per_pipeline_stage == 1 :
581+ return base_stage_cls (config = cfg , mesh = self .mesh , quant = self .quant , model_mode = self .model_mode , rngs = rngs )
582+ elif cfg .scan_layers_per_stage :
583+ return NNXScannedPipelineStage (
584+ base_stage_cls , cfg .num_layers_per_pipeline_stage , cfg , self .mesh , self .quant , self .model_mode , rngs = rngs
585+ )
586+ return NNXSequentialPipelineStage (
587+ base_stage_cls , cfg .num_layers_per_pipeline_stage , cfg , self .mesh , self .quant , self .model_mode , rngs = rngs
588+ )
589+
590+ def get_pipeline_stage_module (self , decoder_blocks ):
591+ """get pipeline stage module"""
592+
593+ def get_layer_to_pipeline (blocks , cfg ):
594+ if cfg .decoder_block == DecoderBlockType .DEEPSEEK :
595+ return blocks [1 ] # return the sparse block
596+ else :
597+ return blocks [0 ]
598+
599+ cfg = self .config
600+ base_stage = get_layer_to_pipeline (decoder_blocks , cfg )
601+ if cfg .set_remat_policy_on_layers_per_stage :
602+ policy = self .get_remat_policy ()
603+ base_stage = self .set_remat_policy ([base_stage ], policy )[0 ]
604+ if cfg .num_layers_per_pipeline_stage == 1 :
605+ stage_module = base_stage (config = cfg , mesh = self .mesh , quant = self .quant , model_mode = self .model_mode )
606+ elif cfg .scan_layers_per_stage :
607+ stage_module = self .scan_decoder_layers (
608+ cfg ,
609+ base_stage ,
610+ cfg .num_layers_per_pipeline_stage ,
611+ "layers_per_stage" ,
612+ self .mesh ,
613+ in_axes_tuple = (nn .broadcast ,) * 4 ,
614+ model_mode = self .model_mode ,
615+ )
616+ else :
617+ stage_module = SequentialBlockDecoderLayers (
618+ decoder_layer = base_stage ,
619+ num_decoder_layers = cfg .num_layers_per_pipeline_stage ,
620+ config = cfg ,
621+ mesh = self .mesh ,
622+ quant = self .quant ,
623+ model_mode = self .model_mode ,
624+ )
625+ return stage_module
626+
529627 def get_norm_layer (self , num_features : int ):
530628 """get normalization layer (return type inherits from nn.Module)"""
531629 if self .config .decoder_block in (
@@ -586,42 +684,6 @@ def scan_decoder_layers(self, cfg, decoder_layer, length, metadata_axis_name, me
586684 config = cfg , mesh = mesh , name = metadata_axis_name , quant = self .quant , ** kwargs # pytype: disable=wrong-keyword-args
587685 )
588686
589- def get_pipeline_stage_module (self , decoder_blocks ):
590- """get pipeline stage module"""
591-
592- def get_layer_to_pipeline (blocks , cfg ):
593- if cfg .decoder_block == DecoderBlockType .DEEPSEEK :
594- return blocks [1 ] # return the sparse block
595- else :
596- return blocks [0 ]
597-
598- cfg = self .config
599- base_stage = get_layer_to_pipeline (decoder_blocks , cfg )
600- if cfg .set_remat_policy_on_layers_per_stage :
601- policy = self .get_remat_policy ()
602- base_stage = self .set_remat_policy ([base_stage ], policy )[0 ]
603- if cfg .num_layers_per_pipeline_stage == 1 :
604- stage_module = base_stage (config = cfg , mesh = self .mesh , quant = self .quant , model_mode = self .model_mode )
605- elif cfg .scan_layers_per_stage :
606- stage_module = self .scan_decoder_layers (
607- cfg ,
608- base_stage ,
609- cfg .num_layers_per_pipeline_stage ,
610- "layers_per_stage" ,
611- self .mesh ,
612- in_axes_tuple = (nn .broadcast ,) * 4 ,
613- )
614- else :
615- stage_module = SequentialBlockDecoderLayers (
616- decoder_layer = base_stage ,
617- num_decoder_layers = cfg .num_layers_per_pipeline_stage ,
618- config = cfg ,
619- mesh = self .mesh ,
620- quant = self .quant ,
621- model_mode = self .model_mode ,
622- )
623- return stage_module
624-
625687 @nn .compact
626688 def _apply_embedding (
627689 self ,
@@ -1301,11 +1363,17 @@ def _apply_gemma4_scanned_blocks(
13011363 decoder_positions ,
13021364 deterministic ,
13031365 model_mode ,
1304- slot ,
1305- previous_chunk ,
1306- bidirectional_mask ,
13071366 )
13081367
1368+ # Pass slot/previous_chunk/bidirectional_mask by keyword only (via layer_call_kwargs),
1369+ # never positionally in broadcast_args: Gemma4DecoderLayer and Gemma4ScannableBlock
1370+ # declare slot and previous_chunk in swapped order, so positional passing misroutes them.
1371+ layer_call_kwargs = {
1372+ "slot" : slot ,
1373+ "previous_chunk" : previous_chunk ,
1374+ "bidirectional_mask" : bidirectional_mask ,
1375+ }
1376+
13091377 if num_full_blocks > 0 :
13101378 ScannableBlockToLinen = gemma4 .Gemma4ScannableBlockToLinen
13111379 policy = self .get_remat_policy ()
@@ -1335,7 +1403,7 @@ def _apply_gemma4_scanned_blocks(
13351403 num_of_layers = block_pattern_len ,
13361404 name = "scanned_blocks" ,
13371405 )(
1338- y , * broadcast_args
1406+ y , * broadcast_args , ** layer_call_kwargs
13391407 )
13401408
13411409 # Process any remaining layers that don't fit into a full scanned block
@@ -1349,7 +1417,7 @@ def _apply_gemma4_scanned_blocks(
13491417 attention_type = attention_type ,
13501418 layer_idx = layer_id ,
13511419 )
1352- y = layer (y , * broadcast_args )
1420+ y = layer (y , * broadcast_args , ** layer_call_kwargs )
13531421 if cfg .scan_layers :
13541422 y = y [0 ]
13551423
0 commit comments