@@ -147,17 +147,15 @@ def __init__(self, config: AMPLIFYConfig, **kwargs):
147147 config .padded_vocab_size ,
148148 config .hidden_size ,
149149 padding_idx = config .pad_token_id ,
150- dtype = config .torch_dtype ,
150+ dtype = config .dtype ,
151151 )
152152
153153 if config .layer_norm_after_embedding :
154154 self .layer_norm_1 = (
155- transformer_engine .pytorch .RMSNorm (
156- config .hidden_size , config .norm_eps , params_dtype = config .torch_dtype
157- )
155+ transformer_engine .pytorch .RMSNorm (config .hidden_size , config .norm_eps , params_dtype = config .dtype )
158156 if config .rms_norm
159157 else transformer_engine .pytorch .LayerNorm (
160- config .hidden_size , config .norm_eps , params_dtype = config .torch_dtype
158+ config .hidden_size , config .norm_eps , params_dtype = config .dtype
161159 )
162160 )
163161
@@ -197,7 +195,7 @@ def __init__(self, config: AMPLIFYConfig, **kwargs):
197195 window_size = (- 1 , - 1 ),
198196 rotary_pos_interleaved = True ,
199197 seq_length = config .max_length ,
200- params_dtype = config .torch_dtype ,
198+ params_dtype = config .dtype ,
201199 )
202200 )
203201
@@ -278,7 +276,7 @@ def __init__(self, config: AMPLIFYConfig, **kwargs):
278276 config .hidden_size ,
279277 config .padded_vocab_size ,
280278 config .norm_eps ,
281- params_dtype = config .torch_dtype ,
279+ params_dtype = config .dtype ,
282280 normalization = "RMSNorm" if config .rms_norm else "LayerNorm" ,
283281 init_method = lambda x : torch .nn .init .uniform_ (
284282 x , - self .config .decoder_init_range , self .config .decoder_init_range
@@ -287,7 +285,7 @@ def __init__(self, config: AMPLIFYConfig, **kwargs):
287285
288286 else :
289287 self .decoder = transformer_engine .pytorch .Linear (
290- config .hidden_size , config .vocab_size , params_dtype = config .torch_dtype
288+ config .hidden_size , config .vocab_size , params_dtype = config .dtype
291289 )
292290
293291 def forward (
0 commit comments