Skip to content

Commit c720805

Browse files
committed
Mark sparsity test tpu_only
1 parent dff1d0c commit c720805

1 file changed

Lines changed: 10 additions & 8 deletions

File tree

tests/sparsity_test.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import tempfile
1919
from absl.testing import absltest
2020
from absl.testing import parameterized
21+
import pytest
2122
from maxtext.trainers.pre_train import train
2223
from tests.utils.test_helpers import get_test_config_path
2324

@@ -45,9 +46,8 @@ class Train(parameterized.TestCase):
4546
"use_sparsity": True,
4647
},
4748
)
48-
def test_different_quant_sparsity_configs(
49-
self, quantization: str, use_sparsity: bool
50-
):
49+
@pytest.mark.tpu_only
50+
def test_different_quant_sparsity_configs(self, quantization: str, use_sparsity: bool):
5151
test_tmpdir = os.environ.get("TEST_TMPDIR", gettempdir())
5252
outputs_dir = os.environ.get("TEST_UNDECLARED_OUTPUTS_DIR", test_tmpdir)
5353
args = [
@@ -81,11 +81,13 @@ def test_different_quant_sparsity_configs(
8181
f"metrics_file={os.path.join(outputs_dir, 'metrics.json')}",
8282
]
8383
if use_sparsity:
84-
args.extend([
85-
"weight_sparsity_n=2",
86-
"weight_sparsity_m=4",
87-
"weight_sparsity_update_step=1",
88-
])
84+
args.extend(
85+
[
86+
"weight_sparsity_n=2",
87+
"weight_sparsity_m=4",
88+
"weight_sparsity_update_step=1",
89+
]
90+
)
8991
train_main(args)
9092

9193

0 commit comments

Comments
 (0)