Skip to content

Commit a5b1328

Browse files
committed
Fix blocksize-32/64 4-bit quantization and GEMV on CDNA (warp size 64)
Made-with: Cursor
1 parent 83892a5 commit a5b1328

File tree

5 files changed

+79
-89
lines changed

5 files changed

+79
-89
lines changed

bitsandbytes/backends/cuda/ops.py

Lines changed: 5 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from bitsandbytes.functional import CUBLAS_Context, _cuda_device_of, _get_tensor_stream, get_ptr
99

1010
from ..._ops import register_kernel
11-
from ...cextension import ROCM_WARP_SIZE_64, lib
11+
from ...cextension import lib
1212

1313

1414
@register_kernel("bitsandbytes::int8_linear_matmul", "cuda")
@@ -212,10 +212,7 @@ def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor
212212
A = A.contiguous()
213213
torch._check_is_size(blocksize)
214214

215-
if ROCM_WARP_SIZE_64:
216-
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64])
217-
else:
218-
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64, 32])
215+
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64, 32])
219216

220217
torch._check(code.dtype == torch.float32, lambda: f"code must be float32, got {code.dtype}")
221218

@@ -271,10 +268,7 @@ def _dequantize_blockwise_impl(
271268
A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype, out: torch.Tensor
272269
) -> None:
273270
A = A.contiguous()
274-
if ROCM_WARP_SIZE_64:
275-
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64])
276-
else:
277-
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64, 32])
271+
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64, 32])
278272

279273
torch._check(A.dtype == torch.uint8, lambda: f"A must be uint8, got {A.dtype}")
280274
torch._check(
@@ -306,10 +300,7 @@ def _(
306300
A: torch.Tensor, blocksize: int, quant_type: str, quant_storage: torch.dtype
307301
) -> tuple[torch.Tensor, torch.Tensor]:
308302
A = A.contiguous()
309-
if ROCM_WARP_SIZE_64:
310-
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64])
311-
else:
312-
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64, 32])
303+
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64, 32])
313304

314305
torch._check(quant_type in ["fp4", "nf4"])
315306
torch._check(
@@ -389,10 +380,7 @@ def _dequantize_4bit_impl(
389380
out: torch.Tensor,
390381
) -> None:
391382
A = A.contiguous()
392-
if ROCM_WARP_SIZE_64:
393-
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64])
394-
else:
395-
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64, 32])
383+
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64, 32])
396384

397385
torch._check(quant_type in ["fp4", "nf4"])
398386
torch._check(

bitsandbytes/cextension.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
get_cuda_specs,
1515
get_cuda_version_tuple,
1616
get_rocm_gpu_arch,
17-
get_rocm_warpsize,
1817
)
1918

2019
logger = logging.getLogger(__name__)
@@ -317,7 +316,6 @@ def get_native_library() -> BNBNativeLibrary:
317316

318317

319318
ROCM_GPU_ARCH = get_rocm_gpu_arch()
320-
ROCM_WARP_SIZE_64 = True if get_rocm_warpsize() == 64 else False
321319

322320
HIP_ENVIRONMENT = False
323321
BNB_BACKEND = "CPU"

csrc/kernels.cu

Lines changed: 63 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -374,69 +374,72 @@ __global__ void kQuantizeBlockwise(
374374
}
375375
}
376376

377-
// Unified small-blocksize kernel for 4-bit quantization
378-
// Processes 2 blocks of BNB_WARP_SIZE values per thread block
379-
// On CUDA (warp=32): blocksize=32, 32 threads, WarpReduce<16>
380-
// On HIP (warp=64): blocksize=64, 64 threads, WarpReduce<32>
381-
// On HIP (warp=32): blocksize=32, 32 threads, WarpReduce<16>
382-
template <typename T, int DATA_TYPE>
377+
// Small-blocksize kernel for 4-bit quantization, parameterized on quantization
378+
// block size (QBLOCK_SIZE). Always launches exactly BNB_WARP_SIZE threads so
379+
// every lane in the wavefront is productive. Multiple quantization blocks are
380+
// packed into one wavefront when QBLOCK_SIZE < BNB_WARP_SIZE * NUM_PER_TH:
381+
//
382+
// CDNA (64), QBLOCK_SIZE=32 -> 4 quant blocks per wavefront
383+
// CDNA (64), QBLOCK_SIZE=64 -> 2 quant blocks per wavefront
384+
// CUDA/RDNA (32), QBLOCK_SIZE=32 -> 2 quant blocks per wavefront
385+
//
386+
// Uses logical-warp WarpReduce<THREADS_PER_QB> so each quantization block's
387+
// threads reduce independently via warp shuffles.
388+
template <typename T, int QBLOCK_SIZE, int DATA_TYPE>
383389
__global__ void kQuantizeBlockwiseSmall(
384390
float* code, T* __restrict__ const A, float* absmax, unsigned char* out, float* __restrict__ const rand,
385391
const int rand_offset, const int n
386392
) {
387-
constexpr int BLOCK_SIZE = BNB_WARP_SIZE; // Size of each quantization block
388-
constexpr int NUM_PER_TH = 2; // Values per thread (for 4-bit packing)
389-
constexpr int THREADS = BNB_WARP_SIZE; // Total threads (one full warp)
390-
constexpr int THREADS_PER_BLOCK = BNB_WARP_SIZE / 2; // Half-warp per quantization block
393+
static_assert(QBLOCK_SIZE <= BNB_WARP_SIZE * 2, "QBLOCK_SIZE too large for one warp");
391394

392-
const int base_idx = blockIdx.x * BLOCK_SIZE * 2; // 2 blocks per thread block
395+
constexpr int NUM_PER_TH = 2;
396+
constexpr int THREADS = BNB_WARP_SIZE;
397+
constexpr int THREADS_PER_QB = QBLOCK_SIZE / NUM_PER_TH;
398+
constexpr int NUM_QB = THREADS / THREADS_PER_QB;
399+
constexpr int TOTAL_VALUES = QBLOCK_SIZE * NUM_QB;
400+
401+
const int base_idx = blockIdx.x * TOTAL_VALUES;
393402

394403
T vals[NUM_PER_TH];
395-
unsigned char qvals[NUM_PER_TH / 2]; // For 4-bit: 2 values per byte
404+
unsigned char qvals[NUM_PER_TH / 2];
396405
float local_abs_max = 0.0f;
397406

398-
const int block_id = threadIdx.x / THREADS_PER_BLOCK; // 0 for threads 0-15, 1 for threads 16-31
399-
const int local_thread_id = threadIdx.x % THREADS_PER_BLOCK; // Thread ID within the block (0-15)
407+
const int qb_id = threadIdx.x / THREADS_PER_QB;
408+
const int local_tid = threadIdx.x % THREADS_PER_QB;
400409

401410
typedef bnb_cub::BlockLoad<T, THREADS, NUM_PER_TH, bnb_cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadT;
402411
typedef bnb_cub::BlockStore<unsigned char, THREADS, NUM_PER_TH / 2, bnb_cub::BLOCK_STORE_WARP_TRANSPOSE> StoreChar;
403-
typedef bnb_cub::WarpReduce<float, THREADS_PER_BLOCK>
404-
WarpReduce; // Half-warp logical reduction: each half reduces independently
412+
typedef bnb_cub::WarpReduce<float, THREADS_PER_QB> WarpReduce;
405413

406414
__shared__ typename LoadT::TempStorage loadt;
407415
__shared__ typename StoreChar::TempStorage storec;
408-
__shared__ typename WarpReduce::TempStorage warp_reduce[2]; // One per logical warp
409-
__shared__ float smem_absmax_value[2];
416+
__shared__ typename WarpReduce::TempStorage warp_reduce[NUM_QB];
417+
__shared__ float smem_absmax_value[NUM_QB];
410418

411-
const int i = base_idx + block_id * BLOCK_SIZE;
412-
// Use a flag instead of early return: BlockLoad/BlockStore/__syncthreads are cooperative
413-
// operations that require ALL 32 threads to participate
414-
const bool block_valid = (i < n);
419+
const int qi = base_idx + qb_id * QBLOCK_SIZE;
420+
const bool qb_valid = (qi < n);
415421

416-
// All 32 threads participate in the load (out-of-bounds threads get 0.0f)
417422
__syncthreads();
418-
LoadT(loadt).Load(&(A[base_idx]), vals, min(BLOCK_SIZE * 2, n - base_idx), (T)0.0f);
423+
LoadT(loadt).Load(&(A[base_idx]), vals, min(TOTAL_VALUES, n - base_idx), (T)0.0f);
419424

420-
// Each thread computes max of its values
421425
local_abs_max = -FLT_MAX;
422426
#pragma unroll NUM_PER_TH
423427
for (int j = 0; j < NUM_PER_TH; j++)
424428
local_abs_max = fmaxf(local_abs_max, fabsf((float)vals[j]));
425429

426-
// Reduce within each logical warp of 16 threads independently
427-
local_abs_max = WarpReduce(warp_reduce[block_id]).Reduce(local_abs_max, BNB_MAX_OP);
430+
local_abs_max = WarpReduce(warp_reduce[qb_id]).Reduce(local_abs_max, BNB_MAX_OP);
428431

429-
if (local_thread_id == 0) {
430-
if (block_valid) {
431-
smem_absmax_value[block_id] = 1.0f / local_abs_max;
432-
absmax[blockIdx.x * 2 + block_id] = local_abs_max;
432+
if (local_tid == 0) {
433+
if (qb_valid) {
434+
smem_absmax_value[qb_id] = 1.0f / local_abs_max;
435+
absmax[blockIdx.x * NUM_QB + qb_id] = local_abs_max;
433436
} else {
434-
smem_absmax_value[block_id] = 0.0f;
437+
smem_absmax_value[qb_id] = 0.0f;
435438
}
436439
}
437440
__syncthreads();
438441

439-
local_abs_max = smem_absmax_value[block_id];
442+
local_abs_max = smem_absmax_value[qb_id];
440443

441444
switch (DATA_TYPE) {
442445
case FP4:
@@ -455,9 +458,8 @@ __global__ void kQuantizeBlockwiseSmall(
455458
break;
456459
}
457460

458-
// All 32 threads participate in the store (valid_items limits the actual writes)
459461
__syncthreads();
460-
StoreChar(storec).Store(&(out[base_idx / 2]), qvals, min((BLOCK_SIZE * 2 + 1) / 2, (n - base_idx + 1) / 2));
462+
StoreChar(storec).Store(&(out[base_idx / 2]), qvals, min((TOTAL_VALUES + 1) / 2, (n - base_idx + 1) / 2));
461463
}
462464

463465
template <typename T, int TILE_SIZE, int THREADS, int NUM_PER_TH, int DATA_TYPE>
@@ -1446,15 +1448,15 @@ __global__ void kgemm_4bit_inference_naive(
14461448
) {
14471449

14481450
// per threadblock:
1449-
// load step-by-step in chunks of [32,warps]: 1x32 * [32,warps] -> [1,warps]
1450-
// 4 warps -> 4 loads per iter
1451-
// 1x32 * 32x4 -> 1x4 outputs per thread block
1451+
// load step-by-step in chunks of [warp_size,warps]: 1xwarp_size * [warp_size,warps] -> [1,warps]
1452+
// THREADS/BNB_WARP_SIZE warps -> that many loads per iter
1453+
// 1xwarp_size * warp_size x warps -> 1 x warps outputs per thread block
14521454
typedef bnb_cub::WarpReduce<float> WarpReduce;
1453-
__shared__ typename WarpReduce::TempStorage temp_storage[THREADS / 32];
1455+
__shared__ typename WarpReduce::TempStorage temp_storage[THREADS / BNB_WARP_SIZE];
14541456

1455-
const int warp_idx = threadIdx.x / 32;
1456-
const int warp_lane = threadIdx.x % 32;
1457-
const int row_B = (THREADS / 32) * blockIdx.x + warp_idx;
1457+
const int warp_idx = threadIdx.x / BNB_WARP_SIZE;
1458+
const int warp_lane = threadIdx.x % BNB_WARP_SIZE;
1459+
const int row_B = (THREADS / BNB_WARP_SIZE) * blockIdx.x + warp_idx;
14581460
const int offset_B = ldb * row_B;
14591461
const int num_values_8bit = num_values_4bit / 2;
14601462
float local_C = 0.0f;
@@ -1473,7 +1475,7 @@ __global__ void kgemm_4bit_inference_naive(
14731475

14741476
// A: [1, K]
14751477
// B: [N, K]
1476-
for (int inner_idx = warp_lane * num_values_4bit; inner_idx < K; inner_idx += 32 * num_values_4bit) {
1478+
for (int inner_idx = warp_lane * num_values_4bit; inner_idx < K; inner_idx += BNB_WARP_SIZE * num_values_4bit) {
14771479
const int inner_idx_halved = inner_idx / 2;
14781480

14791481
// Since blocksize will always be a power-of-2, we avoid more expensive
@@ -1766,22 +1768,28 @@ MAKE_kQuantizeBlockwise(bnb_bfloat16, 256, 2, 0, NF4)
17661768
MAKE_kQuantizeBlockwise(bnb_bfloat16, 128, 2, 0, NF4)
17671769
MAKE_kQuantizeBlockwise(bnb_bfloat16, 64, 2, 0, NF4)
17681770

1769-
// Template instantiations for blocksize=32 specialized kernel (4-bit only)
1770-
#define MAKE_kQuantizeBlockwiseSmall(dtype, data_type_name) \
1771-
template __global__ void kQuantizeBlockwiseSmall<dtype, data_type_name>( \
1771+
// Template instantiations for kQuantizeBlockwiseSmall (4-bit only)
1772+
#define MAKE_kQuantizeBlockwiseSmall(dtype, qblock_size, data_type_name) \
1773+
template __global__ void kQuantizeBlockwiseSmall<dtype, qblock_size, data_type_name>( \
17721774
float* code, dtype* __restrict__ const A, float* absmax, unsigned char* out, float* __restrict__ const rand, \
17731775
const int rand_offset, const int n \
17741776
);
17751777

1776-
// FP4 instantiations for blocksize=32
1777-
MAKE_kQuantizeBlockwiseSmall(half, FP4) MAKE_kQuantizeBlockwiseSmall(float, FP4) MAKE_kQuantizeBlockwiseSmall(
1778-
bnb_bfloat16, FP4
1779-
)
1780-
1781-
// NF4 instantiations for blocksize=32
1782-
MAKE_kQuantizeBlockwiseSmall(half, NF4) MAKE_kQuantizeBlockwiseSmall(float, NF4) MAKE_kQuantizeBlockwiseSmall(
1783-
bnb_bfloat16, NF4
1784-
)
1778+
// QBLOCK_SIZE=32 instantiations
1779+
MAKE_kQuantizeBlockwiseSmall(half, 32, FP4)
1780+
MAKE_kQuantizeBlockwiseSmall(float, 32, FP4)
1781+
MAKE_kQuantizeBlockwiseSmall(bnb_bfloat16, 32, FP4)
1782+
MAKE_kQuantizeBlockwiseSmall(half, 32, NF4)
1783+
MAKE_kQuantizeBlockwiseSmall(float, 32, NF4)
1784+
MAKE_kQuantizeBlockwiseSmall(bnb_bfloat16, 32, NF4)
1785+
1786+
// QBLOCK_SIZE=64 instantiations (used on HIP for blocksize=64)
1787+
MAKE_kQuantizeBlockwiseSmall(half, 64, FP4)
1788+
MAKE_kQuantizeBlockwiseSmall(float, 64, FP4)
1789+
MAKE_kQuantizeBlockwiseSmall(bnb_bfloat16, 64, FP4)
1790+
MAKE_kQuantizeBlockwiseSmall(half, 64, NF4)
1791+
MAKE_kQuantizeBlockwiseSmall(float, 64, NF4)
1792+
MAKE_kQuantizeBlockwiseSmall(bnb_bfloat16, 64, NF4)
17851793

17861794
template __global__ void kDequantizeBlockwise<half, 512, 64, 8, FP4>(
17871795
float* code, unsigned char* A, float* absmax, half* out, const int blocksize, const int n

csrc/kernels.cuh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ __global__ void kQuantizeBlockwise(
1414
float* code, T* __restrict__ const A, float* absmax, unsigned char* out, float* __restrict__ const rand,
1515
const int rand_offset, const int n
1616
);
17-
template <typename T, int DATA_TYPE>
17+
template <typename T, int QBLOCK_SIZE, int DATA_TYPE>
1818
__global__ void kQuantizeBlockwiseSmall(
1919
float* code, T* __restrict__ const A, float* absmax, unsigned char* out, float* __restrict__ const rand,
2020
const int rand_offset, const int n

csrc/ops.cu

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -53,28 +53,24 @@ void quantizeBlockwise(
5353
else if (blocksize == 64) {
5454
#if BNB_HIP
5555
if constexpr (DATA_TYPE > 0) {
56-
if (bnb_host_warp_size() == 64) {
57-
// CDNA: kQuantizeBlockwiseSmall is compiled with THREADS=64
58-
kQuantizeBlockwiseSmall<T, DATA_TYPE>
59-
<<<(num_blocks + 1) / 2, 64>>>(code, A, absmax, out, rand, rand_offset, n);
60-
} else {
61-
// RDNA: standard kernel (same as CUDA path)
62-
kQuantizeBlockwise<T, 64, 2, 0, DATA_TYPE>
63-
<<<num_blocks, 32>>>(code, A, absmax, out, rand, rand_offset, n);
64-
}
56+
const int ws = bnb_host_warp_size();
57+
const int num_qb = ws / (64 / 2);
58+
int grid = (num_blocks + num_qb - 1) / num_qb;
59+
kQuantizeBlockwiseSmall<T, 64, DATA_TYPE>
60+
<<<grid, ws>>>(code, A, absmax, out, rand, rand_offset, n);
6561
} else {
6662
kQuantizeBlockwise<T, 64, 2, 0, DATA_TYPE><<<num_blocks, 32>>>(code, A, absmax, out, rand, rand_offset, n);
6763
}
6864
#else
6965
kQuantizeBlockwise<T, 64, 2, 0, DATA_TYPE><<<num_blocks, 32>>>(code, A, absmax, out, rand, rand_offset, n);
7066
#endif
7167
} else if (blocksize == 32) {
72-
// For 4-bit: use specialized kernel that processes 2 blocks per warp
73-
// Each CUDA block handles 2 quantization blocks, so divide num_blocks by 2
7468
if constexpr (DATA_TYPE > 0) {
75-
int num_blocks_adjusted = (num_blocks + 1) / 2;
76-
kQuantizeBlockwiseSmall<T, DATA_TYPE>
77-
<<<num_blocks_adjusted, 32>>>(code, A, absmax, out, rand, rand_offset, n);
69+
const int ws = bnb_host_warp_size();
70+
const int num_qb = ws / (32 / 2);
71+
int grid = (num_blocks + num_qb - 1) / num_qb;
72+
kQuantizeBlockwiseSmall<T, 32, DATA_TYPE>
73+
<<<grid, ws>>>(code, A, absmax, out, rand, rand_offset, n);
7874
}
7975
}
8076

0 commit comments

Comments
 (0)