1010
1111
1212class TestOrthogonality :
13- """H(H(x)) ≈ x — Hadamard is its own inverse (involutory )."""
13+ """H(H(x)) ≈ x for plain Hadamard (no signs )."""
1414
1515 @pytest .mark .parametrize ("block_size" , BLOCK_SIZES )
1616 @pytest .mark .parametrize ("dtype" , DTYPES )
@@ -34,12 +34,41 @@ def test_double_apply_large(self, block_size, dtype):
3434 torch .testing .assert_close (x , x_orig , atol = atol , rtol = atol )
3535
3636
37+ class TestSignedOrthogonality :
38+ """Randomized Hadamard: R=H*D is orthogonal (R^T*R=I)."""
39+
40+ @pytest .mark .parametrize ("block_size" , BLOCK_SIZES )
41+ @pytest .mark .parametrize ("dtype" , DTYPES )
42+ def test_signed_inverse (self , block_size , dtype ):
43+ """Verify inv(H*D) = D*H: forward then inverse recovers original."""
44+ signs = torch .randint (0 , 2 ** 31 , (block_size // 32 ,), dtype = torch .int32 , device = "cuda" )
45+ x = torch .randn (1024 , dtype = dtype , device = "cuda" )
46+ x_orig = x .clone ()
47+
48+ # Forward: H*D*x
49+ hadamard_rotate (x , block_size = block_size , signs = signs )
50+
51+ # Inverse: D*H*x' = first apply H (no signs), then sign flip
52+ hadamard_rotate (x , block_size = block_size ) # H
53+ # Apply D (sign flip)
54+ x_flat = x .view (- 1 )
55+ for j in range (block_size // 32 ):
56+ word = signs [j ].item ()
57+ for bit in range (32 ):
58+ if word & (1 << bit ):
59+ pos = j * 32 + bit
60+ x_flat [pos ::block_size ] *= - 1
61+
62+ atol = 1e-2 if dtype == torch .bfloat16 else 1e-3
63+ torch .testing .assert_close (x , x_orig , atol = atol , rtol = atol )
64+
65+
3766class TestGEMMEquivalence :
3867 """H(A) @ H(B)^T ≈ A @ B^T (within quantization tolerance)."""
3968
4069 @pytest .mark .parametrize ("block_size" , BLOCK_SIZES )
4170 @pytest .mark .parametrize ("dtype" , DTYPES )
42- def test_gemm (self , block_size , dtype ):
71+ def test_gemm_plain (self , block_size , dtype ):
4372 M , K , N = 4 , 256 , 8
4473 A = torch .randn (M , K , dtype = dtype , device = "cuda" )
4574 B = torch .randn (N , K , dtype = dtype , device = "cuda" )
@@ -54,6 +83,25 @@ def test_gemm(self, block_size, dtype):
5483 atol = 0.1 if dtype == torch .bfloat16 else 0.05
5584 torch .testing .assert_close (result , ref , atol = atol , rtol = 0.05 )
5685
86+ @pytest .mark .parametrize ("block_size" , BLOCK_SIZES )
87+ @pytest .mark .parametrize ("dtype" , DTYPES )
88+ def test_gemm_signed (self , block_size , dtype ):
89+ """GEMM equivalence with random sign flips."""
90+ M , K , N = 4 , 256 , 8
91+ signs = torch .randint (0 , 2 ** 31 , (block_size // 32 ,), dtype = torch .int32 , device = "cuda" )
92+ A = torch .randn (M , K , dtype = dtype , device = "cuda" )
93+ B = torch .randn (N , K , dtype = dtype , device = "cuda" )
94+ ref = A .float () @ B .float ().T
95+
96+ A_rot = A .clone ()
97+ B_rot = B .clone ()
98+ hadamard_rotate (A_rot , block_size = block_size , signs = signs )
99+ hadamard_rotate (B_rot , block_size = block_size , signs = signs )
100+ result = A_rot .float () @ B_rot .float ().T
101+
102+ atol = 0.1 if dtype == torch .bfloat16 else 0.05
103+ torch .testing .assert_close (result , ref , atol = atol , rtol = 0.05 )
104+
57105 def test_gemm_qwen3_shapes (self ):
58106 """GEMM equivalence on Qwen3-Coder-Next 70B shapes."""
59107 shapes = [
@@ -146,6 +194,16 @@ def test_deterministic(self, block_size, dtype):
146194 hadamard_rotate (b , block_size = block_size )
147195 torch .testing .assert_close (a , b , atol = 0 , rtol = 0 )
148196
197+ @pytest .mark .parametrize ("block_size" , BLOCK_SIZES )
198+ def test_deterministic_signed (self , block_size ):
199+ signs = torch .randint (0 , 2 ** 31 , (block_size // 32 ,), dtype = torch .int32 , device = "cuda" )
200+ x = torch .randn (1024 , dtype = torch .float16 , device = "cuda" )
201+ a = x .clone ()
202+ b = x .clone ()
203+ hadamard_rotate (a , block_size = block_size , signs = signs )
204+ hadamard_rotate (b , block_size = block_size , signs = signs )
205+ torch .testing .assert_close (a , b , atol = 0 , rtol = 0 )
206+
149207
150208class TestNormPreservation :
151209 """Hadamard rotation preserves L2 norm (orthogonal transform)."""
0 commit comments