Skip to content

Commit cf96897

Browse files
TimDettmersclaude
andcommitted
test: Add comprehensive Hadamard rotation tests
79 tests covering: - Orthogonality: H(H(x)) ≈ x for all block sizes and dtypes - Signed orthogonality: inv(H*D) = D*H inverse recovery - GEMM equivalence: H(A)@h(B)^T ≈ A@B^T (plain and signed) - GEMM on Qwen3-Coder-Next 70B shapes - Edge cases: partial blocks, various sizes, single block, 2D tensors - Input validation: invalid block sizes and dtypes - Determinism: identical outputs for identical inputs - Norm preservation: L2 norm invariance Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 9931b25 commit cf96897

File tree

1 file changed

+218
-0
lines changed

1 file changed

+218
-0
lines changed

tests/test_hadamard.py

Lines changed: 218 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,218 @@
1+
"""Tests for the Hadamard rotation kernel (hadamard_rotate)."""
2+
3+
import pytest
4+
import torch
5+
6+
from bitsandbytes.functional import hadamard_rotate
7+
8+
BLOCK_SIZES = [32, 64, 128, 256]
9+
DTYPES = [torch.float16, torch.bfloat16]
10+
11+
12+
class TestOrthogonality:
13+
"""H(H(x)) ≈ x for plain Hadamard (no signs)."""
14+
15+
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
16+
@pytest.mark.parametrize("dtype", DTYPES)
17+
def test_double_apply_identity(self, block_size, dtype):
18+
x = torch.randn(1024, dtype=dtype, device="cuda")
19+
x_orig = x.clone()
20+
hadamard_rotate(x, block_size=block_size)
21+
hadamard_rotate(x, block_size=block_size)
22+
atol = 1e-2 if dtype == torch.bfloat16 else 1e-3
23+
torch.testing.assert_close(x, x_orig, atol=atol, rtol=atol)
24+
25+
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
26+
@pytest.mark.parametrize("dtype", DTYPES)
27+
def test_double_apply_large(self, block_size, dtype):
28+
"""Test on a larger tensor (32K elements)."""
29+
x = torch.randn(32768, dtype=dtype, device="cuda")
30+
x_orig = x.clone()
31+
hadamard_rotate(x, block_size=block_size)
32+
hadamard_rotate(x, block_size=block_size)
33+
atol = 1e-2 if dtype == torch.bfloat16 else 1e-3
34+
torch.testing.assert_close(x, x_orig, atol=atol, rtol=atol)
35+
36+
37+
class TestSignedOrthogonality:
38+
"""Randomized Hadamard: R=H*D is orthogonal (R^T*R=I)."""
39+
40+
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
41+
@pytest.mark.parametrize("dtype", DTYPES)
42+
def test_signed_inverse(self, block_size, dtype):
43+
"""Verify inv(H*D) = D*H: forward then inverse recovers original."""
44+
signs = torch.randint(0, 2**31, (block_size // 32,), dtype=torch.int32, device="cuda")
45+
x = torch.randn(1024, dtype=dtype, device="cuda")
46+
x_orig = x.clone()
47+
48+
# Forward: H*D*x
49+
hadamard_rotate(x, block_size=block_size, signs=signs)
50+
51+
# Inverse: D*H*x' = first apply H (no signs), then sign flip
52+
hadamard_rotate(x, block_size=block_size) # H
53+
# Apply D (sign flip)
54+
x_flat = x.view(-1)
55+
for j in range(block_size // 32):
56+
word = signs[j].item()
57+
for bit in range(32):
58+
if word & (1 << bit):
59+
pos = j * 32 + bit
60+
x_flat[pos::block_size] *= -1
61+
62+
atol = 1e-2 if dtype == torch.bfloat16 else 1e-3
63+
torch.testing.assert_close(x, x_orig, atol=atol, rtol=atol)
64+
65+
66+
class TestGEMMEquivalence:
67+
"""H(A) @ H(B)^T ≈ A @ B^T (within quantization tolerance)."""
68+
69+
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
70+
@pytest.mark.parametrize("dtype", DTYPES)
71+
def test_gemm_plain(self, block_size, dtype):
72+
M, K, N = 4, 256, 8
73+
A = torch.randn(M, K, dtype=dtype, device="cuda")
74+
B = torch.randn(N, K, dtype=dtype, device="cuda")
75+
ref = A.float() @ B.float().T
76+
77+
A_rot = A.clone()
78+
B_rot = B.clone()
79+
hadamard_rotate(A_rot, block_size=block_size)
80+
hadamard_rotate(B_rot, block_size=block_size)
81+
result = A_rot.float() @ B_rot.float().T
82+
83+
atol = 0.1 if dtype == torch.bfloat16 else 0.05
84+
torch.testing.assert_close(result, ref, atol=atol, rtol=0.05)
85+
86+
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
87+
@pytest.mark.parametrize("dtype", DTYPES)
88+
def test_gemm_signed(self, block_size, dtype):
89+
"""GEMM equivalence with random sign flips."""
90+
M, K, N = 4, 256, 8
91+
signs = torch.randint(0, 2**31, (block_size // 32,), dtype=torch.int32, device="cuda")
92+
A = torch.randn(M, K, dtype=dtype, device="cuda")
93+
B = torch.randn(N, K, dtype=dtype, device="cuda")
94+
ref = A.float() @ B.float().T
95+
96+
A_rot = A.clone()
97+
B_rot = B.clone()
98+
hadamard_rotate(A_rot, block_size=block_size, signs=signs)
99+
hadamard_rotate(B_rot, block_size=block_size, signs=signs)
100+
result = A_rot.float() @ B_rot.float().T
101+
102+
atol = 0.1 if dtype == torch.bfloat16 else 0.05
103+
torch.testing.assert_close(result, ref, atol=atol, rtol=0.05)
104+
105+
def test_gemm_qwen3_shapes(self):
106+
"""GEMM equivalence on Qwen3-Coder-Next 70B shapes."""
107+
shapes = [
108+
(1, 2048, 5120), # gate/up at M=1
109+
(4, 5120, 2048), # down at M=4
110+
(1, 2048, 4096), # Q proj
111+
(4, 4096, 2048), # O proj
112+
]
113+
for M, K, N in shapes:
114+
A = torch.randn(M, K, dtype=torch.float16, device="cuda")
115+
B = torch.randn(N, K, dtype=torch.float16, device="cuda")
116+
ref = A.float() @ B.float().T
117+
118+
A_rot = A.clone()
119+
B_rot = B.clone()
120+
hadamard_rotate(A_rot, block_size=64)
121+
hadamard_rotate(B_rot, block_size=64)
122+
result = A_rot.float() @ B_rot.float().T
123+
124+
torch.testing.assert_close(result, ref, atol=0.05, rtol=0.05)
125+
126+
127+
class TestEdgeCases:
128+
"""Edge cases: sizes not divisible by block_size, various M values."""
129+
130+
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
131+
def test_size_not_divisible(self, block_size):
132+
"""When n is not divisible by block_size, the last partial block
133+
should still be processed (padded with zeros internally)."""
134+
n = block_size * 3 + 7 # partial block
135+
x = torch.randn(n, dtype=torch.float16, device="cuda")
136+
x_orig = x.clone()
137+
hadamard_rotate(x, block_size=block_size)
138+
# The rotated values should differ from the original
139+
assert not torch.allclose(x, x_orig, atol=1e-4)
140+
# Double-apply should recover the original
141+
hadamard_rotate(x, block_size=block_size)
142+
# Full blocks should be exact, partial block may have more error
143+
full_n = (n // block_size) * block_size
144+
torch.testing.assert_close(x[:full_n], x_orig[:full_n], atol=1e-3, rtol=1e-3)
145+
146+
@pytest.mark.parametrize("n", [32, 64, 128, 256, 512, 1024, 4096])
147+
def test_various_sizes(self, n):
148+
x = torch.randn(n, dtype=torch.float16, device="cuda")
149+
x_orig = x.clone()
150+
hadamard_rotate(x, block_size=32)
151+
hadamard_rotate(x, block_size=32)
152+
torch.testing.assert_close(x, x_orig, atol=1e-3, rtol=1e-3)
153+
154+
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
155+
def test_single_block(self, block_size):
156+
"""Exactly one block."""
157+
x = torch.randn(block_size, dtype=torch.float16, device="cuda")
158+
x_orig = x.clone()
159+
hadamard_rotate(x, block_size=block_size)
160+
hadamard_rotate(x, block_size=block_size)
161+
torch.testing.assert_close(x, x_orig, atol=1e-3, rtol=1e-3)
162+
163+
def test_invalid_block_size(self):
164+
x = torch.randn(128, dtype=torch.float16, device="cuda")
165+
with pytest.raises(RuntimeError):
166+
hadamard_rotate(x, block_size=16)
167+
with pytest.raises(RuntimeError):
168+
hadamard_rotate(x, block_size=48)
169+
170+
def test_invalid_dtype(self):
171+
x = torch.randn(128, dtype=torch.float32, device="cuda")
172+
with pytest.raises(RuntimeError):
173+
hadamard_rotate(x, block_size=32)
174+
175+
def test_2d_tensor(self):
176+
"""Rotation should work on 2D tensors (flattened internally)."""
177+
x = torch.randn(8, 64, dtype=torch.float16, device="cuda")
178+
x_orig = x.clone()
179+
hadamard_rotate(x, block_size=64)
180+
hadamard_rotate(x, block_size=64)
181+
torch.testing.assert_close(x, x_orig, atol=1e-3, rtol=1e-3)
182+
183+
184+
class TestDeterminism:
185+
"""Same input → same output."""
186+
187+
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
188+
@pytest.mark.parametrize("dtype", DTYPES)
189+
def test_deterministic(self, block_size, dtype):
190+
x = torch.randn(1024, dtype=dtype, device="cuda")
191+
a = x.clone()
192+
b = x.clone()
193+
hadamard_rotate(a, block_size=block_size)
194+
hadamard_rotate(b, block_size=block_size)
195+
torch.testing.assert_close(a, b, atol=0, rtol=0)
196+
197+
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
198+
def test_deterministic_signed(self, block_size):
199+
signs = torch.randint(0, 2**31, (block_size // 32,), dtype=torch.int32, device="cuda")
200+
x = torch.randn(1024, dtype=torch.float16, device="cuda")
201+
a = x.clone()
202+
b = x.clone()
203+
hadamard_rotate(a, block_size=block_size, signs=signs)
204+
hadamard_rotate(b, block_size=block_size, signs=signs)
205+
torch.testing.assert_close(a, b, atol=0, rtol=0)
206+
207+
208+
class TestNormPreservation:
209+
"""Hadamard rotation preserves L2 norm (orthogonal transform)."""
210+
211+
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
212+
@pytest.mark.parametrize("dtype", DTYPES)
213+
def test_norm_preservation(self, block_size, dtype):
214+
x = torch.randn(block_size * 4, dtype=dtype, device="cuda")
215+
norm_before = x.float().norm().item()
216+
hadamard_rotate(x, block_size=block_size)
217+
norm_after = x.float().norm().item()
218+
assert abs(norm_after - norm_before) / norm_before < 0.01

0 commit comments

Comments
 (0)