Skip to content

Commit 60f0e87

Browse files
mesakhcienetecnal-cienet
authored andcommitted
feat: implement nnx-based pipeline
1 parent ee920f5 commit 60f0e87

5 files changed

Lines changed: 1755 additions & 469 deletions

File tree

src/maxtext/configs/base.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1161,6 +1161,8 @@ subslice_shape: ""
11611161
enable_nnx: True
11621162
pure_nnx_decoder: True
11631163
pure_nnx: False
1164+
use_nnx_pipeline: False # Set to False to use native Linen pipeline (with custom VJP)
1165+
11641166

11651167
################################## Qwen3-Next Specific Configs ##################################
11661168
# Kernel size for the 1D convolution in the Gated Delta Net

src/maxtext/configs/types.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -941,6 +941,12 @@ class PipelineParallelism(BaseModel):
941941
scan_layers_per_stage: bool = Field(False, description="Use jax.lax.scan over layers within a stage.")
942942
set_remat_policy_on_pipeline_iterations: bool = Field(True, description="Set remat policy on the pipeline scan.")
943943
set_remat_policy_on_layers_per_stage: bool = Field(False, description="Set remat policy on the inner layer scan.")
944+
use_nnx_pipeline: bool = Field(
945+
False,
946+
description="When True, create_pipeline returns NNX pipeline wrapped in ToLinen. "
947+
"When False, create_pipeline returns native Linen pipeline (PipelineLinen/CircularPipelineLinen). "
948+
"Pure NNX decoders use create_nnx_pipeline directly.",
949+
)
944950

945951

946952
class RematAndOffload(BaseModel):

src/maxtext/layers/decoders.py

Lines changed: 106 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
from maxtext.layers import mhc
3535
from maxtext.layers import normalizations
3636
from maxtext.layers import pipeline
37+
from maxtext.layers.nnx_decoders import NNXDecoderLayer, NNXSequentialPipelineStage, NNXScannedPipelineStage
3738
from maxtext.layers import quantizations
3839
from maxtext.layers.attentions import attention_as_linen
3940
from maxtext.layers.embeddings import attend_on_embedding, embed_as_linen, positional_embedding_as_linen
@@ -262,7 +263,7 @@ def __call__(
262263
page_state=page_state,
263264
)
264265
if self.config.scan_layers:
265-
inputs = inputs[0] # When scan_layers is True the decoder layers return (outputs, None).
266+
inputs = inputs[0] # When scan_layers is True the decoder layers return (outputs, None).
266267
if self.config.scan_layers:
267268
return inputs, None # pytype: disable=bad-return-type
268269
else:
@@ -307,11 +308,21 @@ def setup(self):
307308
self.decoder_layer = self.get_decoder_layers()
308309
self.norm_layer = self.get_norm_layer(num_features=self.config.emb_dim)
309310
if self.config.using_pipeline_parallelism:
310-
pipeline_stage_module = self.get_pipeline_stage_module(self.decoder_layer)
311311
remat_policy = self.get_remat_policy()
312-
self.pipeline_module = pipeline.create_pipeline(
313-
config=self.config, mesh=self.mesh, layers=pipeline_stage_module, remat_policy=remat_policy
314-
)
312+
if self.config.use_nnx_pipeline:
313+
nnx_blocks = self._get_nnx_decoder_block_classes()
314+
315+
def stage_factory(rngs):
316+
return self._build_nnx_pipeline_stage(nnx_blocks, rngs)
317+
318+
self.pipeline_module = pipeline.create_pipeline(
319+
config=self.config, layers=stage_factory, mesh=self.mesh, remat_policy=remat_policy
320+
)
321+
else:
322+
pipeline_stage_module = self.get_pipeline_stage_module(self.decoder_layer)
323+
self.pipeline_module = pipeline.create_pipeline(
324+
config=self.config, layers=pipeline_stage_module, mesh=self.mesh, remat_policy=remat_policy
325+
)
315326

316327
def minimal_policy(self, with_context=False, with_quantization=False):
317328
"""Helper for creating minimal checkpoint policies."""
@@ -494,6 +505,44 @@ def get_decoder_layers(self):
494505
# Default case to handle any unknown decoder block types.
495506
raise ValueError(f"Incorrect decoder_block name {self.config.decoder_block.value=}")
496507

508+
def _get_nnx_decoder_block_classes(self):
509+
"""Returns NNX decoder block classes for pipeline stage creation."""
510+
cfg = self.config
511+
512+
def get_scannable(normal_cls, scannable_cls):
513+
return [scannable_cls] if cfg.scan_layers else [normal_cls]
514+
515+
def get_deepseek():
516+
if cfg.use_batch_split_schedule:
517+
return [deepseek_batchsplit.DeepSeekDenseLayer, deepseek_batchsplit.DeepSeekMoELayer]
518+
return [deepseek.DeepSeekDenseLayer, deepseek.DeepSeekMoELayer]
519+
520+
layer_map = {
521+
DecoderBlockType.DEFAULT: [NNXDecoderLayer],
522+
DecoderBlockType.LLAMA2: [llama2.LlamaDecoderLayer],
523+
DecoderBlockType.MISTRAL: [mistral.MistralDecoderLayer],
524+
DecoderBlockType.MIXTRAL: [mixtral.MixtralDecoderLayer],
525+
DecoderBlockType.GEMMA: [gemma.GemmaDecoderLayer],
526+
DecoderBlockType.GEMMA2: [gemma2.Gemma2DecoderLayer],
527+
DecoderBlockType.GEMMA3: [gemma3.Gemma3DecoderLayer],
528+
DecoderBlockType.GEMMA4: get_scannable(gemma4.Gemma4DecoderLayer, gemma4.Gemma4ScannableBlock),
529+
DecoderBlockType.GPT3: [gpt3.Gpt3DecoderLayer],
530+
DecoderBlockType.GPT_OSS: get_scannable(gpt_oss.GptOssDecoderLayer, gpt_oss.GptOssScannableBlock),
531+
DecoderBlockType.QWEN2: [qwen2.Qwen2DecoderLayer],
532+
DecoderBlockType.QWEN3: [qwen3.Qwen3DecoderLayer],
533+
DecoderBlockType.QWEN3_MOE: [qwen3.Qwen3MoeDecoderLayer],
534+
DecoderBlockType.QWEN3_NEXT: get_scannable(qwen3.Qwen3NextDecoderLayer, qwen3.Qwen3NextScannableBlock),
535+
DecoderBlockType.SIMPLE: [simple_layer.SimpleDecoderLayer],
536+
DecoderBlockType.SIMPLE_MLP: [simple_layer.SimpleMlpDecoderLayer],
537+
DecoderBlockType.DEEPSEEK: get_deepseek(),
538+
DecoderBlockType.LLAMA4: get_scannable(llama4.Llama4DecoderLayer, llama4.Llama4ScannableBlock),
539+
DecoderBlockType.OLMO3: get_scannable(olmo3.Olmo3DecoderLayer, olmo3.Olmo3ScannableBlock),
540+
}
541+
542+
if cfg.decoder_block not in layer_map:
543+
raise ValueError(f"Incorrect decoder_block name {cfg.decoder_block.value=}")
544+
return layer_map[cfg.decoder_block]
545+
497546
def set_remat_policy(self, block_layers, policy):
498547
"""Set remat policy"""
499548
RemattedBlockLayers = []
@@ -522,6 +571,58 @@ def map_fn(path, value):
522571
RemattedBlockLayers.append(layer)
523572
return RemattedBlockLayers
524573

574+
def _build_nnx_pipeline_stage(self, decoder_blocks, rngs):
575+
"""Creates a single NNX pipeline stage module."""
576+
cfg = self.config
577+
base_stage_cls = decoder_blocks[1] if cfg.decoder_block == DecoderBlockType.DEEPSEEK else decoder_blocks[0]
578+
579+
if cfg.num_layers_per_pipeline_stage == 1:
580+
return base_stage_cls(config=cfg, mesh=self.mesh, quant=self.quant, model_mode=self.model_mode, rngs=rngs)
581+
elif cfg.scan_layers_per_stage:
582+
return NNXScannedPipelineStage(
583+
base_stage_cls, cfg.num_layers_per_pipeline_stage, cfg, self.mesh, self.quant, self.model_mode, rngs=rngs
584+
)
585+
return NNXSequentialPipelineStage(
586+
base_stage_cls, cfg.num_layers_per_pipeline_stage, cfg, self.mesh, self.quant, self.model_mode, rngs=rngs
587+
)
588+
589+
def get_pipeline_stage_module(self, decoder_blocks):
590+
"""get pipeline stage module"""
591+
592+
def get_layer_to_pipeline(blocks, cfg):
593+
if cfg.decoder_block == DecoderBlockType.DEEPSEEK:
594+
return blocks[1] # return the sparse block
595+
else:
596+
return blocks[0]
597+
598+
cfg = self.config
599+
base_stage = get_layer_to_pipeline(decoder_blocks, cfg)
600+
if cfg.set_remat_policy_on_layers_per_stage:
601+
policy = self.get_remat_policy()
602+
base_stage = self.set_remat_policy([base_stage], policy)[0]
603+
if cfg.num_layers_per_pipeline_stage == 1:
604+
stage_module = base_stage(config=cfg, mesh=self.mesh, quant=self.quant, model_mode=self.model_mode)
605+
elif cfg.scan_layers_per_stage:
606+
stage_module = self.scan_decoder_layers(
607+
cfg,
608+
base_stage,
609+
cfg.num_layers_per_pipeline_stage,
610+
"layers_per_stage",
611+
self.mesh,
612+
in_axes_tuple=(nn.broadcast,) * 4,
613+
model_mode=self.model_mode,
614+
)
615+
else:
616+
stage_module = SequentialBlockDecoderLayers(
617+
decoder_layer=base_stage,
618+
num_decoder_layers=cfg.num_layers_per_pipeline_stage,
619+
config=cfg,
620+
mesh=self.mesh,
621+
quant=self.quant,
622+
model_mode=self.model_mode,
623+
)
624+
return stage_module
625+
525626
def get_norm_layer(self, num_features: int):
526627
"""get normalization layer (return type inherits from nn.Module)"""
527628
if self.config.decoder_block in (
@@ -581,42 +682,6 @@ def scan_decoder_layers(self, cfg, decoder_layer, length, metadata_axis_name, me
581682
config=cfg, mesh=mesh, name=metadata_axis_name, quant=self.quant, **kwargs # pytype: disable=wrong-keyword-args
582683
)
583684

584-
def get_pipeline_stage_module(self, decoder_blocks):
585-
"""get pipeline stage module"""
586-
587-
def get_layer_to_pipeline(blocks, cfg):
588-
if cfg.decoder_block == DecoderBlockType.DEEPSEEK:
589-
return blocks[1] # return the sparse block
590-
else:
591-
return blocks[0]
592-
593-
cfg = self.config
594-
base_stage = get_layer_to_pipeline(decoder_blocks, cfg)
595-
if cfg.set_remat_policy_on_layers_per_stage:
596-
policy = self.get_remat_policy()
597-
base_stage = self.set_remat_policy([base_stage], policy)[0]
598-
if cfg.num_layers_per_pipeline_stage == 1:
599-
stage_module = base_stage(config=cfg, mesh=self.mesh, quant=self.quant, model_mode=self.model_mode)
600-
elif cfg.scan_layers_per_stage:
601-
stage_module = self.scan_decoder_layers(
602-
cfg,
603-
base_stage,
604-
cfg.num_layers_per_pipeline_stage,
605-
"layers_per_stage",
606-
self.mesh,
607-
in_axes_tuple=(nn.broadcast,) * 4,
608-
)
609-
else:
610-
stage_module = SequentialBlockDecoderLayers(
611-
decoder_layer=base_stage,
612-
num_decoder_layers=cfg.num_layers_per_pipeline_stage,
613-
config=cfg,
614-
mesh=self.mesh,
615-
quant=self.quant,
616-
model_mode=self.model_mode,
617-
)
618-
return stage_module
619-
620685
@nn.compact
621686
def _apply_embedding(
622687
self,

src/maxtext/layers/nnx_wrappers.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,22 @@ def current_linen_module() -> linen.Module | None:
170170
return None
171171

172172

173+
def is_linen_initializing() -> bool:
174+
"""Check if the current execution context is inside a Linen init() call.
175+
176+
Returns True when called from within a ``to_linen_class`` wrapper's
177+
``init()`` path. Uses :func:`current_linen_module` to access the Linen
178+
module stack (private API already used by this module).
179+
180+
This is used by NNX pipeline modules to short-circuit the full scan
181+
during Linen init, where only the output shape/dtype is needed.
182+
"""
183+
module = current_linen_module()
184+
if module is not None and hasattr(module, "is_initializing") and callable(module.is_initializing):
185+
return module.is_initializing()
186+
return False
187+
188+
173189
class ToNNX(Module):
174190
"""A wrapper to turn any Linen module into an NNX module.
175191

0 commit comments

Comments
 (0)