Skip to content

Commit d16085c

Browse files
committed
Merge branch 'main' of github.com:AI-Hypercomputer/maxtext into shuningjin-qwix1
2 parents cd2d86f + 37ded59 commit d16085c

7 files changed

Lines changed: 1746 additions & 53 deletions

File tree

src/maxtext/configs/base.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1087,6 +1087,7 @@ subslice_shape: ""
10871087

10881088
# NNX
10891089
enable_nnx: false
1090+
pure_nnx_decoder: false
10901091

10911092
################################## Qwen3-Next Specific Configs ##################################
10921093
# Kernel size for the 1D convolution in the Gated Delta Net
@@ -1152,4 +1153,4 @@ distill_temperature: 1.0
11521153
# distill_beta is used for cosine similarity loss between intermediate activataitions of out_proj in teacher/student models.
11531154
# 0.0 value disables this feature.
11541155
distill_beta: 0.0
1155-
distill_layer_indices: None
1156+
distill_layer_indices: None

src/maxtext/configs/types.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -783,6 +783,7 @@ class HardwareAndMesh(BaseModel):
783783
enable_nnx: bool = Field(False, description="Whether to use NNX for model definition.")
784784
optimize_mesh_for_tpu_v6e: bool = Field(False, description="Apply transformations to the mesh for TPU v6e.")
785785
shardy: bool = Field(True, description="Whether to use shardy XLA backend.")
786+
pure_nnx_decoder: bool = Field(False, description="Whether to enable pure NNX decoder.")
786787

787788

788789
class LayoutAndSharding(BaseModel):

src/maxtext/layers/multi_token_prediction.py

Lines changed: 5 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,8 @@
2222
import jax.numpy as jnp
2323
from jax.sharding import Mesh
2424
from maxtext.common.common_types import Config, MODEL_MODE_TRAIN
25+
from maxtext.layers.nnx_decoders import NNXDecoderLayer
2526
from maxtext.utils.globals import EPS
26-
from maxtext.layers import nnx_wrappers
2727
from maxtext.layers.decoders import DecoderLayer
2828
from maxtext.layers.initializers import variable_to_logically_partitioned
2929
from maxtext.layers.linears import DenseGeneral
@@ -70,7 +70,7 @@ def __init__(
7070
config: Config,
7171
mesh: Mesh,
7272
layer_number: int,
73-
transformer_layer_module: Type[DecoderLayer],
73+
transformer_layer_module: Type[NNXDecoderLayer],
7474
*,
7575
rngs: nnx.Rngs,
7676
):
@@ -108,22 +108,12 @@ def __init__(
108108
rngs=rngs,
109109
)
110110
# Use MODEL_MODE_TRAIN for initialization; runtime model_mode is passed dynamically.
111-
mtp_transformer_layer = transformer_layer_module(
111+
self.transformer_layer = transformer_layer_module(
112112
config=cfg,
113113
mesh=mesh,
114114
model_mode=MODEL_MODE_TRAIN,
115115
name=f"mtp_{k}_transformer_layer",
116-
)
117-
self.transformer_layer = nnx_wrappers.ToNNX(mtp_transformer_layer, rngs=rngs)
118-
119-
# ToNNX requires explicit initialization with sample inputs for proper parameter setup.
120-
batch_size, seq_len = max_utils.get_batch_seq_len_for_mode(config=cfg, model_mode=MODEL_MODE_TRAIN)
121-
self.transformer_layer.lazy_init(
122-
inputs=jnp.zeros((batch_size, seq_len, self.config.emb_dim), dtype=self.config.dtype),
123-
decoder_segment_ids=None,
124-
decoder_positions=jnp.zeros((batch_size, seq_len), dtype=jnp.int32),
125-
deterministic=True,
126-
model_mode=MODEL_MODE_TRAIN,
116+
rngs=rngs,
127117
)
128118

129119
@property
@@ -212,7 +202,7 @@ def __init__(
212202
self,
213203
config: Config,
214204
mesh: Mesh,
215-
transformer_layer_module: Type[DecoderLayer],
205+
transformer_layer_module: Type[NNXDecoderLayer],
216206
decoder: nnx.Module,
217207
rngs: nnx.Rngs,
218208
):

0 commit comments

Comments
 (0)