Skip to content

Commit 454e55f

Browse files
Merge pull request #2885 from CIeNET-International:test/pipeline-scan-nnx
PiperOrigin-RevId: 933577624
2 parents a264ca0 + 83dfd53 commit 454e55f

6 files changed

Lines changed: 1904 additions & 877 deletions

File tree

src/maxtext/layers/decoders.py

Lines changed: 112 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
from maxtext.layers import mhc
3434
from maxtext.layers import normalizations
3535
from maxtext.layers import pipeline
36+
from maxtext.layers.nnx_decoders import NNXDecoderLayer, NNXSequentialPipelineStage, NNXScannedPipelineStage
3637
from maxtext.layers import quantizations
3738
from maxtext.layers.attentions import attention_as_linen
3839
from maxtext.layers.embeddings import attend_on_embedding, embed_as_linen, positional_embedding_as_linen
@@ -260,7 +261,7 @@ def __call__(
260261
slot=slot,
261262
)
262263
if self.config.scan_layers:
263-
inputs = inputs[0] # When scan_layers is True the decoder layers return (outputs, None).
264+
inputs = inputs[0] # When scan_layers is True the decoder layers return (outputs, None).
264265
if self.config.scan_layers:
265266
return inputs, None # pytype: disable=bad-return-type
266267
else:
@@ -305,10 +306,19 @@ def setup(self):
305306
self.decoder_layer = self.get_decoder_layers()
306307
self.norm_layer = self.get_norm_layer(num_features=self.config.emb_dim)
307308
if self.config.using_pipeline_parallelism:
308-
pipeline_stage_module = self.get_pipeline_stage_module(self.decoder_layer)
309309
remat_policy = self.get_remat_policy()
310+
nnx_blocks = self._get_nnx_decoder_block_classes()
311+
312+
# Per-stage builder handed to create_pipeline: the pipeline transform invokes it
313+
# once per pipeline stage, passing that stage's rngs, to construct the stage's
314+
# decoder block(s). nnx_blocks (the selected decoder block classes) is captured from
315+
# this setup scope; remat_policy is applied separately by create_pipeline.
316+
def build_pipeline_stage_layers(rngs):
317+
"""Builds one pipeline stage module from the selected NNX decoder block classes."""
318+
return self._build_nnx_pipeline_stage(nnx_blocks, rngs)
319+
310320
self.pipeline_module = pipeline.create_pipeline(
311-
config=self.config, mesh=self.mesh, layers=pipeline_stage_module, remat_policy=remat_policy
321+
config=self.config, layers=build_pipeline_stage_layers, mesh=self.mesh, remat_policy=remat_policy
312322
)
313323

314324
def minimal_policy(self, with_context=False, with_quantization=False):
@@ -498,6 +508,42 @@ def get_decoder_layers(self):
498508
# Default case to handle any unknown decoder block types.
499509
raise ValueError(f"Incorrect decoder_block name {self.config.decoder_block.value=}")
500510

511+
def _get_nnx_decoder_block_classes(self):
512+
"""Returns NNX decoder block classes for pipeline stage creation."""
513+
cfg = self.config
514+
515+
def get_scannable(normal_cls, scannable_cls):
516+
return [scannable_cls] if cfg.scan_layers else [normal_cls]
517+
518+
layer_map = {
519+
DecoderBlockType.DEFAULT: [NNXDecoderLayer],
520+
DecoderBlockType.LLAMA2: [llama2.LlamaDecoderLayer],
521+
DecoderBlockType.MISTRAL: [mistral.MistralDecoderLayer],
522+
DecoderBlockType.MIXTRAL: [mixtral.MixtralDecoderLayer],
523+
DecoderBlockType.GEMMA: [gemma.GemmaDecoderLayer],
524+
DecoderBlockType.GEMMA2: [gemma2.Gemma2DecoderLayer],
525+
DecoderBlockType.GEMMA3: [gemma3.Gemma3DecoderLayer],
526+
DecoderBlockType.GEMMA4: get_scannable(gemma4.Gemma4DecoderLayer, gemma4.Gemma4ScannableBlock),
527+
DecoderBlockType.GEMMA4_SMALL: [gemma4_small.Gemma4SmallDecoderLayer],
528+
DecoderBlockType.GPT3: [gpt3.Gpt3DecoderLayer],
529+
DecoderBlockType.GPT_OSS: get_scannable(gpt_oss.GptOssDecoderLayer, gpt_oss.GptOssScannableBlock),
530+
DecoderBlockType.QWEN2: [qwen2.Qwen2DecoderLayer],
531+
DecoderBlockType.QWEN3: [qwen3.Qwen3DecoderLayer],
532+
DecoderBlockType.QWEN3_MOE: [qwen3.Qwen3MoeDecoderLayer],
533+
DecoderBlockType.QWEN3_CUSTOM_MOE: [qwen3_custom.Qwen3CustomMoeDecoderLayer],
534+
DecoderBlockType.QWEN3_5: get_scannable(qwen3_5.Qwen3_5DecoderLayer, qwen3_5.Qwen3_5ScannableBlock),
535+
DecoderBlockType.QWEN3_NEXT: get_scannable(qwen3.Qwen3NextDecoderLayer, qwen3.Qwen3NextScannableBlock),
536+
DecoderBlockType.SIMPLE: [simple_layer.SimpleDecoderLayer],
537+
DecoderBlockType.SIMPLE_MLP: [simple_layer.SimpleMlpDecoderLayer],
538+
DecoderBlockType.DEEPSEEK: [deepseek.DeepSeekDenseLayer, deepseek.DeepSeekMoELayer],
539+
DecoderBlockType.LLAMA4: get_scannable(llama4.Llama4DecoderLayer, llama4.Llama4ScannableBlock),
540+
DecoderBlockType.OLMO3: get_scannable(olmo3.Olmo3DecoderLayer, olmo3.Olmo3ScannableBlock),
541+
}
542+
543+
if cfg.decoder_block not in layer_map:
544+
raise ValueError(f"Incorrect decoder_block name {cfg.decoder_block.value=}")
545+
return layer_map[cfg.decoder_block]
546+
501547
def set_remat_policy(self, block_layers, policy):
502548
"""Set remat policy"""
503549
RemattedBlockLayers = []
@@ -526,6 +572,58 @@ def map_fn(path, value):
526572
RemattedBlockLayers.append(layer)
527573
return RemattedBlockLayers
528574

575+
def _build_nnx_pipeline_stage(self, decoder_blocks, rngs):
576+
"""Creates a single NNX pipeline stage module."""
577+
cfg = self.config
578+
base_stage_cls = decoder_blocks[1] if cfg.decoder_block == DecoderBlockType.DEEPSEEK else decoder_blocks[0]
579+
580+
if cfg.num_layers_per_pipeline_stage == 1:
581+
return base_stage_cls(config=cfg, mesh=self.mesh, quant=self.quant, model_mode=self.model_mode, rngs=rngs)
582+
elif cfg.scan_layers_per_stage:
583+
return NNXScannedPipelineStage(
584+
base_stage_cls, cfg.num_layers_per_pipeline_stage, cfg, self.mesh, self.quant, self.model_mode, rngs=rngs
585+
)
586+
return NNXSequentialPipelineStage(
587+
base_stage_cls, cfg.num_layers_per_pipeline_stage, cfg, self.mesh, self.quant, self.model_mode, rngs=rngs
588+
)
589+
590+
def get_pipeline_stage_module(self, decoder_blocks):
591+
"""get pipeline stage module"""
592+
593+
def get_layer_to_pipeline(blocks, cfg):
594+
if cfg.decoder_block == DecoderBlockType.DEEPSEEK:
595+
return blocks[1] # return the sparse block
596+
else:
597+
return blocks[0]
598+
599+
cfg = self.config
600+
base_stage = get_layer_to_pipeline(decoder_blocks, cfg)
601+
if cfg.set_remat_policy_on_layers_per_stage:
602+
policy = self.get_remat_policy()
603+
base_stage = self.set_remat_policy([base_stage], policy)[0]
604+
if cfg.num_layers_per_pipeline_stage == 1:
605+
stage_module = base_stage(config=cfg, mesh=self.mesh, quant=self.quant, model_mode=self.model_mode)
606+
elif cfg.scan_layers_per_stage:
607+
stage_module = self.scan_decoder_layers(
608+
cfg,
609+
base_stage,
610+
cfg.num_layers_per_pipeline_stage,
611+
"layers_per_stage",
612+
self.mesh,
613+
in_axes_tuple=(nn.broadcast,) * 4,
614+
model_mode=self.model_mode,
615+
)
616+
else:
617+
stage_module = SequentialBlockDecoderLayers(
618+
decoder_layer=base_stage,
619+
num_decoder_layers=cfg.num_layers_per_pipeline_stage,
620+
config=cfg,
621+
mesh=self.mesh,
622+
quant=self.quant,
623+
model_mode=self.model_mode,
624+
)
625+
return stage_module
626+
529627
def get_norm_layer(self, num_features: int):
530628
"""get normalization layer (return type inherits from nn.Module)"""
531629
if self.config.decoder_block in (
@@ -586,42 +684,6 @@ def scan_decoder_layers(self, cfg, decoder_layer, length, metadata_axis_name, me
586684
config=cfg, mesh=mesh, name=metadata_axis_name, quant=self.quant, **kwargs # pytype: disable=wrong-keyword-args
587685
)
588686

589-
def get_pipeline_stage_module(self, decoder_blocks):
590-
"""get pipeline stage module"""
591-
592-
def get_layer_to_pipeline(blocks, cfg):
593-
if cfg.decoder_block == DecoderBlockType.DEEPSEEK:
594-
return blocks[1] # return the sparse block
595-
else:
596-
return blocks[0]
597-
598-
cfg = self.config
599-
base_stage = get_layer_to_pipeline(decoder_blocks, cfg)
600-
if cfg.set_remat_policy_on_layers_per_stage:
601-
policy = self.get_remat_policy()
602-
base_stage = self.set_remat_policy([base_stage], policy)[0]
603-
if cfg.num_layers_per_pipeline_stage == 1:
604-
stage_module = base_stage(config=cfg, mesh=self.mesh, quant=self.quant, model_mode=self.model_mode)
605-
elif cfg.scan_layers_per_stage:
606-
stage_module = self.scan_decoder_layers(
607-
cfg,
608-
base_stage,
609-
cfg.num_layers_per_pipeline_stage,
610-
"layers_per_stage",
611-
self.mesh,
612-
in_axes_tuple=(nn.broadcast,) * 4,
613-
)
614-
else:
615-
stage_module = SequentialBlockDecoderLayers(
616-
decoder_layer=base_stage,
617-
num_decoder_layers=cfg.num_layers_per_pipeline_stage,
618-
config=cfg,
619-
mesh=self.mesh,
620-
quant=self.quant,
621-
model_mode=self.model_mode,
622-
)
623-
return stage_module
624-
625687
@nn.compact
626688
def _apply_embedding(
627689
self,
@@ -1301,11 +1363,17 @@ def _apply_gemma4_scanned_blocks(
13011363
decoder_positions,
13021364
deterministic,
13031365
model_mode,
1304-
slot,
1305-
previous_chunk,
1306-
bidirectional_mask,
13071366
)
13081367

1368+
# Pass slot/previous_chunk/bidirectional_mask by keyword only (via layer_call_kwargs),
1369+
# never positionally in broadcast_args: Gemma4DecoderLayer and Gemma4ScannableBlock
1370+
# declare slot and previous_chunk in swapped order, so positional passing misroutes them.
1371+
layer_call_kwargs = {
1372+
"slot": slot,
1373+
"previous_chunk": previous_chunk,
1374+
"bidirectional_mask": bidirectional_mask,
1375+
}
1376+
13091377
if num_full_blocks > 0:
13101378
ScannableBlockToLinen = gemma4.Gemma4ScannableBlockToLinen
13111379
policy = self.get_remat_policy()
@@ -1335,7 +1403,7 @@ def _apply_gemma4_scanned_blocks(
13351403
num_of_layers=block_pattern_len,
13361404
name="scanned_blocks",
13371405
)(
1338-
y, *broadcast_args
1406+
y, *broadcast_args, **layer_call_kwargs
13391407
)
13401408

13411409
# Process any remaining layers that don't fit into a full scanned block
@@ -1349,7 +1417,7 @@ def _apply_gemma4_scanned_blocks(
13491417
attention_type=attention_type,
13501418
layer_idx=layer_id,
13511419
)
1352-
y = layer(y, *broadcast_args)
1420+
y = layer(y, *broadcast_args, **layer_call_kwargs)
13531421
if cfg.scan_layers:
13541422
y = y[0]
13551423

0 commit comments

Comments
 (0)