Skip to content

Commit fcfca9f

Browse files
TimDettmersclaude
andcommitted
refactor: Remove random sign flips from Hadamard rotation
Simplify the Hadamard rotation API by removing the optional signs parameter. Plain Walsh-Hadamard is sufficient for outlier spreading and keeps the interface minimal. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 28fa6c2 commit fcfca9f

File tree

6 files changed

+29
-124
lines changed

6 files changed

+29
-124
lines changed

bitsandbytes/_ops.py

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -588,12 +588,12 @@ def _(
588588

589589
torch.library.define(
590590
"bitsandbytes::hadamard_rotate_",
591-
"(Tensor(a!) data, int block_size, Tensor? signs) -> Tensor(a!)",
591+
"(Tensor(a!) data, int block_size) -> Tensor(a!)",
592592
)
593593

594594

595595
@register_fake("bitsandbytes::hadamard_rotate_")
596-
def _(data: torch.Tensor, block_size: int, signs: Optional[torch.Tensor]) -> torch.Tensor:
596+
def _(data: torch.Tensor, block_size: int) -> torch.Tensor:
597597
torch._check(
598598
block_size in (32, 64, 128, 256),
599599
lambda: f"block_size must be 32, 64, 128, or 256, got {block_size}",
@@ -602,15 +602,6 @@ def _(data: torch.Tensor, block_size: int, signs: Optional[torch.Tensor]) -> tor
602602
data.dtype in (torch.float16, torch.bfloat16),
603603
lambda: f"hadamard_rotate only supports float16/bfloat16, got {data.dtype}",
604604
)
605-
if signs is not None:
606-
torch._check(
607-
signs.dtype == torch.int32,
608-
lambda: f"signs must be int32, got {signs.dtype}",
609-
)
610-
torch._check(
611-
signs.numel() == block_size // 32,
612-
lambda: f"signs must have {block_size // 32} elements for block_size={block_size}, got {signs.numel()}",
613-
)
614605
return data
615606

616607

bitsandbytes/backends/cuda/ops.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1001,7 +1001,7 @@ def _(
10011001

10021002

10031003
@register_kernel("bitsandbytes::hadamard_rotate_", "cuda")
1004-
def _(data: torch.Tensor, block_size: int, signs: Optional[torch.Tensor]) -> torch.Tensor:
1004+
def _(data: torch.Tensor, block_size: int) -> torch.Tensor:
10051005
torch._check(
10061006
block_size in (32, 64, 128, 256),
10071007
lambda: f"block_size must be 32, 64, 128, or 256, got {block_size}",
@@ -1012,14 +1012,12 @@ def _(data: torch.Tensor, block_size: int, signs: Optional[torch.Tensor]) -> tor
10121012
)
10131013

10141014
tname = _KBIT_DTYPE_SUFFIX[data.dtype]
1015-
signs_ptr = get_ptr(signs) if signs is not None else None
10161015
with _cuda_device_of(data):
10171016
fn = getattr(lib, f"chadamard_rotate_{tname}")
10181017
fn(
10191018
get_ptr(data),
10201019
ct.c_int(data.numel()),
10211020
ct.c_int(block_size),
1022-
signs_ptr,
10231021
_get_tensor_stream(data),
10241022
)
10251023

bitsandbytes/functional.py

Lines changed: 5 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1135,30 +1135,22 @@ def decode_absmax_e4m4(encoded: Tensor, bias: int = 11) -> Tensor:
11351135
return result
11361136

11371137

1138-
def hadamard_rotate(
1139-
data: Tensor,
1140-
block_size: int = 32,
1141-
signs: Optional[Tensor] = None,
1142-
) -> Tensor:
1143-
"""Apply in-place randomized Walsh-Hadamard rotation (H*D) to contiguous blocks.
1138+
def hadamard_rotate(data: Tensor, block_size: int = 32) -> Tensor:
1139+
"""Apply in-place Walsh-Hadamard rotation to contiguous blocks.
11441140
11451141
Spreads outliers across quantization blocks, improving kbit accuracy.
1146-
Since H*D is orthogonal, rotating both weights and activations with the
1147-
same signs preserves the GEMM result: (H*D)(A) @ (H*D)(B)^T = A @ B^T.
1142+
Since H is orthogonal, rotating both weights and activations preserves
1143+
the GEMM result: H(A) @ H(B)^T = A @ B^T.
11481144
11491145
Args:
11501146
data: Input tensor (float16 or bfloat16). Modified in-place.
11511147
block_size: Rotation block size (32, 64, 128, or 256).
1152-
signs: Optional int32 tensor of block_size//32 words. Each bit controls
1153-
the sign flip for one element within the block. If None, no sign
1154-
flips are applied (plain Hadamard). Generate once per model with
1155-
``torch.randint(0, 2**32, (block_size // 32,), dtype=torch.int32)``.
11561148
11571149
Returns:
11581150
The input tensor, rotated in-place.
11591151
"""
11601152
data_flat = data.contiguous().view(-1)
1161-
torch.ops.bitsandbytes.hadamard_rotate_(data_flat, block_size, signs)
1153+
torch.ops.bitsandbytes.hadamard_rotate_(data_flat, block_size)
11621154
return data
11631155

11641156

csrc/ops.cu

Lines changed: 9 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1015,25 +1015,19 @@ void repackKbit(
10151015
// ===========================================================================
10161016
// Hadamard rotation kernel (in-place, blocksize-templated)
10171017
//
1018-
// Applies a randomized Walsh-Hadamard transform (H*D) to contiguous blocks
1019-
// of BLOCK_SIZE elements. D is a diagonal sign-flip matrix (optional).
1020-
// Used to spread outliers before kbit quantization.
1021-
// Since H*D is orthogonal, rotating both weights and activations preserves
1022-
// the GEMM result: (H*D)(A) @ (H*D)(B)^T = A @ B^T.
1018+
// Applies a Walsh-Hadamard transform to contiguous blocks of BLOCK_SIZE
1019+
// elements. Used to spread outliers before kbit quantization.
1020+
// Since H is orthogonal, rotating both weights and activations preserves
1021+
// the GEMM result: H(A) @ H(B)^T = A @ B^T.
10231022
//
10241023
// One warp per rotation block:
10251024
// BLOCK_SIZE=32: 1 elem/thread, 5 shuffle stages
10261025
// BLOCK_SIZE=64: 2 elem/thread, 1 register + 5 shuffle stages
10271026
// BLOCK_SIZE=128: 4 elem/thread, 2 register + 5 shuffle stages
10281027
// BLOCK_SIZE=256: 8 elem/thread, 3 register + 5 shuffle stages
1029-
//
1030-
// signs: optional bitmask of BLOCK_SIZE/32 uint32 words. If non-null, bit i
1031-
// set means element i is negated before the Hadamard butterfly. Same sign
1032-
// vector is applied to every block.
10331028
// ===========================================================================
10341029

1035-
template <int BLOCK_SIZE, typename T>
1036-
__global__ void kHadamardRotate(T* __restrict__ data, const int n, const unsigned int* __restrict__ signs) {
1030+
template <int BLOCK_SIZE, typename T> __global__ void kHadamardRotate(T* __restrict__ data, const int n) {
10371031
constexpr int ELEMS_PER_THREAD = BLOCK_SIZE / 32;
10381032
static_assert(BLOCK_SIZE >= 32 && (BLOCK_SIZE & (BLOCK_SIZE - 1)) == 0, "BLOCK_SIZE must be a power of 2 >= 32");
10391033

@@ -1053,16 +1047,6 @@ __global__ void kHadamardRotate(T* __restrict__ data, const int n, const unsigne
10531047
vals[j] = (idx < n) ? (float)data[idx] : 0.0f;
10541048
}
10551049

1056-
// Apply random sign flips (D matrix) before butterfly.
1057-
// Element at position lane_id + j*32 uses word j, bit lane_id.
1058-
if (signs != nullptr) {
1059-
#pragma unroll
1060-
for (int j = 0; j < ELEMS_PER_THREAD; j++) {
1061-
if (signs[j] & (1u << lane_id))
1062-
vals[j] = -vals[j];
1063-
}
1064-
}
1065-
10661050
// In-register butterfly stages (strides >= 32).
10671051
// Stride S in global space corresponds to element index s = S/32.
10681052
// Element j pairs with element j ^ s (both in the same thread).
@@ -1107,18 +1091,17 @@ __global__ void kHadamardRotate(T* __restrict__ data, const int n, const unsigne
11071091

11081092
// ---- Hadamard rotation launch wrapper ----
11091093

1110-
template <int BLOCK_SIZE, typename T>
1111-
void hadamardRotate(T* data, int n, const unsigned int* signs, cudaStream_t stream) {
1094+
template <int BLOCK_SIZE, typename T> void hadamardRotate(T* data, int n, cudaStream_t stream) {
11121095
const int num_blocks = (n + BLOCK_SIZE - 1) / BLOCK_SIZE;
11131096
const int num_cuda_blocks = (num_blocks + KBIT_WARPS_PER_BLOCK - 1) / KBIT_WARPS_PER_BLOCK;
1114-
kHadamardRotate<BLOCK_SIZE, T><<<num_cuda_blocks, KBIT_THREADS_PER_BLOCK, 0, stream>>>(data, n, signs);
1097+
kHadamardRotate<BLOCK_SIZE, T><<<num_cuda_blocks, KBIT_THREADS_PER_BLOCK, 0, stream>>>(data, n);
11151098
CUDA_CHECK_RETURN(cudaPeekAtLastError());
11161099
}
11171100

11181101
// Explicit instantiations: 4 block sizes x 2 dtypes
11191102
#define INSTANTIATE_HADAMARD(BS) \
1120-
template void hadamardRotate<BS, half>(half*, int, const unsigned int*, cudaStream_t); \
1121-
template void hadamardRotate<BS, __nv_bfloat16>(__nv_bfloat16*, int, const unsigned int*, cudaStream_t);
1103+
template void hadamardRotate<BS, half>(half*, int, cudaStream_t); \
1104+
template void hadamardRotate<BS, __nv_bfloat16>(__nv_bfloat16*, int, cudaStream_t);
11221105

11231106
INSTANTIATE_HADAMARD(32)
11241107
INSTANTIATE_HADAMARD(64)

csrc/pythonInterface.cpp

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -800,24 +800,23 @@ MAKE_KBIT_SCALAR_GEMV_V2_FP16ABS(5)
800800
void testMMA(const half*, const half*, float*);
801801

802802
// Forward declarations of hadamard rotation template
803-
template <int BLOCK_SIZE, typename T>
804-
void hadamardRotate(T* data, int n, const unsigned int* signs, cudaStream_t stream);
803+
template <int BLOCK_SIZE, typename T> void hadamardRotate(T* data, int n, cudaStream_t stream);
805804

806805
// Unmangled hadamard rotation wrappers (dispatch block_size at runtime)
807806
#define MAKE_HADAMARD_ROTATE(tname, T) \
808-
void hadamard_rotate_##tname(T* data, int n, int block_size, const unsigned int* signs, cudaStream_t stream) { \
807+
void hadamard_rotate_##tname(T* data, int n, int block_size, cudaStream_t stream) { \
809808
switch (block_size) { \
810809
case 32: \
811-
hadamardRotate<32, T>(data, n, signs, stream); \
810+
hadamardRotate<32, T>(data, n, stream); \
812811
break; \
813812
case 64: \
814-
hadamardRotate<64, T>(data, n, signs, stream); \
813+
hadamardRotate<64, T>(data, n, stream); \
815814
break; \
816815
case 128: \
817-
hadamardRotate<128, T>(data, n, signs, stream); \
816+
hadamardRotate<128, T>(data, n, stream); \
818817
break; \
819818
case 256: \
820-
hadamardRotate<256, T>(data, n, signs, stream); \
819+
hadamardRotate<256, T>(data, n, stream); \
821820
break; \
822821
} \
823822
}
@@ -1699,12 +1698,12 @@ MAKE_CKBIT_SCALAR_GEMV_V2_FP16ABS(4)
16991698
MAKE_CKBIT_SCALAR_GEMV_V2_FP16ABS(5)
17001699

17011700
// Hadamard rotation extern C wrappers
1702-
void chadamard_rotate_fp16(half* data, int n, int block_size, const unsigned int* signs, cudaStream_t stream) {
1703-
hadamard_rotate_fp16(data, n, block_size, signs, stream);
1701+
void chadamard_rotate_fp16(half* data, int n, int block_size, cudaStream_t stream) {
1702+
hadamard_rotate_fp16(data, n, block_size, stream);
17041703
}
17051704

1706-
void chadamard_rotate_bf16(__nv_bfloat16* data, int n, int block_size, const unsigned int* signs, cudaStream_t stream) {
1707-
hadamard_rotate_bf16(data, n, block_size, signs, stream);
1705+
void chadamard_rotate_bf16(__nv_bfloat16* data, int n, int block_size, cudaStream_t stream) {
1706+
hadamard_rotate_bf16(data, n, block_size, stream);
17081707
}
17091708

17101709
#endif

tests/test_hadamard.py

Lines changed: 2 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111

1212
class TestOrthogonality:
13-
"""H(H(x)) ≈ x for plain Hadamard (no signs)."""
13+
"""H(H(x)) ≈ x Hadamard is its own inverse (involutory)."""
1414

1515
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
1616
@pytest.mark.parametrize("dtype", DTYPES)
@@ -34,41 +34,12 @@ def test_double_apply_large(self, block_size, dtype):
3434
torch.testing.assert_close(x, x_orig, atol=atol, rtol=atol)
3535

3636

37-
class TestSignedOrthogonality:
38-
"""Randomized Hadamard: R=H*D is orthogonal (R^T*R=I)."""
39-
40-
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
41-
@pytest.mark.parametrize("dtype", DTYPES)
42-
def test_signed_inverse(self, block_size, dtype):
43-
"""Verify inv(H*D) = D*H: forward then inverse recovers original."""
44-
signs = torch.randint(0, 2**31, (block_size // 32,), dtype=torch.int32, device="cuda")
45-
x = torch.randn(1024, dtype=dtype, device="cuda")
46-
x_orig = x.clone()
47-
48-
# Forward: H*D*x
49-
hadamard_rotate(x, block_size=block_size, signs=signs)
50-
51-
# Inverse: D*H*x' = first apply H (no signs), then sign flip
52-
hadamard_rotate(x, block_size=block_size) # H
53-
# Apply D (sign flip)
54-
x_flat = x.view(-1)
55-
for j in range(block_size // 32):
56-
word = signs[j].item()
57-
for bit in range(32):
58-
if word & (1 << bit):
59-
pos = j * 32 + bit
60-
x_flat[pos::block_size] *= -1
61-
62-
atol = 1e-2 if dtype == torch.bfloat16 else 1e-3
63-
torch.testing.assert_close(x, x_orig, atol=atol, rtol=atol)
64-
65-
6637
class TestGEMMEquivalence:
6738
"""H(A) @ H(B)^T ≈ A @ B^T (within quantization tolerance)."""
6839

6940
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
7041
@pytest.mark.parametrize("dtype", DTYPES)
71-
def test_gemm_plain(self, block_size, dtype):
42+
def test_gemm(self, block_size, dtype):
7243
M, K, N = 4, 256, 8
7344
A = torch.randn(M, K, dtype=dtype, device="cuda")
7445
B = torch.randn(N, K, dtype=dtype, device="cuda")
@@ -83,25 +54,6 @@ def test_gemm_plain(self, block_size, dtype):
8354
atol = 0.1 if dtype == torch.bfloat16 else 0.05
8455
torch.testing.assert_close(result, ref, atol=atol, rtol=0.05)
8556

86-
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
87-
@pytest.mark.parametrize("dtype", DTYPES)
88-
def test_gemm_signed(self, block_size, dtype):
89-
"""GEMM equivalence with random sign flips."""
90-
M, K, N = 4, 256, 8
91-
signs = torch.randint(0, 2**31, (block_size // 32,), dtype=torch.int32, device="cuda")
92-
A = torch.randn(M, K, dtype=dtype, device="cuda")
93-
B = torch.randn(N, K, dtype=dtype, device="cuda")
94-
ref = A.float() @ B.float().T
95-
96-
A_rot = A.clone()
97-
B_rot = B.clone()
98-
hadamard_rotate(A_rot, block_size=block_size, signs=signs)
99-
hadamard_rotate(B_rot, block_size=block_size, signs=signs)
100-
result = A_rot.float() @ B_rot.float().T
101-
102-
atol = 0.1 if dtype == torch.bfloat16 else 0.05
103-
torch.testing.assert_close(result, ref, atol=atol, rtol=0.05)
104-
10557
def test_gemm_qwen3_shapes(self):
10658
"""GEMM equivalence on Qwen3-Coder-Next 70B shapes."""
10759
shapes = [
@@ -194,16 +146,6 @@ def test_deterministic(self, block_size, dtype):
194146
hadamard_rotate(b, block_size=block_size)
195147
torch.testing.assert_close(a, b, atol=0, rtol=0)
196148

197-
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
198-
def test_deterministic_signed(self, block_size):
199-
signs = torch.randint(0, 2**31, (block_size // 32,), dtype=torch.int32, device="cuda")
200-
x = torch.randn(1024, dtype=torch.float16, device="cuda")
201-
a = x.clone()
202-
b = x.clone()
203-
hadamard_rotate(a, block_size=block_size, signs=signs)
204-
hadamard_rotate(b, block_size=block_size, signs=signs)
205-
torch.testing.assert_close(a, b, atol=0, rtol=0)
206-
207149

208150
class TestNormPreservation:
209151
"""Hadamard rotation preserves L2 norm (orthogonal transform)."""

0 commit comments

Comments
 (0)