Skip to content

Commit e2309e3

Browse files
authored
Merge pull request #8 from InfiniTensor/unify-tests
Unify correctness and benchmark testing with `pytest`
2 parents 3b5a6f9 + 2dea118 commit e2309e3

31 files changed

Lines changed: 1123 additions & 1271 deletions

add.py

Lines changed: 0 additions & 70 deletions
This file was deleted.

addmm.py

Lines changed: 0 additions & 81 deletions
This file was deleted.

bench.py

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
import torch
2+
import triton
3+
4+
DISPLAY_NAMES = {
5+
"ninetoothed": "NineToothed",
6+
"torch": "PyTorch",
7+
"triton": "Triton",
8+
}
9+
10+
STYLES = [
11+
("blue", "-"),
12+
("green", "-"),
13+
("orange", "-"),
14+
("red", "-"),
15+
("purple", "-"),
16+
("cyan", "-"),
17+
]
18+
19+
20+
def assert_match(impls, args, kwargs=None, tolerances=None):
21+
"""Assert that all implementations produce matching outputs.
22+
23+
Same API as ``check``, but raises ``AssertionError`` on mismatch
24+
instead of printing. Intended for use in test suites.
25+
26+
:param impls: Ordered dict mapping provider name to callable.
27+
:param args: Tuple of positional arguments.
28+
:param kwargs: Dict of keyword arguments.
29+
:param tolerances: Dict mapping provider name to ``torch.allclose`` kwargs.
30+
"""
31+
kwargs = kwargs or {}
32+
tolerances = tolerances or {}
33+
results = {name: fn(*args, **kwargs) for name, fn in impls.items()}
34+
35+
names = list(impls)
36+
reference_name = names[0]
37+
reference = results[reference_name]
38+
39+
for name in names[1:]:
40+
tol = tolerances.get(name, {})
41+
ref_display = _display_name(reference_name)
42+
other_display = _display_name(name)
43+
44+
assert torch.allclose(reference, results[name], **tol), (
45+
f"{ref_display} and {other_display} outputs differ."
46+
)
47+
48+
49+
def benchmark(
50+
impls,
51+
make_inputs,
52+
x_names,
53+
x_vals,
54+
name,
55+
benchmark_args=None,
56+
x_log=True,
57+
assert_correctness=True,
58+
tolerances=None,
59+
save_path=".",
60+
):
61+
"""Create and run a performance benchmark.
62+
63+
:param impls: Ordered dict mapping provider name to callable.
64+
:param make_inputs: Callable returning ``(args_tuple, kwargs_dict)``.
65+
:param x_names: List of benchmark parameter names.
66+
:param x_vals: List of benchmark parameter values.
67+
:param name: Operator name, used for the plot filename.
68+
:param benchmark_args: Fixed benchmark args dict.
69+
:param x_log: Whether to use log scale for the x-axis.
70+
:param tolerances: Dict mapping provider name to ``torch.allclose`` kwargs.
71+
:param assert_correctness: Whether to assert correctness at each point.
72+
:param save_path: Directory to save plot files, or ``None`` to skip saving.
73+
"""
74+
providers = list(impls)
75+
tolerances = tolerances or {}
76+
77+
@triton.testing.perf_report(
78+
triton.testing.Benchmark(
79+
x_names=x_names,
80+
x_vals=x_vals,
81+
line_arg="provider",
82+
line_vals=providers,
83+
line_names=[_display_name(p) for p in providers],
84+
plot_name=f"{name}-performance",
85+
args=benchmark_args or {},
86+
ylabel="ms",
87+
x_log=x_log,
88+
styles=[_style(i) for i in range(len(providers))],
89+
)
90+
)
91+
def bench(provider, **params):
92+
args, kwargs = make_inputs(**params)
93+
94+
if assert_correctness:
95+
results = {p: impls[p](*args, **kwargs) for p in providers}
96+
reference = results[providers[0]]
97+
98+
for p in providers[1:]:
99+
tol = tolerances.get(p, {})
100+
assert torch.allclose(reference, results[p], **tol)
101+
102+
return triton.testing.do_bench(lambda: impls[provider](*args, **kwargs))
103+
104+
bench.run(print_data=True, save_path=save_path)
105+
106+
107+
def _display_name(name):
108+
"""Return the display name for a provider."""
109+
110+
return DISPLAY_NAMES.get(name, name)
111+
112+
113+
def _style(index):
114+
"""Return a plot style, cycling through available options."""
115+
116+
return STYLES[index % len(STYLES)]

bmm.py

Lines changed: 0 additions & 70 deletions
This file was deleted.

0 commit comments

Comments
 (0)