File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change 1818import tempfile
1919from absl .testing import absltest
2020from absl .testing import parameterized
21+ import pytest
2122from maxtext .trainers .pre_train import train
2223from tests .utils .test_helpers import get_test_config_path
2324
@@ -45,9 +46,8 @@ class Train(parameterized.TestCase):
4546 "use_sparsity" : True ,
4647 },
4748 )
48- def test_different_quant_sparsity_configs (
49- self , quantization : str , use_sparsity : bool
50- ):
49+ @pytest .mark .tpu_only
50+ def test_different_quant_sparsity_configs (self , quantization : str , use_sparsity : bool ):
5151 test_tmpdir = os .environ .get ("TEST_TMPDIR" , gettempdir ())
5252 outputs_dir = os .environ .get ("TEST_UNDECLARED_OUTPUTS_DIR" , test_tmpdir )
5353 args = [
@@ -81,11 +81,13 @@ def test_different_quant_sparsity_configs(
8181 f"metrics_file={ os .path .join (outputs_dir , 'metrics.json' )} " ,
8282 ]
8383 if use_sparsity :
84- args .extend ([
85- "weight_sparsity_n=2" ,
86- "weight_sparsity_m=4" ,
87- "weight_sparsity_update_step=1" ,
88- ])
84+ args .extend (
85+ [
86+ "weight_sparsity_n=2" ,
87+ "weight_sparsity_m=4" ,
88+ "weight_sparsity_update_step=1" ,
89+ ]
90+ )
8991 train_main (args )
9092
9193
You can’t perform that action at this time.
0 commit comments