diff --git a/src/maxtext/models/models.py b/src/maxtext/models/models.py index 8e20c01bba..0d1fcab700 100644 --- a/src/maxtext/models/models.py +++ b/src/maxtext/models/models.py @@ -335,8 +335,8 @@ def __init__( if cfg.pure_nnx_decoder: self.decoder = NNXDecoder(config=cfg, mesh=mesh, quant=self.quant, model_mode=self.model_mode, rngs=rngs) else: - self.decoder = Decoder(config=cfg, mesh=mesh, quant=self.quant, model_mode=self.model_mode) - self.decoder = nnx_wrappers.ToNNX(self.decoder, rngs=rngs) + decoder_linen = Decoder(config=cfg, mesh=mesh, quant=self.quant, model_mode=self.model_mode) + self.decoder = nnx_wrappers.ToNNX(decoder_linen, rngs=rngs) self.hidden_states = None batch_size, seq_len = max_utils.get_batch_seq_len_for_mode(config=cfg, model_mode=model_mode) @@ -529,7 +529,7 @@ def __call__( attention_metadata=attention_metadata, deepstack_visual_embeds=deepstack_visual_embeds, mutable=mutable_collections, - ) # pytype: disable=wrong-keyword-args + ) # pytype: disable=wrong-keyword-args # Materialize hidden state when vocab tiling is enabled if self.config.num_vocab_tiling > 1: