|
34 | 34 | from maxtext.layers import mhc |
35 | 35 | from maxtext.layers import normalizations |
36 | 36 | from maxtext.layers import pipeline |
| 37 | +from maxtext.layers.nnx_decoders import NNXDecoderLayer, NNXSequentialPipelineStage, NNXScannedPipelineStage |
37 | 38 | from maxtext.layers import quantizations |
38 | 39 | from maxtext.layers.attentions import attention_as_linen |
39 | 40 | from maxtext.layers.embeddings import attend_on_embedding, embed_as_linen, positional_embedding_as_linen |
@@ -262,7 +263,7 @@ def __call__( |
262 | 263 | page_state=page_state, |
263 | 264 | ) |
264 | 265 | if self.config.scan_layers: |
265 | | - inputs = inputs[0] # When scan_layers is True the decoder layers return (outputs, None). |
| 266 | + inputs = inputs[0] # When scan_layers is True the decoder layers return (outputs, None). |
266 | 267 | if self.config.scan_layers: |
267 | 268 | return inputs, None # pytype: disable=bad-return-type |
268 | 269 | else: |
@@ -307,11 +308,21 @@ def setup(self): |
307 | 308 | self.decoder_layer = self.get_decoder_layers() |
308 | 309 | self.norm_layer = self.get_norm_layer(num_features=self.config.emb_dim) |
309 | 310 | if self.config.using_pipeline_parallelism: |
310 | | - pipeline_stage_module = self.get_pipeline_stage_module(self.decoder_layer) |
311 | 311 | remat_policy = self.get_remat_policy() |
312 | | - self.pipeline_module = pipeline.create_pipeline( |
313 | | - config=self.config, mesh=self.mesh, layers=pipeline_stage_module, remat_policy=remat_policy |
314 | | - ) |
| 312 | + if self.config.use_nnx_pipeline: |
| 313 | + nnx_blocks = self._get_nnx_decoder_block_classes() |
| 314 | + |
| 315 | + def stage_factory(rngs): |
| 316 | + return self._build_nnx_pipeline_stage(nnx_blocks, rngs) |
| 317 | + |
| 318 | + self.pipeline_module = pipeline.create_pipeline( |
| 319 | + config=self.config, layers=stage_factory, mesh=self.mesh, remat_policy=remat_policy |
| 320 | + ) |
| 321 | + else: |
| 322 | + pipeline_stage_module = self.get_pipeline_stage_module(self.decoder_layer) |
| 323 | + self.pipeline_module = pipeline.create_pipeline( |
| 324 | + config=self.config, layers=pipeline_stage_module, mesh=self.mesh, remat_policy=remat_policy |
| 325 | + ) |
315 | 326 |
|
316 | 327 | def minimal_policy(self, with_context=False, with_quantization=False): |
317 | 328 | """Helper for creating minimal checkpoint policies.""" |
@@ -494,6 +505,44 @@ def get_decoder_layers(self): |
494 | 505 | # Default case to handle any unknown decoder block types. |
495 | 506 | raise ValueError(f"Incorrect decoder_block name {self.config.decoder_block.value=}") |
496 | 507 |
|
| 508 | + def _get_nnx_decoder_block_classes(self): |
| 509 | + """Returns NNX decoder block classes for pipeline stage creation.""" |
| 510 | + cfg = self.config |
| 511 | + |
| 512 | + def get_scannable(normal_cls, scannable_cls): |
| 513 | + return [scannable_cls] if cfg.scan_layers else [normal_cls] |
| 514 | + |
| 515 | + def get_deepseek(): |
| 516 | + if cfg.use_batch_split_schedule: |
| 517 | + return [deepseek_batchsplit.DeepSeekDenseLayer, deepseek_batchsplit.DeepSeekMoELayer] |
| 518 | + return [deepseek.DeepSeekDenseLayer, deepseek.DeepSeekMoELayer] |
| 519 | + |
| 520 | + layer_map = { |
| 521 | + DecoderBlockType.DEFAULT: [NNXDecoderLayer], |
| 522 | + DecoderBlockType.LLAMA2: [llama2.LlamaDecoderLayer], |
| 523 | + DecoderBlockType.MISTRAL: [mistral.MistralDecoderLayer], |
| 524 | + DecoderBlockType.MIXTRAL: [mixtral.MixtralDecoderLayer], |
| 525 | + DecoderBlockType.GEMMA: [gemma.GemmaDecoderLayer], |
| 526 | + DecoderBlockType.GEMMA2: [gemma2.Gemma2DecoderLayer], |
| 527 | + DecoderBlockType.GEMMA3: [gemma3.Gemma3DecoderLayer], |
| 528 | + DecoderBlockType.GEMMA4: get_scannable(gemma4.Gemma4DecoderLayer, gemma4.Gemma4ScannableBlock), |
| 529 | + DecoderBlockType.GPT3: [gpt3.Gpt3DecoderLayer], |
| 530 | + DecoderBlockType.GPT_OSS: get_scannable(gpt_oss.GptOssDecoderLayer, gpt_oss.GptOssScannableBlock), |
| 531 | + DecoderBlockType.QWEN2: [qwen2.Qwen2DecoderLayer], |
| 532 | + DecoderBlockType.QWEN3: [qwen3.Qwen3DecoderLayer], |
| 533 | + DecoderBlockType.QWEN3_MOE: [qwen3.Qwen3MoeDecoderLayer], |
| 534 | + DecoderBlockType.QWEN3_NEXT: get_scannable(qwen3.Qwen3NextDecoderLayer, qwen3.Qwen3NextScannableBlock), |
| 535 | + DecoderBlockType.SIMPLE: [simple_layer.SimpleDecoderLayer], |
| 536 | + DecoderBlockType.SIMPLE_MLP: [simple_layer.SimpleMlpDecoderLayer], |
| 537 | + DecoderBlockType.DEEPSEEK: get_deepseek(), |
| 538 | + DecoderBlockType.LLAMA4: get_scannable(llama4.Llama4DecoderLayer, llama4.Llama4ScannableBlock), |
| 539 | + DecoderBlockType.OLMO3: get_scannable(olmo3.Olmo3DecoderLayer, olmo3.Olmo3ScannableBlock), |
| 540 | + } |
| 541 | + |
| 542 | + if cfg.decoder_block not in layer_map: |
| 543 | + raise ValueError(f"Incorrect decoder_block name {cfg.decoder_block.value=}") |
| 544 | + return layer_map[cfg.decoder_block] |
| 545 | + |
497 | 546 | def set_remat_policy(self, block_layers, policy): |
498 | 547 | """Set remat policy""" |
499 | 548 | RemattedBlockLayers = [] |
@@ -522,6 +571,58 @@ def map_fn(path, value): |
522 | 571 | RemattedBlockLayers.append(layer) |
523 | 572 | return RemattedBlockLayers |
524 | 573 |
|
| 574 | + def _build_nnx_pipeline_stage(self, decoder_blocks, rngs): |
| 575 | + """Creates a single NNX pipeline stage module.""" |
| 576 | + cfg = self.config |
| 577 | + base_stage_cls = decoder_blocks[1] if cfg.decoder_block == DecoderBlockType.DEEPSEEK else decoder_blocks[0] |
| 578 | + |
| 579 | + if cfg.num_layers_per_pipeline_stage == 1: |
| 580 | + return base_stage_cls(config=cfg, mesh=self.mesh, quant=self.quant, model_mode=self.model_mode, rngs=rngs) |
| 581 | + elif cfg.scan_layers_per_stage: |
| 582 | + return NNXScannedPipelineStage( |
| 583 | + base_stage_cls, cfg.num_layers_per_pipeline_stage, cfg, self.mesh, self.quant, self.model_mode, rngs=rngs |
| 584 | + ) |
| 585 | + return NNXSequentialPipelineStage( |
| 586 | + base_stage_cls, cfg.num_layers_per_pipeline_stage, cfg, self.mesh, self.quant, self.model_mode, rngs=rngs |
| 587 | + ) |
| 588 | + |
| 589 | + def get_pipeline_stage_module(self, decoder_blocks): |
| 590 | + """get pipeline stage module""" |
| 591 | + |
| 592 | + def get_layer_to_pipeline(blocks, cfg): |
| 593 | + if cfg.decoder_block == DecoderBlockType.DEEPSEEK: |
| 594 | + return blocks[1] # return the sparse block |
| 595 | + else: |
| 596 | + return blocks[0] |
| 597 | + |
| 598 | + cfg = self.config |
| 599 | + base_stage = get_layer_to_pipeline(decoder_blocks, cfg) |
| 600 | + if cfg.set_remat_policy_on_layers_per_stage: |
| 601 | + policy = self.get_remat_policy() |
| 602 | + base_stage = self.set_remat_policy([base_stage], policy)[0] |
| 603 | + if cfg.num_layers_per_pipeline_stage == 1: |
| 604 | + stage_module = base_stage(config=cfg, mesh=self.mesh, quant=self.quant, model_mode=self.model_mode) |
| 605 | + elif cfg.scan_layers_per_stage: |
| 606 | + stage_module = self.scan_decoder_layers( |
| 607 | + cfg, |
| 608 | + base_stage, |
| 609 | + cfg.num_layers_per_pipeline_stage, |
| 610 | + "layers_per_stage", |
| 611 | + self.mesh, |
| 612 | + in_axes_tuple=(nn.broadcast,) * 4, |
| 613 | + model_mode=self.model_mode, |
| 614 | + ) |
| 615 | + else: |
| 616 | + stage_module = SequentialBlockDecoderLayers( |
| 617 | + decoder_layer=base_stage, |
| 618 | + num_decoder_layers=cfg.num_layers_per_pipeline_stage, |
| 619 | + config=cfg, |
| 620 | + mesh=self.mesh, |
| 621 | + quant=self.quant, |
| 622 | + model_mode=self.model_mode, |
| 623 | + ) |
| 624 | + return stage_module |
| 625 | + |
525 | 626 | def get_norm_layer(self, num_features: int): |
526 | 627 | """get normalization layer (return type inherits from nn.Module)""" |
527 | 628 | if self.config.decoder_block in ( |
@@ -581,42 +682,6 @@ def scan_decoder_layers(self, cfg, decoder_layer, length, metadata_axis_name, me |
581 | 682 | config=cfg, mesh=mesh, name=metadata_axis_name, quant=self.quant, **kwargs # pytype: disable=wrong-keyword-args |
582 | 683 | ) |
583 | 684 |
|
584 | | - def get_pipeline_stage_module(self, decoder_blocks): |
585 | | - """get pipeline stage module""" |
586 | | - |
587 | | - def get_layer_to_pipeline(blocks, cfg): |
588 | | - if cfg.decoder_block == DecoderBlockType.DEEPSEEK: |
589 | | - return blocks[1] # return the sparse block |
590 | | - else: |
591 | | - return blocks[0] |
592 | | - |
593 | | - cfg = self.config |
594 | | - base_stage = get_layer_to_pipeline(decoder_blocks, cfg) |
595 | | - if cfg.set_remat_policy_on_layers_per_stage: |
596 | | - policy = self.get_remat_policy() |
597 | | - base_stage = self.set_remat_policy([base_stage], policy)[0] |
598 | | - if cfg.num_layers_per_pipeline_stage == 1: |
599 | | - stage_module = base_stage(config=cfg, mesh=self.mesh, quant=self.quant, model_mode=self.model_mode) |
600 | | - elif cfg.scan_layers_per_stage: |
601 | | - stage_module = self.scan_decoder_layers( |
602 | | - cfg, |
603 | | - base_stage, |
604 | | - cfg.num_layers_per_pipeline_stage, |
605 | | - "layers_per_stage", |
606 | | - self.mesh, |
607 | | - in_axes_tuple=(nn.broadcast,) * 4, |
608 | | - ) |
609 | | - else: |
610 | | - stage_module = SequentialBlockDecoderLayers( |
611 | | - decoder_layer=base_stage, |
612 | | - num_decoder_layers=cfg.num_layers_per_pipeline_stage, |
613 | | - config=cfg, |
614 | | - mesh=self.mesh, |
615 | | - quant=self.quant, |
616 | | - model_mode=self.model_mode, |
617 | | - ) |
618 | | - return stage_module |
619 | | - |
620 | 685 | @nn.compact |
621 | 686 | def _apply_embedding( |
622 | 687 | self, |
|
0 commit comments