Skip to content

Commit 3a2cf58

Browse files
TimDettmersclaude
andcommitted
feat: Add optional random sign flips to Hadamard rotation
Support randomized Hadamard transform R = H*D where D is a diagonal sign matrix. The sign vector (block_size/32 uint32 words as a bitmask) is applied element-wise before the butterfly stages. Since R is orthogonal (D^2=I), rotating both weights and activations with the same signs preserves the GEMM result. Random sign flips improve outlier destruction vs plain Hadamard by breaking deterministic alignment patterns. Generate signs once per model with torch.randint(0, 2**32, (block_size//32,), dtype=torch.int32). Passing signs=None preserves the previous behavior (plain Hadamard). Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 22148c7 commit 3a2cf58

File tree

5 files changed

+64
-28
lines changed

5 files changed

+64
-28
lines changed

bitsandbytes/_ops.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -588,12 +588,12 @@ def _(
588588

589589
torch.library.define(
590590
"bitsandbytes::hadamard_rotate_",
591-
"(Tensor(a!) data, int block_size) -> Tensor(a!)",
591+
"(Tensor(a!) data, int block_size, Tensor? signs) -> Tensor(a!)",
592592
)
593593

594594

595595
@register_fake("bitsandbytes::hadamard_rotate_")
596-
def _(data: torch.Tensor, block_size: int) -> torch.Tensor:
596+
def _(data: torch.Tensor, block_size: int, signs: Optional[torch.Tensor]) -> torch.Tensor:
597597
torch._check(
598598
block_size in (32, 64, 128, 256),
599599
lambda: f"block_size must be 32, 64, 128, or 256, got {block_size}",
@@ -602,6 +602,15 @@ def _(data: torch.Tensor, block_size: int) -> torch.Tensor:
602602
data.dtype in (torch.float16, torch.bfloat16),
603603
lambda: f"hadamard_rotate only supports float16/bfloat16, got {data.dtype}",
604604
)
605+
if signs is not None:
606+
torch._check(
607+
signs.dtype == torch.int32,
608+
lambda: f"signs must be int32, got {signs.dtype}",
609+
)
610+
torch._check(
611+
signs.numel() == block_size // 32,
612+
lambda: f"signs must have {block_size // 32} elements for block_size={block_size}, got {signs.numel()}",
613+
)
605614
return data
606615

607616

bitsandbytes/backends/cuda/ops.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1001,7 +1001,7 @@ def _(
10011001

10021002

10031003
@register_kernel("bitsandbytes::hadamard_rotate_", "cuda")
1004-
def _(data: torch.Tensor, block_size: int) -> torch.Tensor:
1004+
def _(data: torch.Tensor, block_size: int, signs: Optional[torch.Tensor]) -> torch.Tensor:
10051005
torch._check(
10061006
block_size in (32, 64, 128, 256),
10071007
lambda: f"block_size must be 32, 64, 128, or 256, got {block_size}",
@@ -1012,12 +1012,14 @@ def _(data: torch.Tensor, block_size: int) -> torch.Tensor:
10121012
)
10131013

10141014
tname = _KBIT_DTYPE_SUFFIX[data.dtype]
1015+
signs_ptr = get_ptr(signs) if signs is not None else None
10151016
with _cuda_device_of(data):
10161017
fn = getattr(lib, f"chadamard_rotate_{tname}")
10171018
fn(
10181019
get_ptr(data),
10191020
ct.c_int(data.numel()),
10201021
ct.c_int(block_size),
1022+
signs_ptr,
10211023
_get_tensor_stream(data),
10221024
)
10231025

bitsandbytes/functional.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1135,22 +1135,30 @@ def decode_absmax_e4m4(encoded: Tensor, bias: int = 11) -> Tensor:
11351135
return result
11361136

11371137

1138-
def hadamard_rotate(data: Tensor, block_size: int = 32) -> Tensor:
1139-
"""Apply in-place Walsh-Hadamard rotation to contiguous blocks.
1138+
def hadamard_rotate(
1139+
data: Tensor,
1140+
block_size: int = 32,
1141+
signs: Optional[Tensor] = None,
1142+
) -> Tensor:
1143+
"""Apply in-place randomized Walsh-Hadamard rotation (H*D) to contiguous blocks.
11401144
11411145
Spreads outliers across quantization blocks, improving kbit accuracy.
1142-
Since H is orthogonal, rotating both weights and activations preserves
1143-
the GEMM result: H(A) @ H(B)^T = A @ B^T.
1146+
Since H*D is orthogonal, rotating both weights and activations with the
1147+
same signs preserves the GEMM result: (H*D)(A) @ (H*D)(B)^T = A @ B^T.
11441148
11451149
Args:
11461150
data: Input tensor (float16 or bfloat16). Modified in-place.
11471151
block_size: Rotation block size (32, 64, 128, or 256).
1152+
signs: Optional int32 tensor of block_size//32 words. Each bit controls
1153+
the sign flip for one element within the block. If None, no sign
1154+
flips are applied (plain Hadamard). Generate once per model with
1155+
``torch.randint(0, 2**32, (block_size // 32,), dtype=torch.int32)``.
11481156
11491157
Returns:
11501158
The input tensor, rotated in-place.
11511159
"""
11521160
data_flat = data.contiguous().view(-1)
1153-
torch.ops.bitsandbytes.hadamard_rotate_(data_flat, block_size)
1161+
torch.ops.bitsandbytes.hadamard_rotate_(data_flat, block_size, signs)
11541162
return data
11551163

11561164

csrc/ops.cu

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1014,20 +1014,25 @@ void repackKbit(
10141014
// ===========================================================================
10151015
// Hadamard rotation kernel (in-place, blocksize-templated)
10161016
//
1017-
// Applies a Walsh-Hadamard transform to contiguous blocks of BLOCK_SIZE
1018-
// elements. Used to spread outliers before kbit quantization.
1019-
// Since H is orthogonal, rotating both weights and activations preserves
1020-
// the GEMM result: H(A) @ H(B)^T = A @ B^T.
1017+
// Applies a randomized Walsh-Hadamard transform (H*D) to contiguous blocks
1018+
// of BLOCK_SIZE elements. D is a diagonal sign-flip matrix (optional).
1019+
// Used to spread outliers before kbit quantization.
1020+
// Since H*D is orthogonal, rotating both weights and activations preserves
1021+
// the GEMM result: (H*D)(A) @ (H*D)(B)^T = A @ B^T.
10211022
//
10221023
// One warp per rotation block:
10231024
// BLOCK_SIZE=32: 1 elem/thread, 5 shuffle stages
10241025
// BLOCK_SIZE=64: 2 elem/thread, 1 register + 5 shuffle stages
10251026
// BLOCK_SIZE=128: 4 elem/thread, 2 register + 5 shuffle stages
10261027
// BLOCK_SIZE=256: 8 elem/thread, 3 register + 5 shuffle stages
1028+
//
1029+
// signs: optional bitmask of BLOCK_SIZE/32 uint32 words. If non-null, bit i
1030+
// set means element i is negated before the Hadamard butterfly. Same sign
1031+
// vector is applied to every block.
10271032
// ===========================================================================
10281033

10291034
template <int BLOCK_SIZE, typename T>
1030-
__global__ void kHadamardRotate(T* __restrict__ data, const int n) {
1035+
__global__ void kHadamardRotate(T* __restrict__ data, const int n, const unsigned int* __restrict__ signs) {
10311036
constexpr int ELEMS_PER_THREAD = BLOCK_SIZE / 32;
10321037
static_assert(BLOCK_SIZE >= 32 && (BLOCK_SIZE & (BLOCK_SIZE - 1)) == 0,
10331038
"BLOCK_SIZE must be a power of 2 >= 32");
@@ -1048,6 +1053,16 @@ __global__ void kHadamardRotate(T* __restrict__ data, const int n) {
10481053
vals[j] = (idx < n) ? (float)data[idx] : 0.0f;
10491054
}
10501055

1056+
// Apply random sign flips (D matrix) before butterfly.
1057+
// Element at position lane_id + j*32 uses word j, bit lane_id.
1058+
if (signs != nullptr) {
1059+
#pragma unroll
1060+
for (int j = 0; j < ELEMS_PER_THREAD; j++) {
1061+
if (signs[j] & (1u << lane_id))
1062+
vals[j] = -vals[j];
1063+
}
1064+
}
1065+
10511066
// In-register butterfly stages (strides >= 32).
10521067
// Stride S in global space corresponds to element index s = S/32.
10531068
// Element j pairs with element j ^ s (both in the same thread).
@@ -1093,17 +1108,17 @@ __global__ void kHadamardRotate(T* __restrict__ data, const int n) {
10931108
// ---- Hadamard rotation launch wrapper ----
10941109

10951110
template <int BLOCK_SIZE, typename T>
1096-
void hadamardRotate(T* data, int n, cudaStream_t stream) {
1111+
void hadamardRotate(T* data, int n, const unsigned int* signs, cudaStream_t stream) {
10971112
const int num_blocks = (n + BLOCK_SIZE - 1) / BLOCK_SIZE;
10981113
const int num_cuda_blocks = (num_blocks + KBIT_WARPS_PER_BLOCK - 1) / KBIT_WARPS_PER_BLOCK;
1099-
kHadamardRotate<BLOCK_SIZE, T><<<num_cuda_blocks, KBIT_THREADS_PER_BLOCK, 0, stream>>>(data, n);
1114+
kHadamardRotate<BLOCK_SIZE, T><<<num_cuda_blocks, KBIT_THREADS_PER_BLOCK, 0, stream>>>(data, n, signs);
11001115
CUDA_CHECK_RETURN(cudaPeekAtLastError());
11011116
}
11021117

11031118
// Explicit instantiations: 4 block sizes x 2 dtypes
1104-
#define INSTANTIATE_HADAMARD(BS) \
1105-
template void hadamardRotate<BS, half>(half*, int, cudaStream_t); \
1106-
template void hadamardRotate<BS, __nv_bfloat16>(__nv_bfloat16*, int, cudaStream_t);
1119+
#define INSTANTIATE_HADAMARD(BS) \
1120+
template void hadamardRotate<BS, half>(half*, int, const unsigned int*, cudaStream_t); \
1121+
template void hadamardRotate<BS, __nv_bfloat16>(__nv_bfloat16*, int, const unsigned int*, cudaStream_t);
11071122

11081123
INSTANTIATE_HADAMARD(32)
11091124
INSTANTIATE_HADAMARD(64)

csrc/pythonInterface.cpp

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -798,16 +798,16 @@ void testMMA(const half*, const half*, float*);
798798

799799
// Forward declarations of hadamard rotation template
800800
template <int BLOCK_SIZE, typename T>
801-
void hadamardRotate(T* data, int n, cudaStream_t stream);
801+
void hadamardRotate(T* data, int n, const unsigned int* signs, cudaStream_t stream);
802802

803803
// Unmangled hadamard rotation wrappers (dispatch block_size at runtime)
804804
#define MAKE_HADAMARD_ROTATE(tname, T) \
805-
void hadamard_rotate_##tname(T* data, int n, int block_size, cudaStream_t stream) { \
805+
void hadamard_rotate_##tname(T* data, int n, int block_size, const unsigned int* signs, cudaStream_t stream) { \
806806
switch (block_size) { \
807-
case 32: hadamardRotate<32, T>(data, n, stream); break; \
808-
case 64: hadamardRotate<64, T>(data, n, stream); break; \
809-
case 128: hadamardRotate<128, T>(data, n, stream); break; \
810-
case 256: hadamardRotate<256, T>(data, n, stream); break; \
807+
case 32: hadamardRotate<32, T>(data, n, signs, stream); break; \
808+
case 64: hadamardRotate<64, T>(data, n, signs, stream); break; \
809+
case 128: hadamardRotate<128, T>(data, n, signs, stream); break; \
810+
case 256: hadamardRotate<256, T>(data, n, signs, stream); break; \
811811
} \
812812
}
813813

@@ -1685,12 +1685,14 @@ MAKE_CKBIT_SCALAR_GEMV_V2_FP16ABS(4)
16851685
MAKE_CKBIT_SCALAR_GEMV_V2_FP16ABS(5)
16861686

16871687
// Hadamard rotation extern C wrappers
1688-
void chadamard_rotate_fp16(half* data, int n, int block_size, cudaStream_t stream) {
1689-
hadamard_rotate_fp16(data, n, block_size, stream);
1688+
void chadamard_rotate_fp16(half* data, int n, int block_size, const unsigned int* signs, cudaStream_t stream) {
1689+
hadamard_rotate_fp16(data, n, block_size, signs, stream);
16901690
}
16911691

1692-
void chadamard_rotate_bf16(__nv_bfloat16* data, int n, int block_size, cudaStream_t stream) {
1693-
hadamard_rotate_bf16(data, n, block_size, stream);
1692+
void chadamard_rotate_bf16(
1693+
__nv_bfloat16* data, int n, int block_size, const unsigned int* signs, cudaStream_t stream
1694+
) {
1695+
hadamard_rotate_bf16(data, n, block_size, signs, stream);
16941696
}
16951697

16961698
#endif

0 commit comments

Comments
 (0)