Skip to content

Commit 7c03487

Browse files
authored
Merge pull request #57 from InfiniTensor/standardize-test-parametrization-and-configuration
Standardize test parametrization and configuration
2 parents 679985e + b0a1a77 commit 7c03487

40 files changed

Lines changed: 170 additions & 167 deletions

src/ntops/torch/utils.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,55 @@
66
import ntops
77

88

9+
class _CachedMakeDefaultConfig:
10+
def __init__(self, num_warps=None, num_stages=None, max_num_configs=None):
11+
self.num_warps = num_warps
12+
13+
self.num_stages = num_stages
14+
15+
self.max_num_configs = max_num_configs
16+
17+
18+
_cached_make_default_config = _CachedMakeDefaultConfig()
19+
20+
21+
def get_default_num_warps():
22+
return _cached_make_default_config.num_warps
23+
24+
25+
def set_default_num_warps(num_warps):
26+
_cached_make_default_config.num_warps = num_warps
27+
28+
29+
def get_default_num_stages():
30+
return _cached_make_default_config.num_stages
31+
32+
33+
def set_default_num_stages(num_stages):
34+
_cached_make_default_config.num_stages = num_stages
35+
36+
37+
def get_default_max_num_configs():
38+
return _cached_make_default_config.max_num_configs
39+
40+
41+
def set_default_max_num_configs(max_num_configs):
42+
_cached_make_default_config.max_num_configs = max_num_configs
43+
44+
945
@functools.cache
1046
def _cached_make(
1147
premake, *args, num_warps=None, num_stages=None, max_num_configs=None, **keywords
1248
):
49+
if num_warps is None:
50+
num_warps = _cached_make_default_config.num_warps
51+
52+
if num_stages is None:
53+
num_stages = _cached_make_default_config.num_stages
54+
55+
if max_num_configs is None:
56+
max_num_configs = _cached_make_default_config.max_num_configs
57+
1358
return ninetoothed.make(
1459
*premake(*args, **keywords),
1560
num_warps=num_warps,

tests/conftest.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,15 @@
44
import pytest
55
import torch
66

7+
import ntops.torch.utils
8+
79

810
def pytest_configure():
911
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False
1012
torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = False
1113

14+
ntops.torch.utils.set_default_max_num_configs(_DEFAULT_MAX_NUM_CONFIGS)
15+
1216

1317
def pytest_collectstart(collector):
1418
if isinstance(collector, pytest.Module):
@@ -25,6 +29,9 @@ def set_seed_per_test(request):
2529
_set_random_seed(_hash(_test_case_path_from_request(request)))
2630

2731

32+
_DEFAULT_MAX_NUM_CONFIGS = 3
33+
34+
2835
def _set_random_seed(seed):
2936
random.seed(seed)
3037
torch.manual_seed(seed)

tests/test_abs.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,10 @@
88

99
@skip_if_cuda_not_available
1010
@pytest.mark.parametrize(*generate_arguments())
11-
def test_cuda(shape, dtype, atol, rtol):
12-
device = "cuda"
13-
11+
def test_abs(shape, dtype, device, rtol, atol):
1412
input = torch.randn(shape, dtype=dtype, device=device)
1513

1614
ninetoothed_output = ntops.torch.abs(input)
1715
reference_output = torch.abs(input)
1816

19-
assert torch.allclose(ninetoothed_output, reference_output, atol=atol, rtol=rtol)
17+
assert torch.allclose(ninetoothed_output, reference_output, rtol=rtol, atol=atol)

tests/test_add.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,12 @@
88

99
@skip_if_cuda_not_available
1010
@pytest.mark.parametrize(*generate_arguments())
11-
def test_cuda(shape, dtype, atol, rtol):
12-
device = "cuda"
13-
11+
def test_add(shape, dtype, device, rtol, atol):
1412
input = torch.randn(shape, dtype=dtype, device=device)
1513
other = torch.randn(shape, dtype=dtype, device=device)
1614
alpha = gauss()
1715

1816
ninetoothed_output = ntops.torch.add(input, other, alpha=alpha)
1917
reference_output = torch.add(input, other, alpha=alpha)
2018

21-
assert torch.allclose(ninetoothed_output, reference_output, atol=atol, rtol=rtol)
19+
assert torch.allclose(ninetoothed_output, reference_output, rtol=rtol, atol=atol)

tests/test_addmm.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,7 @@
99

1010
@skip_if_cuda_not_available
1111
@pytest.mark.parametrize(*generate_arguments())
12-
def test_cuda(m, n, k, dtype, atol, rtol):
13-
device = "cuda"
14-
12+
def test_addmm(m, n, k, dtype, device, rtol, atol):
1513
input = torch.randn((m, n), dtype=dtype, device=device)
1614
x = torch.randn((m, k), dtype=dtype, device=device)
1715
y = torch.randn((k, n), dtype=dtype, device=device)
@@ -21,4 +19,4 @@ def test_cuda(m, n, k, dtype, atol, rtol):
2119
ninetoothed_output = ntops.torch.addmm(input, x, y, beta=beta, alpha=alpha)
2220
reference_output = torch.addmm(input, x, y, beta=beta, alpha=alpha)
2321

24-
assert torch.allclose(ninetoothed_output, reference_output, atol=atol, rtol=rtol)
22+
assert torch.allclose(ninetoothed_output, reference_output, rtol=rtol, atol=atol)

tests/test_bitwise_and.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,7 @@
88

99
@skip_if_cuda_not_available
1010
@pytest.mark.parametrize(*generate_arguments(False))
11-
def test_cuda(shape, dtype, atol, rtol):
12-
device = "cuda"
13-
11+
def test_bitwise_and(shape, dtype, device, rtol, atol):
1412
if dtype == torch.bool:
1513
prob = 0.5
1614
input = torch.rand(shape, dtype=torch.float32, device=device) > prob

tests/test_bitwise_not.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,7 @@
88

99
@skip_if_cuda_not_available
1010
@pytest.mark.parametrize(*generate_arguments(False))
11-
def test_cuda(shape, dtype, atol, rtol):
12-
device = "cuda"
13-
11+
def test_bitwise_not(shape, dtype, device, rtol, atol):
1412
if dtype == torch.bool:
1513
prob = 0.5
1614
input = torch.rand(shape, dtype=torch.float32, device=device) > prob

tests/test_bitwise_or.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,7 @@
88

99
@skip_if_cuda_not_available
1010
@pytest.mark.parametrize(*generate_arguments(False))
11-
def test_cuda(shape, dtype, atol, rtol):
12-
device = "cuda"
13-
11+
def test_bitwise_or(shape, dtype, device, rtol, atol):
1412
if dtype == torch.bool:
1513
prob = 0.5
1614
input = torch.rand(shape, dtype=torch.float32, device=device) > prob

tests/test_bmm.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,12 @@
1010

1111
@skip_if_cuda_not_available
1212
@pytest.mark.parametrize(*generate_arguments())
13-
def test_cuda(m, n, k, dtype, atol, rtol):
14-
device = "cuda"
15-
13+
def test_bmm(m, n, k, dtype, device, rtol, atol):
1614
b = random.randint(4, 16)
1715
input = torch.randn((b, m, k), dtype=dtype, device=device)
1816
other = torch.randn((b, k, n), dtype=dtype, device=device)
1917

2018
ninetoothed_output = ntops.torch.bmm(input, other)
2119
reference_output = torch.bmm(input, other)
2220

23-
assert torch.allclose(ninetoothed_output, reference_output, atol=atol, rtol=rtol)
21+
assert torch.allclose(ninetoothed_output, reference_output, rtol=rtol, atol=atol)

tests/test_clamp.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,12 @@
88

99
@skip_if_cuda_not_available
1010
@pytest.mark.parametrize(*generate_arguments())
11-
def test_cuda(shape, dtype, atol, rtol):
12-
device = "cuda"
13-
11+
def test_clamp(shape, dtype, device, rtol, atol):
1412
input = torch.randn(shape, dtype=dtype, device=device)
1513
min = torch.randn(shape, dtype=dtype, device=device)
1614
max = torch.randn(shape, dtype=dtype, device=device)
1715

1816
ninetoothed_output = ntops.torch.clamp(input, min, max)
1917
reference_output = torch.clamp(input, min, max)
2018

21-
assert torch.allclose(ninetoothed_output, reference_output, atol=atol, rtol=rtol)
19+
assert torch.allclose(ninetoothed_output, reference_output, rtol=rtol, atol=atol)

0 commit comments

Comments
 (0)