2323from conftest import requires_fp8
2424from transformer_engine .common .recipe import DelayedScaling , Format
2525
26- import amplify .amplify_hf as amp_hf
2726import amplify .amplify_te as amp_te
2827from amplify .state_dict_convert import convert_amplify_hf_to_te
2928
3029
30+ try :
31+ import xformers
32+ except ImportError :
33+ xformers = None
34+
35+ if xformers is not None :
36+ import amplify .amplify_hf as amp_hf
37+ else :
38+ amp_hf = None
39+
40+
41+ @pytest .mark .skipif (amp_hf is None , reason = "xformers is not installed" )
3142def test_amplify_hf_model (config , input_data ):
3243 model = amp_hf .AMPLIFY (config )
3344 model .to ("cuda" )
@@ -67,6 +78,7 @@ def test_te_model_has_all_te_layers(config):
6778 assert not isinstance (module , nn .RMSNorm ), f"Vanilla RMSNorm layer found in { name } "
6879
6980
81+ @pytest .mark .skipif (amp_hf is None , reason = "xformers is not installed" )
7082def test_models_have_identical_outputs (input_data ):
7183 model_hf = amp_hf .AMPLIFY .from_pretrained ("chandar-lab/AMPLIFY_120M" )
7284 model_te = convert_amplify_hf_to_te (model_hf )
@@ -84,6 +96,7 @@ def test_models_have_identical_outputs(input_data):
8496 torch .testing .assert_close (outputs_hf .loss , outputs_te .loss , atol = 1e-2 , rtol = 1e-3 )
8597
8698
99+ @pytest .mark .skipif (amp_hf is None , reason = "xformers is not installed" )
87100def test_converted_model_roundtrip (input_data , tmp_path ):
88101 model_hf = amp_hf .AMPLIFY .from_pretrained ("chandar-lab/AMPLIFY_120M" )
89102 model_te = convert_amplify_hf_to_te (model_hf )
@@ -107,6 +120,7 @@ def test_converted_model_roundtrip(input_data, tmp_path):
107120 torch .testing .assert_close (outputs_hf .loss , outputs_te .loss , atol = 1e-2 , rtol = 1e-3 )
108121
109122
123+ @pytest .mark .skipif (amp_hf is None , reason = "xformers is not installed" )
110124def test_convert_state_dict ():
111125 model_hf = amp_hf .AMPLIFY .from_pretrained ("chandar-lab/AMPLIFY_120M" )
112126 model_te = convert_amplify_hf_to_te (model_hf )
@@ -168,3 +182,52 @@ def test_convert_state_dict():
168182 te_state_dict_keys .remove ("decoder.bias" )
169183
170184 assert len (te_state_dict_keys ) == 0
185+
186+
187+ def test_hf_trained_model_loss (input_data ):
188+ model = amp_hf .AMPLIFY .from_pretrained ("chandar-lab/AMPLIFY_120M" )
189+ model .to ("cuda" , dtype = torch .bfloat16 )
190+ input_data = {k : v .to ("cuda" ) for k , v in input_data .items ()}
191+ model .eval ()
192+ with torch .amp .autocast (device_type = "cuda" , dtype = torch .bfloat16 ):
193+ output = model (** input_data )
194+
195+ torch .testing .assert_close (output .loss .detach ().cpu (), torch .tensor (2.4 ), atol = 1e-1 , rtol = 1e-2 )
196+
197+
198+ def test_te_trained_model_loss (input_data ):
199+ model_hf = amp_hf .AMPLIFY .from_pretrained ("chandar-lab/AMPLIFY_120M" )
200+ model = convert_amplify_hf_to_te (model_hf )
201+ model .to ("cuda" , dtype = torch .bfloat16 )
202+ input_data = {k : v .to ("cuda" ) for k , v in input_data .items ()}
203+ model .eval ()
204+ with torch .amp .autocast (device_type = "cuda" , dtype = torch .bfloat16 ):
205+ output = model (** input_data )
206+
207+ torch .testing .assert_close (output .loss .detach ().cpu (), torch .tensor (2.4 ), atol = 1e-1 , rtol = 1e-2 )
208+
209+
210+ def test_hf_reinitialized_model_loss (input_data ):
211+ config = amp_hf .AMPLIFYConfig .from_pretrained ("chandar-lab/AMPLIFY_120M" )
212+ model = amp_hf .AMPLIFY (config )
213+ model .to ("cuda" , dtype = torch .bfloat16 )
214+ input_data = {k : v .to ("cuda" ) for k , v in input_data .items ()}
215+ model .eval ()
216+ with torch .amp .autocast (device_type = "cuda" , dtype = torch .bfloat16 ):
217+ output = model (** input_data )
218+
219+ loss = output .loss .detach ().cpu ()
220+ assert loss < 3.5 , f"Loss is { loss } , expected less than 3.5"
221+
222+
223+ def test_te_reinitialized_model_loss (input_data ):
224+ config = amp_te .AMPLIFYConfig .from_pretrained ("chandar-lab/AMPLIFY_120M" )
225+ model = amp_te .AMPLIFYForMaskedLM (config )
226+ model .to ("cuda" , dtype = torch .bfloat16 )
227+ input_data = {k : v .to ("cuda" ) for k , v in input_data .items ()}
228+ model .eval ()
229+ with torch .amp .autocast (device_type = "cuda" , dtype = torch .bfloat16 ):
230+ output = model (** input_data )
231+
232+ loss = output .loss .detach ().cpu ()
233+ assert loss < 3.5 , f"Loss is { loss } , expected less than 3.5"
0 commit comments