Skip to content

Commit efce27d

Browse files
Migrate Decoder to NNX
1 parent 95ef3e1 commit efce27d

5 files changed

Lines changed: 865 additions & 16 deletions

File tree

src/MaxText/configs/base.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1035,6 +1035,7 @@ subslice_shape: ""
10351035

10361036
# NNX
10371037
enable_nnx: false
1038+
pure_nnx_decoder: false
10381039

10391040
################################## Qwen3-Next Specific Configs ##################################
10401041
# Kernel size for the 1D convolution in the Gated Delta Net

src/MaxText/configs/types.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -753,6 +753,7 @@ class HardwareAndMesh(BaseModel):
753753
enable_nnx: bool = Field(False, description="Whether to use NNX for model definition.")
754754
optimize_mesh_for_tpu_v6e: bool = Field(False, description="Apply transformations to the mesh for TPU v6e.")
755755
shardy: bool = Field(True, description="Whether to use shardy XLA backend.")
756+
pure_nnx_decoder: bool = Field(False, description="Whether to enable pure NNX decoder.")
756757

757758

758759
class LayoutAndSharding(BaseModel):

src/MaxText/layers/models.py

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from MaxText.common_types import Config, MODEL_MODE_TRAIN, MODEL_MODE_AUTOREGRESSIVE, DECODING_ACTIVE_SEQUENCE_INDICATOR
2929
from MaxText.layers import nnx_wrappers
3030
from MaxText.layers.decoders import Decoder
31+
from MaxText.layers.nnx_decoders import NNXDecoder, decoder_as_linen
3132
from MaxText.layers.embeddings import Embed, embed_as_linen
3233
from MaxText.layers.encoders import VisionEncoder, vision_encoder_as_linen, AudioEncoder, audio_encoder_as_linen
3334
from MaxText.layers.quantizations import AqtQuantization as Quant
@@ -86,7 +87,13 @@ def setup(self):
8687
)
8788
self.vision_encoder = vision_encoder_as_linen(config=cfg, mesh=mesh) if cfg.use_multimodal else None
8889
self.audio_encoder = audio_encoder_as_linen(config=cfg, mesh=mesh) if cfg.use_audio else None
89-
self.decoder = Decoder(config=cfg, mesh=mesh, quant=self.quant, model_mode=self.model_mode)
90+
if cfg.pure_nnx_decoder:
91+
self.decoder = decoder_as_linen(
92+
config=cfg, mesh=mesh, quant=self.quant, model_mode=self.model_mode, rngs=nnx.Rngs(0)
93+
)
94+
else:
95+
self.decoder = Decoder(config=cfg, mesh=mesh, quant=self.quant, model_mode=self.model_mode)
96+
9097
# If MTP is enabled via config, set up the MTP block.
9198
if self.config.mtp_num_layers > 0:
9299
# Get the list of layer blueprints for the current model.
@@ -325,9 +332,11 @@ def __init__(
325332
)
326333
self.vision_encoder = VisionEncoder(config=cfg, mesh=mesh, rngs=rngs) if cfg.use_multimodal else None
327334
self.audio_encoder = AudioEncoder(config=cfg, mesh=mesh, rngs=rngs) if cfg.use_audio else None
328-
329-
decoder_linen = Decoder(config=cfg, mesh=mesh, quant=self.quant, model_mode=self.model_mode)
330-
self.decoder = nnx_wrappers.ToNNX(decoder_linen, rngs=rngs)
335+
if cfg.pure_nnx_decoder:
336+
self.decoder = NNXDecoder(config=cfg, mesh=mesh, quant=self.quant, model_mode=self.model_mode, rngs=rngs)
337+
else:
338+
self.decoder = Decoder(config=cfg, mesh=mesh, quant=self.quant, model_mode=self.model_mode)
339+
self.decoder = nnx_wrappers.ToNNX(self.decoder, rngs=rngs)
331340
self.hidden_states = None
332341

333342
batch_size, seq_len = max_utils.get_batch_seq_len_for_mode(config=cfg, model_mode=model_mode)
@@ -353,12 +362,13 @@ def __init__(
353362
else:
354363
dummy_attention_metadata = None
355364

356-
self.decoder.lazy_init(
357-
shared_embedding=self.token_embedder,
358-
decoder_input_tokens=dummy_decoder_input_tokens,
359-
decoder_positions=dummy_decoder_positions,
360-
attention_metadata=dummy_attention_metadata,
361-
)
365+
if not cfg.pure_nnx_decoder:
366+
self.decoder.lazy_init(
367+
shared_embedding=self.token_embedder,
368+
decoder_input_tokens=dummy_decoder_input_tokens,
369+
decoder_positions=dummy_decoder_positions,
370+
attention_metadata=dummy_attention_metadata,
371+
)
362372

363373
# If MTP is enabled via config, set up the MTP block.
364374
if self.config.mtp_num_layers > 0:

src/MaxText/layers/multi_token_prediction.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -109,12 +109,22 @@ def __init__(
109109
rngs=rngs,
110110
)
111111
# Use MODEL_MODE_TRAIN for initialization; runtime model_mode is passed dynamically.
112-
mtp_transformer_layer = transformer_layer_module(
113-
config=cfg,
114-
mesh=mesh,
115-
model_mode=MODEL_MODE_TRAIN,
116-
name=f"mtp_{k}_transformer_layer",
117-
)
112+
if cfg.pure_nnx_decoder:
113+
mtp_transformer_layer = transformer_layer_module(
114+
config=cfg,
115+
mesh=mesh,
116+
model_mode=MODEL_MODE_TRAIN,
117+
name=f"mtp_{k}_transformer_layer",
118+
rngs=rngs,
119+
)
120+
else:
121+
mtp_transformer_layer = transformer_layer_module(
122+
config=cfg,
123+
mesh=mesh,
124+
model_mode=MODEL_MODE_TRAIN,
125+
name=f"mtp_{k}_transformer_layer",
126+
)
127+
118128
self.transformer_layer = nnx_wrappers.ToNNX(mtp_transformer_layer, rngs=rngs)
119129

120130
# ToNNX requires explicit initialization with sample inputs for proper parameter setup.

0 commit comments

Comments
 (0)