2828from MaxText .common_types import Config , MODEL_MODE_TRAIN , MODEL_MODE_AUTOREGRESSIVE , DECODING_ACTIVE_SEQUENCE_INDICATOR
2929from MaxText .layers import nnx_wrappers
3030from MaxText .layers .decoders import Decoder
31+ from MaxText .layers .nnx_decoders import NNXDecoder , decoder_as_linen
3132from MaxText .layers .embeddings import Embed , embed_as_linen
3233from MaxText .layers .encoders import VisionEncoder , vision_encoder_as_linen , AudioEncoder , audio_encoder_as_linen
3334from MaxText .layers .quantizations import AqtQuantization as Quant
@@ -86,7 +87,13 @@ def setup(self):
8687 )
8788 self .vision_encoder = vision_encoder_as_linen (config = cfg , mesh = mesh ) if cfg .use_multimodal else None
8889 self .audio_encoder = audio_encoder_as_linen (config = cfg , mesh = mesh ) if cfg .use_audio else None
89- self .decoder = Decoder (config = cfg , mesh = mesh , quant = self .quant , model_mode = self .model_mode )
90+ if cfg .pure_nnx_decoder :
91+ self .decoder = decoder_as_linen (
92+ config = cfg , mesh = mesh , quant = self .quant , model_mode = self .model_mode , rngs = nnx .Rngs (0 )
93+ )
94+ else :
95+ self .decoder = Decoder (config = cfg , mesh = mesh , quant = self .quant , model_mode = self .model_mode )
96+
9097 # If MTP is enabled via config, set up the MTP block.
9198 if self .config .mtp_num_layers > 0 :
9299 # Get the list of layer blueprints for the current model.
@@ -325,9 +332,11 @@ def __init__(
325332 )
326333 self .vision_encoder = VisionEncoder (config = cfg , mesh = mesh , rngs = rngs ) if cfg .use_multimodal else None
327334 self .audio_encoder = AudioEncoder (config = cfg , mesh = mesh , rngs = rngs ) if cfg .use_audio else None
328-
329- decoder_linen = Decoder (config = cfg , mesh = mesh , quant = self .quant , model_mode = self .model_mode )
330- self .decoder = nnx_wrappers .ToNNX (decoder_linen , rngs = rngs )
335+ if cfg .pure_nnx_decoder :
336+ self .decoder = NNXDecoder (config = cfg , mesh = mesh , quant = self .quant , model_mode = self .model_mode , rngs = rngs )
337+ else :
338+ self .decoder = Decoder (config = cfg , mesh = mesh , quant = self .quant , model_mode = self .model_mode )
339+ self .decoder = nnx_wrappers .ToNNX (self .decoder , rngs = rngs )
331340 self .hidden_states = None
332341
333342 batch_size , seq_len = max_utils .get_batch_seq_len_for_mode (config = cfg , model_mode = model_mode )
@@ -353,12 +362,13 @@ def __init__(
353362 else :
354363 dummy_attention_metadata = None
355364
356- self .decoder .lazy_init (
357- shared_embedding = self .token_embedder ,
358- decoder_input_tokens = dummy_decoder_input_tokens ,
359- decoder_positions = dummy_decoder_positions ,
360- attention_metadata = dummy_attention_metadata ,
361- )
365+ if not cfg .pure_nnx_decoder :
366+ self .decoder .lazy_init (
367+ shared_embedding = self .token_embedder ,
368+ decoder_input_tokens = dummy_decoder_input_tokens ,
369+ decoder_positions = dummy_decoder_positions ,
370+ attention_metadata = dummy_attention_metadata ,
371+ )
362372
363373 # If MTP is enabled via config, set up the MTP block.
364374 if self .config .mtp_num_layers > 0 :
0 commit comments