Skip to content

Commit 6a872e3

Browse files
TimDettmersclaude
andcommitted
feat: Add full-dimension Hadamard rotation kernel
Add kHadamardRotateFull kernel that rotates across the entire last dimension (512-8192), matching the approach used by QuIP#, QuaRot, and SpinQuant for maximal outlier suppression. The existing block-diagonal kernel (block_size 32-256) remains for use cases where block rotation suffices. Kernel design: one thread block per row with 3-4 butterfly levels: 1. In-thread butterfly (strides 1, 2, 4) 2. Warp shuffle butterfly (strides 8-128) 3. Cross-warp butterfly via shared memory (strides 256+) 4. Cross-chunk butterfly in registers (dims > 2048) API: hadamard_rotate(data, block_size=0) for full-dimension mode. Signs vector has dim//32 words (one per full row, not per block). 144 tests passing (79 existing + 65 new full-dimension tests). Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 72ee3e8 commit 6a872e3

6 files changed

Lines changed: 486 additions & 8 deletions

File tree

bitsandbytes/_ops.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -614,6 +614,43 @@ def _(data: torch.Tensor, block_size: int, signs: Optional[torch.Tensor]) -> tor
614614
return data
615615

616616

617+
# Full-dimension Hadamard rotation (in-place, for kbit quantization outlier spreading)
618+
# Unlike hadamard_rotate_ which uses block-diagonal Hadamard, this rotates across
619+
# the entire last dimension of the input tensor.
620+
621+
torch.library.define(
622+
"bitsandbytes::hadamard_rotate_full_",
623+
"(Tensor(a!) data, int dim, Tensor? signs) -> Tensor(a!)",
624+
)
625+
626+
627+
@register_fake("bitsandbytes::hadamard_rotate_full_")
628+
def _(data: torch.Tensor, dim: int, signs: Optional[torch.Tensor]) -> torch.Tensor:
629+
supported_dims = (512, 1024, 2048, 4096, 8192)
630+
torch._check(
631+
dim in supported_dims,
632+
lambda: f"dim must be one of {supported_dims}, got {dim}",
633+
)
634+
torch._check(
635+
data.numel() % dim == 0,
636+
lambda: f"data.numel() ({data.numel()}) must be divisible by dim ({dim})",
637+
)
638+
torch._check(
639+
data.dtype in (torch.float16, torch.bfloat16),
640+
lambda: f"hadamard_rotate_full only supports float16/bfloat16, got {data.dtype}",
641+
)
642+
if signs is not None:
643+
torch._check(
644+
signs.dtype == torch.int32,
645+
lambda: f"signs must be int32, got {signs.dtype}",
646+
)
647+
torch._check(
648+
signs.numel() == dim // 32,
649+
lambda: f"signs must have {dim // 32} elements for dim={dim}, got {signs.numel()}",
650+
)
651+
return data
652+
653+
617654
# K-bit fused dequant + GEMM (production: fp16 + bf16)
618655

619656
torch.library.define(

bitsandbytes/backends/cuda/ops.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1026,6 +1026,34 @@ def _(data: torch.Tensor, block_size: int, signs: Optional[torch.Tensor]) -> tor
10261026
return data
10271027

10281028

1029+
@register_kernel("bitsandbytes::hadamard_rotate_full_", "cuda")
1030+
def _(data: torch.Tensor, dim: int, signs: Optional[torch.Tensor]) -> torch.Tensor:
1031+
supported_dims = (512, 1024, 2048, 4096, 8192)
1032+
torch._check(
1033+
dim in supported_dims,
1034+
lambda: f"dim must be one of {supported_dims}, got {dim}",
1035+
)
1036+
torch._check(
1037+
data.dtype in (torch.float16, torch.bfloat16),
1038+
lambda: f"hadamard_rotate_full only supports float16/bfloat16, got {data.dtype}",
1039+
)
1040+
1041+
num_rows = data.numel() // dim
1042+
tname = _KBIT_DTYPE_SUFFIX[data.dtype]
1043+
signs_ptr = get_ptr(signs) if signs is not None else None
1044+
with _cuda_device_of(data):
1045+
fn = getattr(lib, f"chadamard_rotate_full_{tname}")
1046+
fn(
1047+
get_ptr(data),
1048+
ct.c_int(num_rows),
1049+
ct.c_int(dim),
1050+
signs_ptr,
1051+
_get_tensor_stream(data),
1052+
)
1053+
1054+
return data
1055+
1056+
10291057
def _kbit_gemm_prod_check(A, B_packed, B_absmax, codebook, N, k, k_chunks):
10301058
torch._check(k >= 2 and k <= 5, lambda: f"k must be 2-5, got {k}")
10311059
torch._check(

bitsandbytes/functional.py

Lines changed: 30 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1140,25 +1140,47 @@ def hadamard_rotate(
11401140
block_size: int = 32,
11411141
signs: Optional[Tensor] = None,
11421142
) -> Tensor:
1143-
"""Apply in-place randomized Walsh-Hadamard rotation (H*D) to contiguous blocks.
1143+
"""Apply in-place randomized Walsh-Hadamard rotation (H*D).
11441144
11451145
Spreads outliers across quantization blocks, improving kbit accuracy.
11461146
Since H*D is orthogonal, rotating both weights and activations with the
11471147
same signs preserves the GEMM result: (H*D)(A) @ (H*D)(B)^T = A @ B^T.
11481148
1149+
Two modes:
1150+
1151+
**Block-diagonal** (block_size in {32, 64, 128, 256}): Applies independent
1152+
Hadamard rotations to contiguous blocks of ``block_size`` elements across
1153+
the flattened tensor. Fast and parallel, but only spreads outliers within
1154+
each block.
1155+
1156+
**Full-dimension** (block_size=0): Applies the Hadamard rotation across
1157+
the entire last dimension of the tensor. Matches the approach used by
1158+
QuIP#, QuaRot, and SpinQuant for maximal outlier suppression. The last
1159+
dimension must be a power of 2 in {512, 1024, 2048, 4096, 8192}.
1160+
11491161
Args:
11501162
data: Input tensor (float16 or bfloat16). Modified in-place.
1151-
block_size: Rotation block size (32, 64, 128, or 256).
1152-
signs: Optional int32 tensor of block_size//32 words. Each bit controls
1153-
the sign flip for one element within the block. If None, no sign
1154-
flips are applied (plain Hadamard). Generate once per model with
1155-
``torch.randint(0, 2**32, (block_size // 32,), dtype=torch.int32)``.
1163+
block_size: Rotation block size (32, 64, 128, 256) for block-diagonal
1164+
mode, or 0 for full-dimension mode.
1165+
signs: Optional int32 tensor of sign-flip bits. For block-diagonal
1166+
mode: ``block_size // 32`` words (repeated per block). For
1167+
full-dimension mode: ``dim // 32`` words where ``dim`` is the
1168+
last dimension. Each bit controls the sign flip for one element.
1169+
If None, no sign flips (plain Hadamard). Generate once per model
1170+
with ``torch.randint(0, 2**32, (n_words,), dtype=torch.int32)``.
11561171
11571172
Returns:
11581173
The input tensor, rotated in-place.
11591174
"""
1160-
data_flat = data.contiguous().view(-1)
1161-
torch.ops.bitsandbytes.hadamard_rotate_(data_flat, block_size, signs)
1175+
if block_size == 0:
1176+
# Full-dimension mode: rotate across the entire last dimension.
1177+
dim = data.shape[-1]
1178+
data_flat = data.contiguous().view(-1)
1179+
torch.ops.bitsandbytes.hadamard_rotate_full_(data_flat, dim, signs)
1180+
else:
1181+
# Block-diagonal mode: independent rotations per block.
1182+
data_flat = data.contiguous().view(-1)
1183+
torch.ops.bitsandbytes.hadamard_rotate_(data_flat, block_size, signs)
11621184
return data
11631185

11641186

csrc/ops.cu

Lines changed: 194 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1127,6 +1127,200 @@ INSTANTIATE_HADAMARD(256)
11271127

11281128
#undef INSTANTIATE_HADAMARD
11291129

1130+
// ===========================================================================
1131+
// Full-dimension Hadamard rotation kernel.
1132+
// One thread block processes one row of DIM elements using 3-4 butterfly levels:
1133+
// 1. In-thread butterfly (strides 1..kNElts/2)
1134+
// 2. Warp shuffle butterfly (strides kNElts..kNElts*16)
1135+
// 3. Cross-warp butterfly via shared memory (strides across warps)
1136+
// 4. Cross-chunk butterfly in registers (when kNChunks > 1)
1137+
//
1138+
// Grid: (num_rows,). Signs: DIM/32 uint32 words (one per full row, not per block).
1139+
// ===========================================================================
1140+
1141+
template <int kLogDim, int kNThreads, typename T>
1142+
__global__ void kHadamardRotateFull(T* __restrict__ data, const int num_rows, const unsigned int* __restrict__ signs) {
1143+
constexpr int DIM = 1 << kLogDim;
1144+
constexpr int kNElts = 8; // elements per thread per chunk
1145+
constexpr int kNChunks = DIM / (kNThreads * kNElts);
1146+
constexpr int kNWarps = kNThreads / 32;
1147+
1148+
static_assert(DIM == kNThreads * kNElts * kNChunks, "dimension decomposition mismatch");
1149+
static_assert(kNElts == 8, "kNElts must be 8");
1150+
static_assert((kNThreads & (kNThreads - 1)) == 0, "kNThreads must be power of 2");
1151+
1152+
const int row = blockIdx.x;
1153+
if (row >= num_rows)
1154+
return;
1155+
1156+
T* row_data = data + (long long)row * DIM;
1157+
1158+
// Shared memory for cross-warp butterfly (only needed when kNWarps > 1).
1159+
// Use char[] to match other kernels in this TU, then cast to float*.
1160+
extern __shared__ char smem_raw[];
1161+
float* smem = reinterpret_cast<float*>(smem_raw);
1162+
1163+
const int tid = threadIdx.x;
1164+
const int warp_id = tid / 32;
1165+
const int lane_id = tid % 32;
1166+
1167+
// ---- Load elements (contiguous per thread) ----
1168+
float vals[kNChunks][kNElts];
1169+
#pragma unroll
1170+
for (int c = 0; c < kNChunks; c++) {
1171+
const int base = c * kNThreads * kNElts + tid * kNElts;
1172+
#pragma unroll
1173+
for (int i = 0; i < kNElts; i++) {
1174+
vals[c][i] = (float)row_data[base + i];
1175+
}
1176+
}
1177+
1178+
// ---- Apply sign flips (D matrix) before butterfly ----
1179+
// 8 contiguous elements at position 'base' always fit within one uint32 word
1180+
// since base is always a multiple of 8.
1181+
if (signs != nullptr) {
1182+
#pragma unroll
1183+
for (int c = 0; c < kNChunks; c++) {
1184+
const int linear = c * kNThreads + tid; // which group of 8
1185+
const int word_idx = linear / 4;
1186+
const int byte_pos = (linear % 4) * 8;
1187+
const unsigned int byte_bits = (signs[word_idx] >> byte_pos) & 0xFFu;
1188+
#pragma unroll
1189+
for (int i = 0; i < kNElts; i++) {
1190+
if (byte_bits & (1u << i))
1191+
vals[c][i] = -vals[c][i];
1192+
}
1193+
}
1194+
}
1195+
1196+
// ---- Level 1: In-thread butterfly (strides 1, 2, 4) ----
1197+
#pragma unroll
1198+
for (int c = 0; c < kNChunks; c++) {
1199+
#pragma unroll
1200+
for (int s = 1; s < kNElts; s <<= 1) {
1201+
#pragma unroll
1202+
for (int i = 0; i < kNElts; i++) {
1203+
int partner = i ^ s;
1204+
if (partner > i) {
1205+
float a = vals[c][i], b = vals[c][partner];
1206+
vals[c][i] = a + b;
1207+
vals[c][partner] = a - b;
1208+
}
1209+
}
1210+
}
1211+
}
1212+
1213+
// ---- Level 2: Warp shuffle butterfly (shfl_xor s=1..16) ----
1214+
#pragma unroll
1215+
for (int s = 1; s <= 16; s <<= 1) {
1216+
#pragma unroll
1217+
for (int c = 0; c < kNChunks; c++) {
1218+
#pragma unroll
1219+
for (int i = 0; i < kNElts; i++) {
1220+
float other = __shfl_xor_sync(0xFFFFFFFF, vals[c][i], s);
1221+
vals[c][i] = (lane_id & s) ? (other - vals[c][i]) : (vals[c][i] + other);
1222+
}
1223+
}
1224+
}
1225+
1226+
// ---- Level 3: Cross-warp butterfly via shared memory ----
1227+
if constexpr (kNWarps > 1) {
1228+
constexpr int VALS_PER_THREAD = kNChunks * kNElts;
1229+
// smem layout: smem[tid * VALS_PER_THREAD + c * kNElts + i]
1230+
#pragma unroll
1231+
for (int ws = 1; ws < kNWarps; ws <<= 1) {
1232+
// Write my values to shared memory
1233+
#pragma unroll
1234+
for (int c = 0; c < kNChunks; c++) {
1235+
#pragma unroll
1236+
for (int i = 0; i < kNElts; i++) {
1237+
smem[tid * VALS_PER_THREAD + c * kNElts + i] = vals[c][i];
1238+
}
1239+
}
1240+
__syncthreads();
1241+
1242+
// Read partner warp's values
1243+
const int partner_tid = (warp_id ^ ws) * 32 + lane_id;
1244+
const bool negate = (warp_id & ws) != 0;
1245+
#pragma unroll
1246+
for (int c = 0; c < kNChunks; c++) {
1247+
#pragma unroll
1248+
for (int i = 0; i < kNElts; i++) {
1249+
float pval = smem[partner_tid * VALS_PER_THREAD + c * kNElts + i];
1250+
vals[c][i] = negate ? (pval - vals[c][i]) : (vals[c][i] + pval);
1251+
}
1252+
}
1253+
__syncthreads();
1254+
}
1255+
}
1256+
1257+
// ---- Level 4: Cross-chunk butterfly (in-register, no communication) ----
1258+
if constexpr (kNChunks > 1) {
1259+
#pragma unroll
1260+
for (int cs = 1; cs < kNChunks; cs <<= 1) {
1261+
#pragma unroll
1262+
for (int c = 0; c < kNChunks; c++) {
1263+
int pc = c ^ cs;
1264+
if (pc > c) {
1265+
#pragma unroll
1266+
for (int i = 0; i < kNElts; i++) {
1267+
float a = vals[c][i], b = vals[pc][i];
1268+
vals[c][i] = a + b;
1269+
vals[pc][i] = a - b;
1270+
}
1271+
}
1272+
}
1273+
}
1274+
}
1275+
1276+
// ---- Normalize by 1/sqrt(DIM) ----
1277+
const float norm = rsqrtf((float)DIM);
1278+
#pragma unroll
1279+
for (int c = 0; c < kNChunks; c++) {
1280+
#pragma unroll
1281+
for (int i = 0; i < kNElts; i++)
1282+
vals[c][i] *= norm;
1283+
}
1284+
1285+
// ---- Store back ----
1286+
#pragma unroll
1287+
for (int c = 0; c < kNChunks; c++) {
1288+
const int base = c * kNThreads * kNElts + tid * kNElts;
1289+
#pragma unroll
1290+
for (int i = 0; i < kNElts; i++) {
1291+
row_data[base + i] = (T)vals[c][i];
1292+
}
1293+
}
1294+
}
1295+
1296+
// ---- Full-dimension Hadamard launch wrapper ----
1297+
// kLogDim must match the dimension. kNThreads is the thread block size.
1298+
1299+
template <int kLogDim, int kNThreads, typename T>
1300+
void hadamardRotateFull(T* data, int num_rows, const unsigned int* signs, cudaStream_t stream) {
1301+
constexpr int DIM = 1 << kLogDim;
1302+
constexpr int kNElts = 8;
1303+
constexpr int kNChunks = DIM / (kNThreads * kNElts);
1304+
constexpr int smem_bytes = kNThreads * kNChunks * kNElts * sizeof(float);
1305+
kHadamardRotateFull<kLogDim, kNThreads, T><<<num_rows, kNThreads, smem_bytes, stream>>>(data, num_rows, signs);
1306+
CUDA_CHECK_RETURN(cudaPeekAtLastError());
1307+
}
1308+
1309+
// Explicit instantiations: dim 512..8192, 2 dtypes
1310+
#define INSTANTIATE_HADAMARD_FULL(LOG_DIM, NTHREADS) \
1311+
template void hadamardRotateFull<LOG_DIM, NTHREADS, half>(half*, int, const unsigned int*, cudaStream_t); \
1312+
template void hadamardRotateFull<LOG_DIM, NTHREADS, __nv_bfloat16>( \
1313+
__nv_bfloat16*, int, const unsigned int*, cudaStream_t \
1314+
);
1315+
1316+
INSTANTIATE_HADAMARD_FULL(9, 64) // dim=512
1317+
INSTANTIATE_HADAMARD_FULL(10, 128) // dim=1024
1318+
INSTANTIATE_HADAMARD_FULL(11, 256) // dim=2048
1319+
INSTANTIATE_HADAMARD_FULL(12, 256) // dim=4096
1320+
INSTANTIATE_HADAMARD_FULL(13, 256) // dim=8192
1321+
1322+
#undef INSTANTIATE_HADAMARD_FULL
1323+
11301324
// Datacenter GPU detection: Hopper (sm_90) and Blackwell datacenter (sm_100).
11311325
// NOTE: sm_120 (RTX 5090, Blackwell consumer) lacks TMA/wgmma — must NOT match.
11321326
#if defined(__CUDA_ARCH__)

csrc/pythonInterface.cpp

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -827,6 +827,39 @@ MAKE_HADAMARD_ROTATE(bf16, __nv_bfloat16)
827827

828828
#undef MAKE_HADAMARD_ROTATE
829829

830+
// Forward declarations of full-dimension hadamard rotation template
831+
template <int kLogDim, int kNThreads, typename T>
832+
void hadamardRotateFull(T* data, int num_rows, const unsigned int* signs, cudaStream_t stream);
833+
834+
// Unmangled full-dimension hadamard rotation wrappers (dispatch dim at runtime)
835+
#define MAKE_HADAMARD_ROTATE_FULL(tname, T) \
836+
void hadamard_rotate_full_##tname( \
837+
T* data, int num_rows, int dim, const unsigned int* signs, cudaStream_t stream \
838+
) { \
839+
switch (dim) { \
840+
case 512: \
841+
hadamardRotateFull<9, 64, T>(data, num_rows, signs, stream); \
842+
break; \
843+
case 1024: \
844+
hadamardRotateFull<10, 128, T>(data, num_rows, signs, stream); \
845+
break; \
846+
case 2048: \
847+
hadamardRotateFull<11, 256, T>(data, num_rows, signs, stream); \
848+
break; \
849+
case 4096: \
850+
hadamardRotateFull<12, 256, T>(data, num_rows, signs, stream); \
851+
break; \
852+
case 8192: \
853+
hadamardRotateFull<13, 256, T>(data, num_rows, signs, stream); \
854+
break; \
855+
} \
856+
}
857+
858+
MAKE_HADAMARD_ROTATE_FULL(fp16, half)
859+
MAKE_HADAMARD_ROTATE_FULL(bf16, __nv_bfloat16)
860+
861+
#undef MAKE_HADAMARD_ROTATE_FULL
862+
830863
#endif // BUILD_CUDA || BUILD_HIP (kbit unmangled)
831864

832865
extern "C" {
@@ -1707,5 +1740,16 @@ void chadamard_rotate_bf16(__nv_bfloat16* data, int n, int block_size, const uns
17071740
hadamard_rotate_bf16(data, n, block_size, signs, stream);
17081741
}
17091742

1743+
// Full-dimension Hadamard rotation extern C wrappers
1744+
void chadamard_rotate_full_fp16(half* data, int num_rows, int dim, const unsigned int* signs, cudaStream_t stream) {
1745+
hadamard_rotate_full_fp16(data, num_rows, dim, signs, stream);
1746+
}
1747+
1748+
void chadamard_rotate_full_bf16(
1749+
__nv_bfloat16* data, int num_rows, int dim, const unsigned int* signs, cudaStream_t stream
1750+
) {
1751+
hadamard_rotate_full_bf16(data, num_rows, dim, signs, stream);
1752+
}
1753+
17101754
#endif
17111755
}

0 commit comments

Comments
 (0)