Skip to content

Commit 21c4433

Browse files
Merge pull request #3670 from AI-Hypercomputer:parambole/502806272
PiperOrigin-RevId: 900869830
2 parents 9c68b1a + 82798d8 commit 21c4433

3 files changed

Lines changed: 7 additions & 5 deletions

File tree

src/maxtext/layers/multi_token_prediction.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,6 @@ def __init__(
7979
self.layer_number = layer_number
8080
self.transformer_layer_module = transformer_layer_module
8181
self.rngs = rngs
82-
k = layer_number
8382
cfg = self.config
8483

8584
self.embedding_norm = RMSNorm(
@@ -112,7 +111,6 @@ def __init__(
112111
config=cfg,
113112
mesh=mesh,
114113
model_mode=MODEL_MODE_TRAIN,
115-
name=f"mtp_{k}_transformer_layer",
116114
rngs=rngs,
117115
)
118116

src/maxtext/layers/nnx_wrappers.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -627,5 +627,6 @@ def __init_subclass__(cls, **kwargs):
627627
ToLinenPartial.__qualname__ = class_name
628628

629629
ToLinenPartial.__init__ = __init__
630+
ToLinenPartial.module_class = base_nnx_class
630631

631632
return ToLinenPartial

src/maxtext/models/models.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -93,14 +93,17 @@ def setup(self):
9393
# If MTP is enabled via config, set up the MTP block.
9494
if self.config.mtp_num_layers > 0:
9595
# Get the list of layer blueprints for the current model.
96-
layer_types = self.decoder.get_decoder_layers()
9796
# For MTP, we use the DecoderLayer blueprint to ensure architectural consistency.
9897
# By convention, this is the last layer in the list.
99-
mtp_layer = layer_types[-1]
98+
layer_types = self.decoder.get_decoder_layers()
99+
mtp_layer_linen = layer_types[-1]
100+
# UNWRAP: The MTP block is pure NNX. If the decoder returned a Linen wrapper,
101+
# extract the native NNX class to preserve parameter tracing/scoping.
102+
mtp_layer_nnx = getattr(mtp_layer_linen, "module_class", mtp_layer_linen)
100103
self.mtp_block = multi_token_prediction_block_as_linen(
101104
config=self.config,
102105
mesh=self.mesh,
103-
transformer_layer_module=mtp_layer,
106+
transformer_layer_module=mtp_layer_nnx,
104107
decoder=self.decoder,
105108
rngs=self.make_rng("mtp_block"),
106109
)

0 commit comments

Comments
 (0)