1010
1111
1212class TestOrthogonality :
13- """H(H(x)) ≈ x for plain Hadamard (no signs )."""
13+ """H(H(x)) ≈ x — Hadamard is its own inverse (involutory )."""
1414
1515 @pytest .mark .parametrize ("block_size" , BLOCK_SIZES )
1616 @pytest .mark .parametrize ("dtype" , DTYPES )
@@ -34,41 +34,12 @@ def test_double_apply_large(self, block_size, dtype):
3434 torch .testing .assert_close (x , x_orig , atol = atol , rtol = atol )
3535
3636
37- class TestSignedOrthogonality :
38- """Randomized Hadamard: R=H*D is orthogonal (R^T*R=I)."""
39-
40- @pytest .mark .parametrize ("block_size" , BLOCK_SIZES )
41- @pytest .mark .parametrize ("dtype" , DTYPES )
42- def test_signed_inverse (self , block_size , dtype ):
43- """Verify inv(H*D) = D*H: forward then inverse recovers original."""
44- signs = torch .randint (0 , 2 ** 31 , (block_size // 32 ,), dtype = torch .int32 , device = "cuda" )
45- x = torch .randn (1024 , dtype = dtype , device = "cuda" )
46- x_orig = x .clone ()
47-
48- # Forward: H*D*x
49- hadamard_rotate (x , block_size = block_size , signs = signs )
50-
51- # Inverse: D*H*x' = first apply H (no signs), then sign flip
52- hadamard_rotate (x , block_size = block_size ) # H
53- # Apply D (sign flip)
54- x_flat = x .view (- 1 )
55- for j in range (block_size // 32 ):
56- word = signs [j ].item ()
57- for bit in range (32 ):
58- if word & (1 << bit ):
59- pos = j * 32 + bit
60- x_flat [pos ::block_size ] *= - 1
61-
62- atol = 1e-2 if dtype == torch .bfloat16 else 1e-3
63- torch .testing .assert_close (x , x_orig , atol = atol , rtol = atol )
64-
65-
6637class TestGEMMEquivalence :
6738 """H(A) @ H(B)^T ≈ A @ B^T (within quantization tolerance)."""
6839
6940 @pytest .mark .parametrize ("block_size" , BLOCK_SIZES )
7041 @pytest .mark .parametrize ("dtype" , DTYPES )
71- def test_gemm_plain (self , block_size , dtype ):
42+ def test_gemm (self , block_size , dtype ):
7243 M , K , N = 4 , 256 , 8
7344 A = torch .randn (M , K , dtype = dtype , device = "cuda" )
7445 B = torch .randn (N , K , dtype = dtype , device = "cuda" )
@@ -83,25 +54,6 @@ def test_gemm_plain(self, block_size, dtype):
8354 atol = 0.1 if dtype == torch .bfloat16 else 0.05
8455 torch .testing .assert_close (result , ref , atol = atol , rtol = 0.05 )
8556
86- @pytest .mark .parametrize ("block_size" , BLOCK_SIZES )
87- @pytest .mark .parametrize ("dtype" , DTYPES )
88- def test_gemm_signed (self , block_size , dtype ):
89- """GEMM equivalence with random sign flips."""
90- M , K , N = 4 , 256 , 8
91- signs = torch .randint (0 , 2 ** 31 , (block_size // 32 ,), dtype = torch .int32 , device = "cuda" )
92- A = torch .randn (M , K , dtype = dtype , device = "cuda" )
93- B = torch .randn (N , K , dtype = dtype , device = "cuda" )
94- ref = A .float () @ B .float ().T
95-
96- A_rot = A .clone ()
97- B_rot = B .clone ()
98- hadamard_rotate (A_rot , block_size = block_size , signs = signs )
99- hadamard_rotate (B_rot , block_size = block_size , signs = signs )
100- result = A_rot .float () @ B_rot .float ().T
101-
102- atol = 0.1 if dtype == torch .bfloat16 else 0.05
103- torch .testing .assert_close (result , ref , atol = atol , rtol = 0.05 )
104-
10557 def test_gemm_qwen3_shapes (self ):
10658 """GEMM equivalence on Qwen3-Coder-Next 70B shapes."""
10759 shapes = [
@@ -194,16 +146,6 @@ def test_deterministic(self, block_size, dtype):
194146 hadamard_rotate (b , block_size = block_size )
195147 torch .testing .assert_close (a , b , atol = 0 , rtol = 0 )
196148
197- @pytest .mark .parametrize ("block_size" , BLOCK_SIZES )
198- def test_deterministic_signed (self , block_size ):
199- signs = torch .randint (0 , 2 ** 31 , (block_size // 32 ,), dtype = torch .int32 , device = "cuda" )
200- x = torch .randn (1024 , dtype = torch .float16 , device = "cuda" )
201- a = x .clone ()
202- b = x .clone ()
203- hadamard_rotate (a , block_size = block_size , signs = signs )
204- hadamard_rotate (b , block_size = block_size , signs = signs )
205- torch .testing .assert_close (a , b , atol = 0 , rtol = 0 )
206-
207149
208150class TestNormPreservation :
209151 """Hadamard rotation preserves L2 norm (orthogonal transform)."""
0 commit comments