@@ -147,17 +147,15 @@ def __init__(self, config: AMPLIFYConfig, **kwargs):
147147 config .padded_vocab_size ,
148148 config .hidden_size ,
149149 padding_idx = config .pad_token_id ,
150- dtype = config .torch_dtype ,
150+ dtype = config .dtype ,
151151 )
152152
153153 if config .layer_norm_after_embedding :
154154 self .layer_norm_1 = (
155- transformer_engine .pytorch .RMSNorm (
156- config .hidden_size , config .norm_eps , params_dtype = config .torch_dtype
157- )
155+ transformer_engine .pytorch .RMSNorm (config .hidden_size , config .norm_eps , params_dtype = config .dtype )
158156 if config .rms_norm
159157 else transformer_engine .pytorch .LayerNorm (
160- config .hidden_size , config .norm_eps , params_dtype = config .torch_dtype
158+ config .hidden_size , config .norm_eps , params_dtype = config .dtype
161159 )
162160 )
163161
@@ -194,7 +192,7 @@ def __init__(self, config: AMPLIFYConfig, **kwargs):
194192 window_size = (- 1 , - 1 ),
195193 rotary_pos_interleaved = True ,
196194 seq_length = config .max_length ,
197- params_dtype = config .torch_dtype ,
195+ params_dtype = config .dtype ,
198196 )
199197 )
200198
@@ -277,7 +275,7 @@ def __init__(self, config: AMPLIFYConfig, **kwargs):
277275 config .hidden_size ,
278276 config .padded_vocab_size ,
279277 config .norm_eps ,
280- params_dtype = config .torch_dtype ,
278+ params_dtype = config .dtype ,
281279 normalization = "RMSNorm" if config .rms_norm else "LayerNorm" ,
282280 init_method = lambda x : torch .nn .init .uniform_ (
283281 x , - self .config .decoder_init_range , self .config .decoder_init_range
@@ -286,7 +284,7 @@ def __init__(self, config: AMPLIFYConfig, **kwargs):
286284
287285 else :
288286 self .decoder = transformer_engine .pytorch .Linear (
289- config .hidden_size , config .vocab_size , params_dtype = config .torch_dtype
287+ config .hidden_size , config .vocab_size , params_dtype = config .dtype
290288 )
291289
292290 def forward (
0 commit comments