Skip to content

Commit 22148c7

Browse files
TimDettmersclaude
andcommitted
feat: Add Hadamard rotation kernel for kbit outlier spreading
Templated Walsh-Hadamard transform kernel for FP16/BF16, operating on contiguous blocks of 32/64/128/256 elements. One warp per rotation block using butterfly decomposition: in-register stages for stride>=32, shuffle stages for stride<32. Normalization by 1/sqrt(block_size). In-place operation, CUDA graph safe (no runtime API calls in hot path). Registered as torch.ops.bitsandbytes.hadamard_rotate_ with Python helper in functional.py. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent a9864dc commit 22148c7

File tree

5 files changed

+194
-0
lines changed

5 files changed

+194
-0
lines changed

bitsandbytes/_ops.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -584,6 +584,27 @@ def _(
584584
return packed_tiled, absmax_tiled
585585

586586

587+
# Hadamard rotation (in-place, for kbit quantization outlier spreading)
588+
589+
torch.library.define(
590+
"bitsandbytes::hadamard_rotate_",
591+
"(Tensor(a!) data, int block_size) -> Tensor(a!)",
592+
)
593+
594+
595+
@register_fake("bitsandbytes::hadamard_rotate_")
596+
def _(data: torch.Tensor, block_size: int) -> torch.Tensor:
597+
torch._check(
598+
block_size in (32, 64, 128, 256),
599+
lambda: f"block_size must be 32, 64, 128, or 256, got {block_size}",
600+
)
601+
torch._check(
602+
data.dtype in (torch.float16, torch.bfloat16),
603+
lambda: f"hadamard_rotate only supports float16/bfloat16, got {data.dtype}",
604+
)
605+
return data
606+
607+
587608
# K-bit fused dequant + GEMM (production: fp16 + bf16)
588609

589610
torch.library.define(

bitsandbytes/backends/cuda/ops.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1000,6 +1000,30 @@ def _(
10001000
return packed_tiled, absmax_tiled
10011001

10021002

1003+
@register_kernel("bitsandbytes::hadamard_rotate_", "cuda")
1004+
def _(data: torch.Tensor, block_size: int) -> torch.Tensor:
1005+
torch._check(
1006+
block_size in (32, 64, 128, 256),
1007+
lambda: f"block_size must be 32, 64, 128, or 256, got {block_size}",
1008+
)
1009+
torch._check(
1010+
data.dtype in (torch.float16, torch.bfloat16),
1011+
lambda: f"hadamard_rotate only supports float16/bfloat16, got {data.dtype}",
1012+
)
1013+
1014+
tname = _KBIT_DTYPE_SUFFIX[data.dtype]
1015+
with _cuda_device_of(data):
1016+
fn = getattr(lib, f"chadamard_rotate_{tname}")
1017+
fn(
1018+
get_ptr(data),
1019+
ct.c_int(data.numel()),
1020+
ct.c_int(block_size),
1021+
_get_tensor_stream(data),
1022+
)
1023+
1024+
return data
1025+
1026+
10031027
def _kbit_gemm_prod_check(A, B_packed, B_absmax, codebook, N, k, k_chunks):
10041028
torch._check(k >= 2 and k <= 5, lambda: f"k must be 2-5, got {k}")
10051029
torch._check(

bitsandbytes/functional.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1135,6 +1135,25 @@ def decode_absmax_e4m4(encoded: Tensor, bias: int = 11) -> Tensor:
11351135
return result
11361136

11371137

1138+
def hadamard_rotate(data: Tensor, block_size: int = 32) -> Tensor:
1139+
"""Apply in-place Walsh-Hadamard rotation to contiguous blocks.
1140+
1141+
Spreads outliers across quantization blocks, improving kbit accuracy.
1142+
Since H is orthogonal, rotating both weights and activations preserves
1143+
the GEMM result: H(A) @ H(B)^T = A @ B^T.
1144+
1145+
Args:
1146+
data: Input tensor (float16 or bfloat16). Modified in-place.
1147+
block_size: Rotation block size (32, 64, 128, or 256).
1148+
1149+
Returns:
1150+
The input tensor, rotated in-place.
1151+
"""
1152+
data_flat = data.contiguous().view(-1)
1153+
torch.ops.bitsandbytes.hadamard_rotate_(data_flat, block_size)
1154+
return data
1155+
1156+
11381157
def quantize_kbit(
11391158
A: Tensor,
11401159
k: int = 4,

csrc/ops.cu

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1011,6 +1011,107 @@ void repackKbit(
10111011
CUDA_CHECK_RETURN(cudaPeekAtLastError());
10121012
}
10131013

1014+
// ===========================================================================
1015+
// Hadamard rotation kernel (in-place, blocksize-templated)
1016+
//
1017+
// Applies a Walsh-Hadamard transform to contiguous blocks of BLOCK_SIZE
1018+
// elements. Used to spread outliers before kbit quantization.
1019+
// Since H is orthogonal, rotating both weights and activations preserves
1020+
// the GEMM result: H(A) @ H(B)^T = A @ B^T.
1021+
//
1022+
// One warp per rotation block:
1023+
// BLOCK_SIZE=32: 1 elem/thread, 5 shuffle stages
1024+
// BLOCK_SIZE=64: 2 elem/thread, 1 register + 5 shuffle stages
1025+
// BLOCK_SIZE=128: 4 elem/thread, 2 register + 5 shuffle stages
1026+
// BLOCK_SIZE=256: 8 elem/thread, 3 register + 5 shuffle stages
1027+
// ===========================================================================
1028+
1029+
template <int BLOCK_SIZE, typename T>
1030+
__global__ void kHadamardRotate(T* __restrict__ data, const int n) {
1031+
constexpr int ELEMS_PER_THREAD = BLOCK_SIZE / 32;
1032+
static_assert(BLOCK_SIZE >= 32 && (BLOCK_SIZE & (BLOCK_SIZE - 1)) == 0,
1033+
"BLOCK_SIZE must be a power of 2 >= 32");
1034+
1035+
const int warp_idx = (blockIdx.x * blockDim.x + threadIdx.x) / 32;
1036+
const int lane_id = threadIdx.x % 32;
1037+
const int block_start = warp_idx * BLOCK_SIZE;
1038+
1039+
if (block_start >= n)
1040+
return;
1041+
1042+
// Load ELEMS_PER_THREAD elements per thread.
1043+
// Thread t holds elements at global positions: block_start + t, t+32, t+64, ...
1044+
float vals[ELEMS_PER_THREAD];
1045+
#pragma unroll
1046+
for (int j = 0; j < ELEMS_PER_THREAD; j++) {
1047+
int idx = block_start + lane_id + j * 32;
1048+
vals[j] = (idx < n) ? (float)data[idx] : 0.0f;
1049+
}
1050+
1051+
// In-register butterfly stages (strides >= 32).
1052+
// Stride S in global space corresponds to element index s = S/32.
1053+
// Element j pairs with element j ^ s (both in the same thread).
1054+
#pragma unroll
1055+
for (int s = ELEMS_PER_THREAD / 2; s >= 1; s >>= 1) {
1056+
#pragma unroll
1057+
for (int j = 0; j < ELEMS_PER_THREAD; j++) {
1058+
int partner = j ^ s;
1059+
if (partner > j) {
1060+
float a = vals[j], b = vals[partner];
1061+
vals[j] = a + b;
1062+
vals[partner] = a - b;
1063+
}
1064+
}
1065+
}
1066+
1067+
// Shuffle butterfly stages (strides 16, 8, 4, 2, 1).
1068+
// Each stage exchanges values between lanes within the warp.
1069+
#pragma unroll
1070+
for (int s = 16; s >= 1; s >>= 1) {
1071+
#pragma unroll
1072+
for (int j = 0; j < ELEMS_PER_THREAD; j++) {
1073+
float other = __shfl_xor_sync(0xFFFFFFFF, vals[j], s);
1074+
vals[j] = (lane_id & s) ? (other - vals[j]) : (vals[j] + other);
1075+
}
1076+
}
1077+
1078+
// Normalize by 1/sqrt(BLOCK_SIZE).
1079+
const float norm = rsqrtf((float)BLOCK_SIZE);
1080+
#pragma unroll
1081+
for (int j = 0; j < ELEMS_PER_THREAD; j++)
1082+
vals[j] *= norm;
1083+
1084+
// Store back.
1085+
#pragma unroll
1086+
for (int j = 0; j < ELEMS_PER_THREAD; j++) {
1087+
int idx = block_start + lane_id + j * 32;
1088+
if (idx < n)
1089+
data[idx] = (T)vals[j];
1090+
}
1091+
}
1092+
1093+
// ---- Hadamard rotation launch wrapper ----
1094+
1095+
template <int BLOCK_SIZE, typename T>
1096+
void hadamardRotate(T* data, int n, cudaStream_t stream) {
1097+
const int num_blocks = (n + BLOCK_SIZE - 1) / BLOCK_SIZE;
1098+
const int num_cuda_blocks = (num_blocks + KBIT_WARPS_PER_BLOCK - 1) / KBIT_WARPS_PER_BLOCK;
1099+
kHadamardRotate<BLOCK_SIZE, T><<<num_cuda_blocks, KBIT_THREADS_PER_BLOCK, 0, stream>>>(data, n);
1100+
CUDA_CHECK_RETURN(cudaPeekAtLastError());
1101+
}
1102+
1103+
// Explicit instantiations: 4 block sizes x 2 dtypes
1104+
#define INSTANTIATE_HADAMARD(BS) \
1105+
template void hadamardRotate<BS, half>(half*, int, cudaStream_t); \
1106+
template void hadamardRotate<BS, __nv_bfloat16>(__nv_bfloat16*, int, cudaStream_t);
1107+
1108+
INSTANTIATE_HADAMARD(32)
1109+
INSTANTIATE_HADAMARD(64)
1110+
INSTANTIATE_HADAMARD(128)
1111+
INSTANTIATE_HADAMARD(256)
1112+
1113+
#undef INSTANTIATE_HADAMARD
1114+
10141115
// Datacenter GPU detection: Hopper (sm_90) and Blackwell datacenter (sm_100).
10151116
// NOTE: sm_120 (RTX 5090, Blackwell consumer) lacks TMA/wgmma — must NOT match.
10161117
#if defined(__CUDA_ARCH__)

csrc/pythonInterface.cpp

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -796,6 +796,26 @@ MAKE_KBIT_SCALAR_GEMV_V2_FP16ABS(5)
796796
// Debug MMA test
797797
void testMMA(const half*, const half*, float*);
798798

799+
// Forward declarations of hadamard rotation template
800+
template <int BLOCK_SIZE, typename T>
801+
void hadamardRotate(T* data, int n, cudaStream_t stream);
802+
803+
// Unmangled hadamard rotation wrappers (dispatch block_size at runtime)
804+
#define MAKE_HADAMARD_ROTATE(tname, T) \
805+
void hadamard_rotate_##tname(T* data, int n, int block_size, cudaStream_t stream) { \
806+
switch (block_size) { \
807+
case 32: hadamardRotate<32, T>(data, n, stream); break; \
808+
case 64: hadamardRotate<64, T>(data, n, stream); break; \
809+
case 128: hadamardRotate<128, T>(data, n, stream); break; \
810+
case 256: hadamardRotate<256, T>(data, n, stream); break; \
811+
} \
812+
}
813+
814+
MAKE_HADAMARD_ROTATE(fp16, half)
815+
MAKE_HADAMARD_ROTATE(bf16, __nv_bfloat16)
816+
817+
#undef MAKE_HADAMARD_ROTATE
818+
799819
#endif // BUILD_CUDA || BUILD_HIP (kbit unmangled)
800820

801821
extern "C" {
@@ -1664,5 +1684,14 @@ MAKE_CKBIT_SCALAR_GEMV_V2_FP16ABS(3)
16641684
MAKE_CKBIT_SCALAR_GEMV_V2_FP16ABS(4)
16651685
MAKE_CKBIT_SCALAR_GEMV_V2_FP16ABS(5)
16661686

1687+
// Hadamard rotation extern C wrappers
1688+
void chadamard_rotate_fp16(half* data, int n, int block_size, cudaStream_t stream) {
1689+
hadamard_rotate_fp16(data, n, block_size, stream);
1690+
}
1691+
1692+
void chadamard_rotate_bf16(__nv_bfloat16* data, int n, int block_size, cudaStream_t stream) {
1693+
hadamard_rotate_bf16(data, n, block_size, stream);
1694+
}
1695+
16671696
#endif
16681697
}

0 commit comments

Comments
 (0)