@@ -168,3 +168,52 @@ def test_convert_state_dict():
168168 te_state_dict_keys .remove ("decoder.bias" )
169169
170170 assert len (te_state_dict_keys ) == 0
171+
172+
173+ def test_hf_trained_model_loss (input_data ):
174+ model = amp_hf .AMPLIFY .from_pretrained ("chandar-lab/AMPLIFY_120M" )
175+ model .to ("cuda" , dtype = torch .bfloat16 )
176+ input_data = {k : v .to ("cuda" ) for k , v in input_data .items ()}
177+ model .eval ()
178+ with torch .amp .autocast (device_type = "cuda" , dtype = torch .bfloat16 ):
179+ output = model (** input_data )
180+
181+ torch .testing .assert_close (output .loss .detach ().cpu (), torch .tensor (2.4 ), atol = 1e-1 , rtol = 1e-2 )
182+
183+
184+ def test_te_trained_model_loss (input_data ):
185+ model_hf = amp_hf .AMPLIFY .from_pretrained ("chandar-lab/AMPLIFY_120M" )
186+ model = convert_amplify_hf_to_te (model_hf )
187+ model .to ("cuda" , dtype = torch .bfloat16 )
188+ input_data = {k : v .to ("cuda" ) for k , v in input_data .items ()}
189+ model .eval ()
190+ with torch .amp .autocast (device_type = "cuda" , dtype = torch .bfloat16 ):
191+ output = model (** input_data )
192+
193+ torch .testing .assert_close (output .loss .detach ().cpu (), torch .tensor (2.4 ), atol = 1e-1 , rtol = 1e-2 )
194+
195+
196+ def test_hf_reinitialized_model_loss (input_data ):
197+ config = amp_hf .AMPLIFYConfig .from_pretrained ("chandar-lab/AMPLIFY_120M" )
198+ model = amp_hf .AMPLIFY (config )
199+ model .to ("cuda" , dtype = torch .bfloat16 )
200+ input_data = {k : v .to ("cuda" ) for k , v in input_data .items ()}
201+ model .eval ()
202+ with torch .amp .autocast (device_type = "cuda" , dtype = torch .bfloat16 ):
203+ output = model (** input_data )
204+
205+ loss = output .loss .detach ().cpu ()
206+ assert loss < 3.5 , f"Loss is { loss } , expected less than 3.5"
207+
208+
209+ def test_te_reinitialized_model_loss (input_data ):
210+ config = amp_te .AMPLIFYConfig .from_pretrained ("chandar-lab/AMPLIFY_120M" )
211+ model = amp_te .AMPLIFYForMaskedLM (config )
212+ model .to ("cuda" , dtype = torch .bfloat16 )
213+ input_data = {k : v .to ("cuda" ) for k , v in input_data .items ()}
214+ model .eval ()
215+ with torch .amp .autocast (device_type = "cuda" , dtype = torch .bfloat16 ):
216+ output = model (** input_data )
217+
218+ loss = output .loss .detach ().cpu ()
219+ assert loss < 3.5 , f"Loss is { loss } , expected less than 3.5"
0 commit comments