2323from conftest import requires_fp8
2424from transformer_engine .common .recipe import DelayedScaling , Format
2525
26+ import amplify .amplify_hf as amp_hf
2627import amplify .amplify_te as amp_te
2728from amplify .state_dict_convert import convert_amplify_hf_to_te
2829
2930
30- try :
31- import xformers
32- except ImportError :
33- xformers = None
34-
35- if xformers is not None :
36- import amplify .amplify_hf as amp_hf
37- else :
38- amp_hf = None
39-
40-
41- @pytest .mark .skipif (amp_hf is None , reason = "xformers is not installed" )
4231def test_amplify_hf_model (config , input_data ):
4332 model = amp_hf .AMPLIFY (config )
4433 model .to ("cuda" )
@@ -78,7 +67,6 @@ def test_te_model_has_all_te_layers(config):
7867 assert not isinstance (module , nn .RMSNorm ), f"Vanilla RMSNorm layer found in { name } "
7968
8069
81- @pytest .mark .skipif (amp_hf is None , reason = "xformers is not installed" )
8270def test_models_have_identical_outputs (input_data ):
8371 model_hf = amp_hf .AMPLIFY .from_pretrained ("chandar-lab/AMPLIFY_120M" )
8472 model_te = convert_amplify_hf_to_te (model_hf )
@@ -96,7 +84,6 @@ def test_models_have_identical_outputs(input_data):
9684 torch .testing .assert_close (outputs_hf .loss , outputs_te .loss , atol = 1e-2 , rtol = 1e-3 )
9785
9886
99- @pytest .mark .skipif (amp_hf is None , reason = "xformers is not installed" )
10087def test_converted_model_roundtrip (input_data , tmp_path ):
10188 model_hf = amp_hf .AMPLIFY .from_pretrained ("chandar-lab/AMPLIFY_120M" )
10289 model_te = convert_amplify_hf_to_te (model_hf )
@@ -120,7 +107,6 @@ def test_converted_model_roundtrip(input_data, tmp_path):
120107 torch .testing .assert_close (outputs_hf .loss , outputs_te .loss , atol = 1e-2 , rtol = 1e-3 )
121108
122109
123- @pytest .mark .skipif (amp_hf is None , reason = "xformers is not installed" )
124110def test_convert_state_dict ():
125111 model_hf = amp_hf .AMPLIFY .from_pretrained ("chandar-lab/AMPLIFY_120M" )
126112 model_te = convert_amplify_hf_to_te (model_hf )
@@ -182,52 +168,3 @@ def test_convert_state_dict():
182168 te_state_dict_keys .remove ("decoder.bias" )
183169
184170 assert len (te_state_dict_keys ) == 0
185-
186-
187- def test_hf_trained_model_loss (input_data ):
188- model = amp_hf .AMPLIFY .from_pretrained ("chandar-lab/AMPLIFY_120M" )
189- model .to ("cuda" , dtype = torch .bfloat16 )
190- input_data = {k : v .to ("cuda" ) for k , v in input_data .items ()}
191- model .eval ()
192- with torch .amp .autocast (device_type = "cuda" , dtype = torch .bfloat16 ):
193- output = model (** input_data )
194-
195- torch .testing .assert_close (output .loss .detach ().cpu (), torch .tensor (2.4 ), atol = 1e-1 , rtol = 1e-2 )
196-
197-
198- def test_te_trained_model_loss (input_data ):
199- model_hf = amp_hf .AMPLIFY .from_pretrained ("chandar-lab/AMPLIFY_120M" )
200- model = convert_amplify_hf_to_te (model_hf )
201- model .to ("cuda" , dtype = torch .bfloat16 )
202- input_data = {k : v .to ("cuda" ) for k , v in input_data .items ()}
203- model .eval ()
204- with torch .amp .autocast (device_type = "cuda" , dtype = torch .bfloat16 ):
205- output = model (** input_data )
206-
207- torch .testing .assert_close (output .loss .detach ().cpu (), torch .tensor (2.4 ), atol = 1e-1 , rtol = 1e-2 )
208-
209-
210- def test_hf_reinitialized_model_loss (input_data ):
211- config = amp_hf .AMPLIFYConfig .from_pretrained ("chandar-lab/AMPLIFY_120M" )
212- model = amp_hf .AMPLIFY (config )
213- model .to ("cuda" , dtype = torch .bfloat16 )
214- input_data = {k : v .to ("cuda" ) for k , v in input_data .items ()}
215- model .eval ()
216- with torch .amp .autocast (device_type = "cuda" , dtype = torch .bfloat16 ):
217- output = model (** input_data )
218-
219- loss = output .loss .detach ().cpu ()
220- assert loss < 3.5 , f"Loss is { loss } , expected less than 3.5"
221-
222-
223- def test_te_reinitialized_model_loss (input_data ):
224- config = amp_te .AMPLIFYConfig .from_pretrained ("chandar-lab/AMPLIFY_120M" )
225- model = amp_te .AMPLIFYForMaskedLM (config )
226- model .to ("cuda" , dtype = torch .bfloat16 )
227- input_data = {k : v .to ("cuda" ) for k , v in input_data .items ()}
228- model .eval ()
229- with torch .amp .autocast (device_type = "cuda" , dtype = torch .bfloat16 ):
230- output = model (** input_data )
231-
232- loss = output .loss .detach ().cpu ()
233- assert loss < 3.5 , f"Loss is { loss } , expected less than 3.5"
0 commit comments