Skip to content

Commit 72ee3e8

Browse files
committed
Revert "refactor: Remove random sign flips from Hadamard rotation"
This reverts commit fcfca9f.
1 parent 14ee2e9 commit 72ee3e8

File tree

6 files changed

+124
-29
lines changed

6 files changed

+124
-29
lines changed

bitsandbytes/_ops.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -588,12 +588,12 @@ def _(
588588

589589
torch.library.define(
590590
"bitsandbytes::hadamard_rotate_",
591-
"(Tensor(a!) data, int block_size) -> Tensor(a!)",
591+
"(Tensor(a!) data, int block_size, Tensor? signs) -> Tensor(a!)",
592592
)
593593

594594

595595
@register_fake("bitsandbytes::hadamard_rotate_")
596-
def _(data: torch.Tensor, block_size: int) -> torch.Tensor:
596+
def _(data: torch.Tensor, block_size: int, signs: Optional[torch.Tensor]) -> torch.Tensor:
597597
torch._check(
598598
block_size in (32, 64, 128, 256),
599599
lambda: f"block_size must be 32, 64, 128, or 256, got {block_size}",
@@ -602,6 +602,15 @@ def _(data: torch.Tensor, block_size: int) -> torch.Tensor:
602602
data.dtype in (torch.float16, torch.bfloat16),
603603
lambda: f"hadamard_rotate only supports float16/bfloat16, got {data.dtype}",
604604
)
605+
if signs is not None:
606+
torch._check(
607+
signs.dtype == torch.int32,
608+
lambda: f"signs must be int32, got {signs.dtype}",
609+
)
610+
torch._check(
611+
signs.numel() == block_size // 32,
612+
lambda: f"signs must have {block_size // 32} elements for block_size={block_size}, got {signs.numel()}",
613+
)
605614
return data
606615

607616

bitsandbytes/backends/cuda/ops.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1001,7 +1001,7 @@ def _(
10011001

10021002

10031003
@register_kernel("bitsandbytes::hadamard_rotate_", "cuda")
1004-
def _(data: torch.Tensor, block_size: int) -> torch.Tensor:
1004+
def _(data: torch.Tensor, block_size: int, signs: Optional[torch.Tensor]) -> torch.Tensor:
10051005
torch._check(
10061006
block_size in (32, 64, 128, 256),
10071007
lambda: f"block_size must be 32, 64, 128, or 256, got {block_size}",
@@ -1012,12 +1012,14 @@ def _(data: torch.Tensor, block_size: int) -> torch.Tensor:
10121012
)
10131013

10141014
tname = _KBIT_DTYPE_SUFFIX[data.dtype]
1015+
signs_ptr = get_ptr(signs) if signs is not None else None
10151016
with _cuda_device_of(data):
10161017
fn = getattr(lib, f"chadamard_rotate_{tname}")
10171018
fn(
10181019
get_ptr(data),
10191020
ct.c_int(data.numel()),
10201021
ct.c_int(block_size),
1022+
signs_ptr,
10211023
_get_tensor_stream(data),
10221024
)
10231025

bitsandbytes/functional.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1135,22 +1135,30 @@ def decode_absmax_e4m4(encoded: Tensor, bias: int = 11) -> Tensor:
11351135
return result
11361136

11371137

1138-
def hadamard_rotate(data: Tensor, block_size: int = 32) -> Tensor:
1139-
"""Apply in-place Walsh-Hadamard rotation to contiguous blocks.
1138+
def hadamard_rotate(
1139+
data: Tensor,
1140+
block_size: int = 32,
1141+
signs: Optional[Tensor] = None,
1142+
) -> Tensor:
1143+
"""Apply in-place randomized Walsh-Hadamard rotation (H*D) to contiguous blocks.
11401144
11411145
Spreads outliers across quantization blocks, improving kbit accuracy.
1142-
Since H is orthogonal, rotating both weights and activations preserves
1143-
the GEMM result: H(A) @ H(B)^T = A @ B^T.
1146+
Since H*D is orthogonal, rotating both weights and activations with the
1147+
same signs preserves the GEMM result: (H*D)(A) @ (H*D)(B)^T = A @ B^T.
11441148
11451149
Args:
11461150
data: Input tensor (float16 or bfloat16). Modified in-place.
11471151
block_size: Rotation block size (32, 64, 128, or 256).
1152+
signs: Optional int32 tensor of block_size//32 words. Each bit controls
1153+
the sign flip for one element within the block. If None, no sign
1154+
flips are applied (plain Hadamard). Generate once per model with
1155+
``torch.randint(0, 2**32, (block_size // 32,), dtype=torch.int32)``.
11481156
11491157
Returns:
11501158
The input tensor, rotated in-place.
11511159
"""
11521160
data_flat = data.contiguous().view(-1)
1153-
torch.ops.bitsandbytes.hadamard_rotate_(data_flat, block_size)
1161+
torch.ops.bitsandbytes.hadamard_rotate_(data_flat, block_size, signs)
11541162
return data
11551163

11561164

csrc/ops.cu

Lines changed: 26 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1015,19 +1015,25 @@ void repackKbit(
10151015
// ===========================================================================
10161016
// Hadamard rotation kernel (in-place, blocksize-templated)
10171017
//
1018-
// Applies a Walsh-Hadamard transform to contiguous blocks of BLOCK_SIZE
1019-
// elements. Used to spread outliers before kbit quantization.
1020-
// Since H is orthogonal, rotating both weights and activations preserves
1021-
// the GEMM result: H(A) @ H(B)^T = A @ B^T.
1018+
// Applies a randomized Walsh-Hadamard transform (H*D) to contiguous blocks
1019+
// of BLOCK_SIZE elements. D is a diagonal sign-flip matrix (optional).
1020+
// Used to spread outliers before kbit quantization.
1021+
// Since H*D is orthogonal, rotating both weights and activations preserves
1022+
// the GEMM result: (H*D)(A) @ (H*D)(B)^T = A @ B^T.
10221023
//
10231024
// One warp per rotation block:
10241025
// BLOCK_SIZE=32: 1 elem/thread, 5 shuffle stages
10251026
// BLOCK_SIZE=64: 2 elem/thread, 1 register + 5 shuffle stages
10261027
// BLOCK_SIZE=128: 4 elem/thread, 2 register + 5 shuffle stages
10271028
// BLOCK_SIZE=256: 8 elem/thread, 3 register + 5 shuffle stages
1029+
//
1030+
// signs: optional bitmask of BLOCK_SIZE/32 uint32 words. If non-null, bit i
1031+
// set means element i is negated before the Hadamard butterfly. Same sign
1032+
// vector is applied to every block.
10281033
// ===========================================================================
10291034

1030-
template <int BLOCK_SIZE, typename T> __global__ void kHadamardRotate(T* __restrict__ data, const int n) {
1035+
template <int BLOCK_SIZE, typename T>
1036+
__global__ void kHadamardRotate(T* __restrict__ data, const int n, const unsigned int* __restrict__ signs) {
10311037
constexpr int ELEMS_PER_THREAD = BLOCK_SIZE / 32;
10321038
static_assert(BLOCK_SIZE >= 32 && (BLOCK_SIZE & (BLOCK_SIZE - 1)) == 0, "BLOCK_SIZE must be a power of 2 >= 32");
10331039

@@ -1047,6 +1053,16 @@ template <int BLOCK_SIZE, typename T> __global__ void kHadamardRotate(T* __restr
10471053
vals[j] = (idx < n) ? (float)data[idx] : 0.0f;
10481054
}
10491055

1056+
// Apply random sign flips (D matrix) before butterfly.
1057+
// Element at position lane_id + j*32 uses word j, bit lane_id.
1058+
if (signs != nullptr) {
1059+
#pragma unroll
1060+
for (int j = 0; j < ELEMS_PER_THREAD; j++) {
1061+
if (signs[j] & (1u << lane_id))
1062+
vals[j] = -vals[j];
1063+
}
1064+
}
1065+
10501066
// In-register butterfly stages (strides >= 32).
10511067
// Stride S in global space corresponds to element index s = S/32.
10521068
// Element j pairs with element j ^ s (both in the same thread).
@@ -1091,17 +1107,18 @@ template <int BLOCK_SIZE, typename T> __global__ void kHadamardRotate(T* __restr
10911107

10921108
// ---- Hadamard rotation launch wrapper ----
10931109

1094-
template <int BLOCK_SIZE, typename T> void hadamardRotate(T* data, int n, cudaStream_t stream) {
1110+
template <int BLOCK_SIZE, typename T>
1111+
void hadamardRotate(T* data, int n, const unsigned int* signs, cudaStream_t stream) {
10951112
const int num_blocks = (n + BLOCK_SIZE - 1) / BLOCK_SIZE;
10961113
const int num_cuda_blocks = (num_blocks + KBIT_WARPS_PER_BLOCK - 1) / KBIT_WARPS_PER_BLOCK;
1097-
kHadamardRotate<BLOCK_SIZE, T><<<num_cuda_blocks, KBIT_THREADS_PER_BLOCK, 0, stream>>>(data, n);
1114+
kHadamardRotate<BLOCK_SIZE, T><<<num_cuda_blocks, KBIT_THREADS_PER_BLOCK, 0, stream>>>(data, n, signs);
10981115
CUDA_CHECK_RETURN(cudaPeekAtLastError());
10991116
}
11001117

11011118
// Explicit instantiations: 4 block sizes x 2 dtypes
11021119
#define INSTANTIATE_HADAMARD(BS) \
1103-
template void hadamardRotate<BS, half>(half*, int, cudaStream_t); \
1104-
template void hadamardRotate<BS, __nv_bfloat16>(__nv_bfloat16*, int, cudaStream_t);
1120+
template void hadamardRotate<BS, half>(half*, int, const unsigned int*, cudaStream_t); \
1121+
template void hadamardRotate<BS, __nv_bfloat16>(__nv_bfloat16*, int, const unsigned int*, cudaStream_t);
11051122

11061123
INSTANTIATE_HADAMARD(32)
11071124
INSTANTIATE_HADAMARD(64)

csrc/pythonInterface.cpp

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -800,23 +800,24 @@ MAKE_KBIT_SCALAR_GEMV_V2_FP16ABS(5)
800800
void testMMA(const half*, const half*, float*);
801801

802802
// Forward declarations of hadamard rotation template
803-
template <int BLOCK_SIZE, typename T> void hadamardRotate(T* data, int n, cudaStream_t stream);
803+
template <int BLOCK_SIZE, typename T>
804+
void hadamardRotate(T* data, int n, const unsigned int* signs, cudaStream_t stream);
804805

805806
// Unmangled hadamard rotation wrappers (dispatch block_size at runtime)
806807
#define MAKE_HADAMARD_ROTATE(tname, T) \
807-
void hadamard_rotate_##tname(T* data, int n, int block_size, cudaStream_t stream) { \
808+
void hadamard_rotate_##tname(T* data, int n, int block_size, const unsigned int* signs, cudaStream_t stream) { \
808809
switch (block_size) { \
809810
case 32: \
810-
hadamardRotate<32, T>(data, n, stream); \
811+
hadamardRotate<32, T>(data, n, signs, stream); \
811812
break; \
812813
case 64: \
813-
hadamardRotate<64, T>(data, n, stream); \
814+
hadamardRotate<64, T>(data, n, signs, stream); \
814815
break; \
815816
case 128: \
816-
hadamardRotate<128, T>(data, n, stream); \
817+
hadamardRotate<128, T>(data, n, signs, stream); \
817818
break; \
818819
case 256: \
819-
hadamardRotate<256, T>(data, n, stream); \
820+
hadamardRotate<256, T>(data, n, signs, stream); \
820821
break; \
821822
} \
822823
}
@@ -1698,12 +1699,12 @@ MAKE_CKBIT_SCALAR_GEMV_V2_FP16ABS(4)
16981699
MAKE_CKBIT_SCALAR_GEMV_V2_FP16ABS(5)
16991700

17001701
// Hadamard rotation extern C wrappers
1701-
void chadamard_rotate_fp16(half* data, int n, int block_size, cudaStream_t stream) {
1702-
hadamard_rotate_fp16(data, n, block_size, stream);
1702+
void chadamard_rotate_fp16(half* data, int n, int block_size, const unsigned int* signs, cudaStream_t stream) {
1703+
hadamard_rotate_fp16(data, n, block_size, signs, stream);
17031704
}
17041705

1705-
void chadamard_rotate_bf16(__nv_bfloat16* data, int n, int block_size, cudaStream_t stream) {
1706-
hadamard_rotate_bf16(data, n, block_size, stream);
1706+
void chadamard_rotate_bf16(__nv_bfloat16* data, int n, int block_size, const unsigned int* signs, cudaStream_t stream) {
1707+
hadamard_rotate_bf16(data, n, block_size, signs, stream);
17071708
}
17081709

17091710
#endif

tests/test_hadamard.py

Lines changed: 60 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111

1212
class TestOrthogonality:
13-
"""H(H(x)) ≈ x Hadamard is its own inverse (involutory)."""
13+
"""H(H(x)) ≈ x for plain Hadamard (no signs)."""
1414

1515
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
1616
@pytest.mark.parametrize("dtype", DTYPES)
@@ -34,12 +34,41 @@ def test_double_apply_large(self, block_size, dtype):
3434
torch.testing.assert_close(x, x_orig, atol=atol, rtol=atol)
3535

3636

37+
class TestSignedOrthogonality:
38+
"""Randomized Hadamard: R=H*D is orthogonal (R^T*R=I)."""
39+
40+
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
41+
@pytest.mark.parametrize("dtype", DTYPES)
42+
def test_signed_inverse(self, block_size, dtype):
43+
"""Verify inv(H*D) = D*H: forward then inverse recovers original."""
44+
signs = torch.randint(0, 2**31, (block_size // 32,), dtype=torch.int32, device="cuda")
45+
x = torch.randn(1024, dtype=dtype, device="cuda")
46+
x_orig = x.clone()
47+
48+
# Forward: H*D*x
49+
hadamard_rotate(x, block_size=block_size, signs=signs)
50+
51+
# Inverse: D*H*x' = first apply H (no signs), then sign flip
52+
hadamard_rotate(x, block_size=block_size) # H
53+
# Apply D (sign flip)
54+
x_flat = x.view(-1)
55+
for j in range(block_size // 32):
56+
word = signs[j].item()
57+
for bit in range(32):
58+
if word & (1 << bit):
59+
pos = j * 32 + bit
60+
x_flat[pos::block_size] *= -1
61+
62+
atol = 1e-2 if dtype == torch.bfloat16 else 1e-3
63+
torch.testing.assert_close(x, x_orig, atol=atol, rtol=atol)
64+
65+
3766
class TestGEMMEquivalence:
3867
"""H(A) @ H(B)^T ≈ A @ B^T (within quantization tolerance)."""
3968

4069
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
4170
@pytest.mark.parametrize("dtype", DTYPES)
42-
def test_gemm(self, block_size, dtype):
71+
def test_gemm_plain(self, block_size, dtype):
4372
M, K, N = 4, 256, 8
4473
A = torch.randn(M, K, dtype=dtype, device="cuda")
4574
B = torch.randn(N, K, dtype=dtype, device="cuda")
@@ -54,6 +83,25 @@ def test_gemm(self, block_size, dtype):
5483
atol = 0.1 if dtype == torch.bfloat16 else 0.05
5584
torch.testing.assert_close(result, ref, atol=atol, rtol=0.05)
5685

86+
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
87+
@pytest.mark.parametrize("dtype", DTYPES)
88+
def test_gemm_signed(self, block_size, dtype):
89+
"""GEMM equivalence with random sign flips."""
90+
M, K, N = 4, 256, 8
91+
signs = torch.randint(0, 2**31, (block_size // 32,), dtype=torch.int32, device="cuda")
92+
A = torch.randn(M, K, dtype=dtype, device="cuda")
93+
B = torch.randn(N, K, dtype=dtype, device="cuda")
94+
ref = A.float() @ B.float().T
95+
96+
A_rot = A.clone()
97+
B_rot = B.clone()
98+
hadamard_rotate(A_rot, block_size=block_size, signs=signs)
99+
hadamard_rotate(B_rot, block_size=block_size, signs=signs)
100+
result = A_rot.float() @ B_rot.float().T
101+
102+
atol = 0.1 if dtype == torch.bfloat16 else 0.05
103+
torch.testing.assert_close(result, ref, atol=atol, rtol=0.05)
104+
57105
def test_gemm_qwen3_shapes(self):
58106
"""GEMM equivalence on Qwen3-Coder-Next 70B shapes."""
59107
shapes = [
@@ -146,6 +194,16 @@ def test_deterministic(self, block_size, dtype):
146194
hadamard_rotate(b, block_size=block_size)
147195
torch.testing.assert_close(a, b, atol=0, rtol=0)
148196

197+
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
198+
def test_deterministic_signed(self, block_size):
199+
signs = torch.randint(0, 2**31, (block_size // 32,), dtype=torch.int32, device="cuda")
200+
x = torch.randn(1024, dtype=torch.float16, device="cuda")
201+
a = x.clone()
202+
b = x.clone()
203+
hadamard_rotate(a, block_size=block_size, signs=signs)
204+
hadamard_rotate(b, block_size=block_size, signs=signs)
205+
torch.testing.assert_close(a, b, atol=0, rtol=0)
206+
149207

150208
class TestNormPreservation:
151209
"""Hadamard rotation preserves L2 norm (orthogonal transform)."""

0 commit comments

Comments
 (0)