|
22 | 22 | import jax.numpy as jnp |
23 | 23 | from jax.sharding import Mesh |
24 | 24 | from maxtext.common.common_types import Config, MODEL_MODE_TRAIN |
| 25 | +from maxtext.layers.nnx_decoders import NNXDecoderLayer |
25 | 26 | from maxtext.utils.globals import EPS |
26 | | -from maxtext.layers import nnx_wrappers |
27 | 27 | from maxtext.layers.decoders import DecoderLayer |
28 | 28 | from maxtext.layers.initializers import variable_to_logically_partitioned |
29 | 29 | from maxtext.layers.linears import DenseGeneral |
@@ -70,7 +70,7 @@ def __init__( |
70 | 70 | config: Config, |
71 | 71 | mesh: Mesh, |
72 | 72 | layer_number: int, |
73 | | - transformer_layer_module: Type[DecoderLayer], |
| 73 | + transformer_layer_module: Type[NNXDecoderLayer], |
74 | 74 | *, |
75 | 75 | rngs: nnx.Rngs, |
76 | 76 | ): |
@@ -108,22 +108,12 @@ def __init__( |
108 | 108 | rngs=rngs, |
109 | 109 | ) |
110 | 110 | # Use MODEL_MODE_TRAIN for initialization; runtime model_mode is passed dynamically. |
111 | | - mtp_transformer_layer = transformer_layer_module( |
| 111 | + self.transformer_layer = transformer_layer_module( |
112 | 112 | config=cfg, |
113 | 113 | mesh=mesh, |
114 | 114 | model_mode=MODEL_MODE_TRAIN, |
115 | 115 | name=f"mtp_{k}_transformer_layer", |
116 | | - ) |
117 | | - self.transformer_layer = nnx_wrappers.ToNNX(mtp_transformer_layer, rngs=rngs) |
118 | | - |
119 | | - # ToNNX requires explicit initialization with sample inputs for proper parameter setup. |
120 | | - batch_size, seq_len = max_utils.get_batch_seq_len_for_mode(config=cfg, model_mode=MODEL_MODE_TRAIN) |
121 | | - self.transformer_layer.lazy_init( |
122 | | - inputs=jnp.zeros((batch_size, seq_len, self.config.emb_dim), dtype=self.config.dtype), |
123 | | - decoder_segment_ids=None, |
124 | | - decoder_positions=jnp.zeros((batch_size, seq_len), dtype=jnp.int32), |
125 | | - deterministic=True, |
126 | | - model_mode=MODEL_MODE_TRAIN, |
| 116 | + rngs=rngs, |
127 | 117 | ) |
128 | 118 |
|
129 | 119 | @property |
@@ -212,7 +202,7 @@ def __init__( |
212 | 202 | self, |
213 | 203 | config: Config, |
214 | 204 | mesh: Mesh, |
215 | | - transformer_layer_module: Type[DecoderLayer], |
| 205 | + transformer_layer_module: Type[NNXDecoderLayer], |
216 | 206 | decoder: nnx.Module, |
217 | 207 | rngs: nnx.Rngs, |
218 | 208 | ): |
|
0 commit comments