@@ -147,17 +147,15 @@ def __init__(self, config: AMPLIFYConfig, **kwargs):
147147 config .padded_vocab_size ,
148148 config .hidden_size ,
149149 padding_idx = config .pad_token_id ,
150- dtype = config .torch_dtype ,
150+ dtype = config .dtype ,
151151 )
152152
153153 if config .layer_norm_after_embedding :
154154 self .layer_norm_1 = (
155- transformer_engine .pytorch .RMSNorm (
156- config .hidden_size , config .norm_eps , params_dtype = config .torch_dtype
157- )
155+ transformer_engine .pytorch .RMSNorm (config .hidden_size , config .norm_eps , params_dtype = config .dtype )
158156 if config .rms_norm
159157 else transformer_engine .pytorch .LayerNorm (
160- config .hidden_size , config .norm_eps , params_dtype = config .torch_dtype
158+ config .hidden_size , config .norm_eps , params_dtype = config .dtype
161159 )
162160 )
163161
@@ -169,6 +167,9 @@ def __init__(self, config: AMPLIFYConfig, **kwargs):
169167 intermediate_size = int (2 * config .intermediate_size / 3 )
170168 intermediate_size = multiple_of * ((intermediate_size + multiple_of - 1 ) // multiple_of )
171169
170+ else :
171+ intermediate_size = config .intermediate_size
172+
172173 self .transformer_encoder = nn .ModuleList ()
173174 for layer_num in range (config .num_hidden_layers ):
174175 self .transformer_encoder .append (
@@ -194,7 +195,7 @@ def __init__(self, config: AMPLIFYConfig, **kwargs):
194195 window_size = (- 1 , - 1 ),
195196 rotary_pos_interleaved = True ,
196197 seq_length = config .max_length ,
197- params_dtype = config .torch_dtype ,
198+ params_dtype = config .dtype ,
198199 )
199200 )
200201
@@ -212,7 +213,6 @@ def forward(
212213 output_hidden_states = False ,
213214 output_attentions = False ,
214215 labels = None ,
215- ** kwargs ,
216216 ) -> BaseModelOutput :
217217 """Forward pass of the AMPLIFY model.
218218
@@ -222,7 +222,6 @@ def forward(
222222 output_hidden_states (bool): Whether to output the hidden states.
223223 output_attentions (bool): Whether to output the attention weights.
224224 labels (torch.Tensor): The labels.
225- **kwargs: Additional arguments.
226225
227226 Returns:
228227 BaseModelOutput: The output of the model.
@@ -277,7 +276,7 @@ def __init__(self, config: AMPLIFYConfig, **kwargs):
277276 config .hidden_size ,
278277 config .padded_vocab_size ,
279278 config .norm_eps ,
280- params_dtype = config .torch_dtype ,
279+ params_dtype = config .dtype ,
281280 normalization = "RMSNorm" if config .rms_norm else "LayerNorm" ,
282281 init_method = lambda x : torch .nn .init .uniform_ (
283282 x , - self .config .decoder_init_range , self .config .decoder_init_range
@@ -286,7 +285,7 @@ def __init__(self, config: AMPLIFYConfig, **kwargs):
286285
287286 else :
288287 self .decoder = transformer_engine .pytorch .Linear (
289- config .hidden_size , config .vocab_size , params_dtype = config .torch_dtype
288+ config .hidden_size , config .vocab_size , params_dtype = config .dtype
290289 )
291290
292291 def forward (
@@ -296,7 +295,6 @@ def forward(
296295 output_hidden_states = False ,
297296 output_attentions = False ,
298297 labels = None ,
299- ** kwargs ,
300298 ) -> MaskedLMOutput :
301299 """Forward pass of the AMPLIFYForMaskedLM model.
302300
@@ -306,7 +304,6 @@ def forward(
306304 output_hidden_states (bool): Whether to output the hidden states.
307305 output_attentions (bool): Whether to output the attention weights.
308306 labels (torch.Tensor): The labels.
309- **kwargs: Additional arguments.
310307
311308 Returns:
312309 MaskedLMOutput: The output of the model.
@@ -317,7 +314,6 @@ def forward(
317314 output_hidden_states ,
318315 output_attentions ,
319316 labels ,
320- ** kwargs ,
321317 )
322318
323319 # Classification head with layer norm
0 commit comments