@@ -68,7 +68,7 @@ def test_te_model_has_all_te_layers(config):
6868
6969
7070def test_models_have_identical_outputs (input_data ):
71- model_hf = amp_hf .AMPLIFY .from_pretrained ("chandar-lab/AMPLIFY_120M" )
71+ model_hf = amp_hf .AMPLIFY .from_pretrained ("chandar-lab/AMPLIFY_120M" , revision = "d918a9e8" )
7272 model_te = convert_amplify_hf_to_te (model_hf )
7373 input_data = {k : v .to ("cuda" ) for k , v in input_data .items ()}
7474
@@ -85,7 +85,7 @@ def test_models_have_identical_outputs(input_data):
8585
8686
8787def test_converted_model_roundtrip (input_data , tmp_path ):
88- model_hf = amp_hf .AMPLIFY .from_pretrained ("chandar-lab/AMPLIFY_120M" )
88+ model_hf = amp_hf .AMPLIFY .from_pretrained ("chandar-lab/AMPLIFY_120M" , revision = "d918a9e8" )
8989 model_te = convert_amplify_hf_to_te (model_hf )
9090
9191 model_te .save_pretrained (tmp_path / "AMPLIFY_120M" )
@@ -108,7 +108,7 @@ def test_converted_model_roundtrip(input_data, tmp_path):
108108
109109
110110def test_convert_state_dict ():
111- model_hf = amp_hf .AMPLIFY .from_pretrained ("chandar-lab/AMPLIFY_120M" )
111+ model_hf = amp_hf .AMPLIFY .from_pretrained ("chandar-lab/AMPLIFY_120M" , revision = "d918a9e8" )
112112 model_te = convert_amplify_hf_to_te (model_hf )
113113
114114 from amplify .state_dict_convert import _pack_qkv_weight , _pad_bias , _pad_weights , mapping
@@ -171,7 +171,7 @@ def test_convert_state_dict():
171171
172172
173173def test_hf_trained_model_loss (input_data ):
174- model = amp_hf .AMPLIFY .from_pretrained ("chandar-lab/AMPLIFY_120M" )
174+ model = amp_hf .AMPLIFY .from_pretrained ("chandar-lab/AMPLIFY_120M" , revision = "d918a9e8" )
175175 model .to ("cuda" , dtype = torch .bfloat16 )
176176 input_data = {k : v .to ("cuda" ) for k , v in input_data .items ()}
177177 model .eval ()
@@ -182,7 +182,7 @@ def test_hf_trained_model_loss(input_data):
182182
183183
184184def test_te_trained_model_loss (input_data ):
185- model_hf = amp_hf .AMPLIFY .from_pretrained ("chandar-lab/AMPLIFY_120M" )
185+ model_hf = amp_hf .AMPLIFY .from_pretrained ("chandar-lab/AMPLIFY_120M" , revision = "d918a9e8" )
186186 model = convert_amplify_hf_to_te (model_hf )
187187 model .to ("cuda" , dtype = torch .bfloat16 )
188188 input_data = {k : v .to ("cuda" ) for k , v in input_data .items ()}
@@ -194,7 +194,7 @@ def test_te_trained_model_loss(input_data):
194194
195195
196196def test_hf_reinitialized_model_loss (input_data ):
197- config = amp_hf .AMPLIFYConfig .from_pretrained ("chandar-lab/AMPLIFY_120M" )
197+ config = amp_hf .AMPLIFYConfig .from_pretrained ("chandar-lab/AMPLIFY_120M" , revision = "d918a9e8" )
198198 model = amp_hf .AMPLIFY (config )
199199 model .to ("cuda" , dtype = torch .bfloat16 )
200200 input_data = {k : v .to ("cuda" ) for k , v in input_data .items ()}
@@ -207,7 +207,7 @@ def test_hf_reinitialized_model_loss(input_data):
207207
208208
209209def test_te_reinitialized_model_loss (input_data ):
210- config = amp_te .AMPLIFYConfig .from_pretrained ("chandar-lab/AMPLIFY_120M" )
210+ config = amp_te .AMPLIFYConfig .from_pretrained ("chandar-lab/AMPLIFY_120M" , revision = "d918a9e8" )
211211 model = amp_te .AMPLIFYForMaskedLM (config )
212212 model .to ("cuda" , dtype = torch .bfloat16 )
213213 input_data = {k : v .to ("cuda" ) for k , v in input_data .items ()}
0 commit comments