Skip to content

Commit 37ded59

Browse files
PR #2831: Migrate Decoder to NNX
Imported from GitHub PR #2831 # Description Migrate the Transformer decoder layer into NNX. Note: The following models are currently not supported: - DeepSeek - Gemma3 - Llama4 Support for these models will be added in a follow-up PR. Strategy: A `pure_nnx_decoder` flag is added to control whether NNX or Linen decoder shall be used. Initial migration doesn't include the pipeline NNX support. # Tests Conducted these tests. Details in the [GDoc file](https://docs.google.com/document/d/1NbUP3g5glgbC6bMyt44pwM_vQA1NR7U2rBUzfbTDwSs/edit?pli=1&resourcekey=0-9EUahtzL-hCycdu7l0grhQ&tab=t.htq5367h8au0) 1. Test with different model and compare with Linen training 2. Golden logits comparison 3. Inference 4. Checkpoint comparison (Including TreeStructure Comparison) 5. Sharding comparison TODOs: - NNX version of unit tests (future PRs) # Checklist Before submitting this PR, please make sure (put X in square brackets): - [x] I have performed a self-review of my code. For an optional AI review, add the `gemini-review` label. - [x] I have necessary comments in my code, particularly in hard-to-understand areas. - [x] I have run end-to-end tests tests and provided workload links above if applicable. - [x] I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in [our documentation](https://maxtext.readthedocs.io/en/latest/development.html#adding-new-documentation-files). Copybara import of the project: -- 073e916 by hsuan-lun-chiang <hsuan-lun.chiang@cienet.com>: Migrate Decoder to NNX Adding nnx_decoders.py in parallel with decoders.py 1. Dup and modifiy decoders.py on new file nnx_decoders.py 2. add new config pure_nnx_decoder to control if model will use NNXDecoder, default false for now 3. modify relative code to accomodate the change 4. add/modify unit test Merging this change closes #2831 COPYBARA_INTEGRATE_REVIEW=#2831 from CIeNET-International:feat/Migrate-Decoder-to-NNX 073e916 PiperOrigin-RevId: 884170982
1 parent ca7e2df commit 37ded59

7 files changed

Lines changed: 1746 additions & 53 deletions

File tree

src/maxtext/configs/base.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1087,6 +1087,7 @@ subslice_shape: ""
10871087

10881088
# NNX
10891089
enable_nnx: false
1090+
pure_nnx_decoder: false
10901091

10911092
################################## Qwen3-Next Specific Configs ##################################
10921093
# Kernel size for the 1D convolution in the Gated Delta Net
@@ -1152,4 +1153,4 @@ distill_temperature: 1.0
11521153
# distill_beta is used for cosine similarity loss between intermediate activataitions of out_proj in teacher/student models.
11531154
# 0.0 value disables this feature.
11541155
distill_beta: 0.0
1155-
distill_layer_indices: None
1156+
distill_layer_indices: None

src/maxtext/configs/types.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -783,6 +783,7 @@ class HardwareAndMesh(BaseModel):
783783
enable_nnx: bool = Field(False, description="Whether to use NNX for model definition.")
784784
optimize_mesh_for_tpu_v6e: bool = Field(False, description="Apply transformations to the mesh for TPU v6e.")
785785
shardy: bool = Field(True, description="Whether to use shardy XLA backend.")
786+
pure_nnx_decoder: bool = Field(False, description="Whether to enable pure NNX decoder.")
786787

787788

788789
class LayoutAndSharding(BaseModel):

src/maxtext/layers/multi_token_prediction.py

Lines changed: 5 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,8 @@
2222
import jax.numpy as jnp
2323
from jax.sharding import Mesh
2424
from maxtext.common.common_types import Config, MODEL_MODE_TRAIN
25+
from maxtext.layers.nnx_decoders import NNXDecoderLayer
2526
from maxtext.utils.globals import EPS
26-
from maxtext.layers import nnx_wrappers
2727
from maxtext.layers.decoders import DecoderLayer
2828
from maxtext.layers.initializers import variable_to_logically_partitioned
2929
from maxtext.layers.linears import DenseGeneral
@@ -70,7 +70,7 @@ def __init__(
7070
config: Config,
7171
mesh: Mesh,
7272
layer_number: int,
73-
transformer_layer_module: Type[DecoderLayer],
73+
transformer_layer_module: Type[NNXDecoderLayer],
7474
*,
7575
rngs: nnx.Rngs,
7676
):
@@ -108,22 +108,12 @@ def __init__(
108108
rngs=rngs,
109109
)
110110
# Use MODEL_MODE_TRAIN for initialization; runtime model_mode is passed dynamically.
111-
mtp_transformer_layer = transformer_layer_module(
111+
self.transformer_layer = transformer_layer_module(
112112
config=cfg,
113113
mesh=mesh,
114114
model_mode=MODEL_MODE_TRAIN,
115115
name=f"mtp_{k}_transformer_layer",
116-
)
117-
self.transformer_layer = nnx_wrappers.ToNNX(mtp_transformer_layer, rngs=rngs)
118-
119-
# ToNNX requires explicit initialization with sample inputs for proper parameter setup.
120-
batch_size, seq_len = max_utils.get_batch_seq_len_for_mode(config=cfg, model_mode=MODEL_MODE_TRAIN)
121-
self.transformer_layer.lazy_init(
122-
inputs=jnp.zeros((batch_size, seq_len, self.config.emb_dim), dtype=self.config.dtype),
123-
decoder_segment_ids=None,
124-
decoder_positions=jnp.zeros((batch_size, seq_len), dtype=jnp.int32),
125-
deterministic=True,
126-
model_mode=MODEL_MODE_TRAIN,
116+
rngs=rngs,
127117
)
128118

129119
@property
@@ -212,7 +202,7 @@ def __init__(
212202
self,
213203
config: Config,
214204
mesh: Mesh,
215-
transformer_layer_module: Type[DecoderLayer],
205+
transformer_layer_module: Type[NNXDecoderLayer],
216206
decoder: nnx.Module,
217207
rngs: nnx.Rngs,
218208
):

0 commit comments

Comments
 (0)