Skip to content

Commit 737bad0

Browse files
committed
Implement Hadamard transform
1 parent 7fd0e83 commit 737bad0

3 files changed

Lines changed: 585 additions & 0 deletions

File tree

benchmarks/benchmark_hadamard.py

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
import argparse
2+
import math
3+
import time
4+
5+
import torch
6+
from triton.testing import do_bench
7+
8+
from quack.hadamard import hadamard_transform, hadamard_transform_ref
9+
10+
try:
11+
from fast_hadamard_transform import hadamard_transform as fast_hadamard_transform
12+
except ImportError:
13+
fast_hadamard_transform = None
14+
15+
16+
DTYPES = {
17+
"float16": torch.float16,
18+
"bfloat16": torch.bfloat16,
19+
"float32": torch.float32,
20+
}
21+
22+
TOLERANCES = {
23+
torch.bfloat16: (1e-1, 2e-2),
24+
torch.float16: (3e-2, 2e-2),
25+
torch.float32: (1e-4, 1e-4),
26+
}
27+
28+
29+
def _effective_bandwidth_gbps(x: torch.Tensor, latency_ms: float) -> float:
30+
bytes_moved = 2 * x.numel() * x.element_size()
31+
return bytes_moved / (latency_ms / 1000.0) / 1e9
32+
33+
34+
def _bench(name, fn, x, warmup, rep):
35+
for _ in range(3):
36+
fn()
37+
torch.cuda.synchronize()
38+
time.sleep(0.2)
39+
latency_ms = do_bench(fn, warmup=warmup, rep=rep)
40+
print(
41+
f"{name:>24}: {latency_ms:.4f} ms, "
42+
f"{_effective_bandwidth_gbps(x, latency_ms):.1f} effective GB/s"
43+
)
44+
return latency_ms
45+
46+
47+
def run_hadamard(M, N, dtype, scale, warmup, rep, include_torch):
48+
if not torch.cuda.is_available():
49+
raise RuntimeError("CUDA is required to run this benchmark")
50+
if N > 32768:
51+
raise ValueError("QuACK Hadamard currently supports N <= 32768")
52+
53+
torch.manual_seed(0)
54+
x = torch.randn(M, N, device="cuda", dtype=dtype)
55+
print(f"Tensor dimensions: [{M}, {N}]")
56+
print(f"dtype: {dtype}, scale: {scale}")
57+
58+
out = hadamard_transform(x, scale=scale)
59+
if fast_hadamard_transform is not None:
60+
out_ref = fast_hadamard_transform(x, scale)
61+
ref_name = "fast-hadamard-transform"
62+
else:
63+
out_ref = hadamard_transform_ref(x, scale=scale)
64+
ref_name = "torch reference"
65+
atol, rtol = TOLERANCES[dtype]
66+
torch.testing.assert_close(out, out_ref, atol=atol, rtol=rtol)
67+
print(f"Correctness: compared QuACK against {ref_name}")
68+
69+
_bench("QuACK CuTe-DSL", lambda: hadamard_transform(x, scale=scale), x, warmup, rep)
70+
71+
if fast_hadamard_transform is not None:
72+
_bench(
73+
"fast-hadamard-transform",
74+
lambda: fast_hadamard_transform(x, scale),
75+
x,
76+
warmup,
77+
rep,
78+
)
79+
else:
80+
print("fast-hadamard-transform: not installed")
81+
82+
_bench("torch.clone lower bound", lambda: torch.clone(x), x, warmup, rep)
83+
84+
if include_torch:
85+
_bench("torch FWHT reference", lambda: hadamard_transform_ref(x, scale=scale), x, 3, 10)
86+
87+
88+
if __name__ == "__main__":
89+
parser = argparse.ArgumentParser(description="Benchmark Hadamard transform")
90+
parser.add_argument("--M", default=8192, type=int)
91+
parser.add_argument("--N", default=4096, type=int)
92+
parser.add_argument("--dtype", choices=DTYPES.keys(), default="bfloat16")
93+
parser.add_argument("--scale", default=None, type=float)
94+
parser.add_argument("--warmup_iterations", default=10, type=int)
95+
parser.add_argument("--iterations", default=100, type=int)
96+
parser.add_argument("--include-torch", action="store_true")
97+
args = parser.parse_args()
98+
99+
dtype = DTYPES[args.dtype]
100+
scale = args.scale
101+
if scale is None:
102+
scale = 1.0 / math.sqrt(1 << (args.N - 1).bit_length())
103+
104+
run_hadamard(
105+
args.M,
106+
args.N,
107+
dtype,
108+
scale,
109+
args.warmup_iterations,
110+
args.iterations,
111+
args.include_torch,
112+
)

0 commit comments

Comments
 (0)