Skip to content

Commit de4ec11

Browse files
Migrate Decoder (Gemma3/Deepseek/Llama4) and utils to NNX
1 parent c24d321 commit de4ec11

6 files changed

Lines changed: 1121 additions & 35 deletions

File tree

src/maxtext/configs/base.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1074,7 +1074,8 @@ position_id_per_seconds: 25
10741074
subslice_shape: ""
10751075

10761076
# NNX
1077-
enable_nnx: false
1077+
enable_nnx: True
1078+
pure_nnx_decoder: True
10781079

10791080
################################## Qwen3-Next Specific Configs ##################################
10801081
# Kernel size for the 1D convolution in the Gated Delta Net

src/maxtext/configs/types.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -777,6 +777,7 @@ class HardwareAndMesh(BaseModel):
777777
enable_nnx: bool = Field(False, description="Whether to use NNX for model definition.")
778778
optimize_mesh_for_tpu_v6e: bool = Field(False, description="Apply transformations to the mesh for TPU v6e.")
779779
shardy: bool = Field(True, description="Whether to use shardy XLA backend.")
780+
pure_nnx_decoder: bool = Field(False, description="Whether to enable pure NNX decoder.")
780781

781782

782783
class LayoutAndSharding(BaseModel):

src/maxtext/layers/multi_token_prediction.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,9 @@ def __init__(
113113
mesh=mesh,
114114
model_mode=MODEL_MODE_TRAIN,
115115
name=f"mtp_{k}_transformer_layer",
116+
rngs=rngs,
116117
)
118+
117119
self.transformer_layer = nnx_wrappers.ToNNX(mtp_transformer_layer, rngs=rngs)
118120

119121
# ToNNX requires explicit initialization with sample inputs for proper parameter setup.

0 commit comments

Comments
 (0)