Skip to content

Commit 7b6c76f

Browse files
Add CUDA kernel support for 4-bit quantization with blocksize=32 (#1854)
* kQuantizeBlockwise32 kernel addition * fix * adding tests * fix * docstring update * independent logical warp * update values
1 parent 17d32f1 commit 7b6c76f

File tree

8 files changed

+138
-18
lines changed

8 files changed

+138
-18
lines changed

bitsandbytes/backends/cuda/ops.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,7 @@ def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor
214214
if ROCM_WARP_SIZE_64:
215215
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128])
216216
else:
217-
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64])
217+
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64, 32])
218218

219219
torch._check(code.dtype == torch.float32, lambda: f"code must be float32, got {code.dtype}")
220220

@@ -272,7 +272,7 @@ def _dequantize_blockwise_impl(
272272
if ROCM_WARP_SIZE_64:
273273
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128])
274274
else:
275-
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64])
275+
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64, 32])
276276

277277
torch._check(A.dtype == torch.uint8, lambda: f"A must be uint8, got {A.dtype}")
278278
torch._check(
@@ -306,7 +306,7 @@ def _(
306306
if ROCM_WARP_SIZE_64:
307307
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128])
308308
else:
309-
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64])
309+
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64, 32])
310310

311311
torch._check(quant_type in ["fp4", "nf4"])
312312
torch._check(
@@ -388,7 +388,7 @@ def _dequantize_4bit_impl(
388388
if ROCM_WARP_SIZE_64:
389389
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128])
390390
else:
391-
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64])
391+
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64, 32])
392392

393393
torch._check(quant_type in ["fp4", "nf4"])
394394
torch._check(

bitsandbytes/functional.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -842,7 +842,7 @@ def quantize_4bit(
842842
out (`torch.Tensor`, *optional*): A tensor to use to store the result.
843843
blocksize (`int`, *optional*):
844844
The size of the blocks. Defaults to 128 on ROCm and 64 otherwise.
845-
Valid values are 64, 128, 256, 512, 1024, 2048, and 4096.
845+
Valid values are 32, 64, 128, 256, 512, 1024, 2048, and 4096.
846846
compress_statistics (`bool`, *optional*): Whether to additionally quantize the absmax values. Defaults to False.
847847
quant_type (`str`, *optional*): The data type to use: `nf4` or `fp4`. Defaults to `fp4`.
848848
quant_storage (`torch.dtype`, *optional*): The dtype of the tensor used to store the result. Defaults to `torch.uint8`.
@@ -953,7 +953,7 @@ def dequantize_4bit(
953953
out (`torch.Tensor`, *optional*): A tensor to use to store the result.
954954
blocksize (`int`, *optional*):
955955
The size of the blocks. Defaults to 128 on ROCm and 64 otherwise.
956-
Valid values are 64, 128, 256, 512, 1024, 2048, and 4096.
956+
Valid values are 32, 64, 128, 256, 512, 1024, 2048, and 4096.
957957
quant_type (`str`, *optional*): The data type to use: `nf4` or `fp4`. Defaults to `fp4`.
958958
959959
Raises:

csrc/kernels.cu

Lines changed: 102 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -423,6 +423,90 @@ __global__ void kQuantizeBlockwise(
423423
}
424424
}
425425

426+
// Specialized kernel for blocksize=32 with 4-bit quantization
427+
// Processes 2 blocks of 32 values per warp to maintain full thread utilization
428+
// Uses 32 threads total: threads 0-15 handle block 0, threads 16-31 handle block 1
429+
template <typename T, int DATA_TYPE>
430+
__global__ void kQuantizeBlockwise32(
431+
float* code, T* __restrict__ const A, float* absmax, unsigned char* out, float* __restrict__ const rand,
432+
const int rand_offset, const int n
433+
) {
434+
constexpr int BLOCK_SIZE = 32; // Size of each quantization block
435+
constexpr int NUM_PER_TH = 2; // Values per thread (for 4-bit packing)
436+
constexpr int THREADS = 32; // Total threads (full warp)
437+
constexpr int THREADS_PER_BLOCK = 16; // Threads handling each quantization block
438+
439+
const int base_idx = blockIdx.x * BLOCK_SIZE * 2; // 2 blocks per CUDA block
440+
441+
T vals[NUM_PER_TH];
442+
unsigned char qvals[NUM_PER_TH / 2]; // For 4-bit: 2 values per byte
443+
float local_abs_max = 0.0f;
444+
445+
const int block_id = threadIdx.x / THREADS_PER_BLOCK; // 0 for threads 0-15, 1 for threads 16-31
446+
const int local_thread_id = threadIdx.x % THREADS_PER_BLOCK; // Thread ID within the block (0-15)
447+
448+
typedef cub::BlockLoad<T, THREADS, NUM_PER_TH, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadT;
449+
typedef cub::BlockStore<unsigned char, THREADS, NUM_PER_TH / 2, cub::BLOCK_STORE_WARP_TRANSPOSE> StoreChar;
450+
typedef cub::WarpReduce<float, 16>
451+
WarpReduce; // Logical warp size of 16: threads 0-15 and 16-31 reduce independently
452+
453+
__shared__ typename LoadT::TempStorage loadt;
454+
__shared__ typename StoreChar::TempStorage storec;
455+
__shared__ typename WarpReduce::TempStorage warp_reduce[2]; // One per logical warp
456+
__shared__ float smem_absmax_value[2];
457+
458+
const int i = base_idx + block_id * BLOCK_SIZE;
459+
// Use a flag instead of early return: BlockLoad/BlockStore/__syncthreads are cooperative
460+
// operations that require ALL 32 threads to participate
461+
const bool block_valid = (i < n);
462+
463+
// All 32 threads participate in the load (out-of-bounds threads get 0.0f)
464+
__syncthreads();
465+
LoadT(loadt).Load(&(A[base_idx]), vals, min(BLOCK_SIZE * 2, n - base_idx), (T)0.0f);
466+
467+
// Each thread computes max of its values
468+
local_abs_max = -FLT_MAX;
469+
#pragma unroll NUM_PER_TH
470+
for (int j = 0; j < NUM_PER_TH; j++)
471+
local_abs_max = fmaxf(local_abs_max, fabsf((float)vals[j]));
472+
473+
// Reduce within each logical warp of 16 threads independently
474+
local_abs_max = WarpReduce(warp_reduce[block_id]).Reduce(local_abs_max, CUB_REDUCTIONOP_MAX);
475+
476+
if (local_thread_id == 0) {
477+
if (block_valid) {
478+
smem_absmax_value[block_id] = 1.0f / local_abs_max;
479+
absmax[blockIdx.x * 2 + block_id] = local_abs_max;
480+
} else {
481+
smem_absmax_value[block_id] = 0.0f;
482+
}
483+
}
484+
__syncthreads();
485+
486+
local_abs_max = smem_absmax_value[block_id];
487+
488+
switch (DATA_TYPE) {
489+
case FP4:
490+
#pragma unroll NUM_PER_TH
491+
for (int j = 0; j < NUM_PER_TH / 2; j++) {
492+
qvals[j] = dQuantizeFP4(((float)vals[2 * j]) * local_abs_max) << 4;
493+
qvals[j] |= dQuantizeFP4(((float)vals[2 * j + 1]) * local_abs_max);
494+
}
495+
break;
496+
case NF4:
497+
#pragma unroll NUM_PER_TH
498+
for (int j = 0; j < NUM_PER_TH / 2; j++) {
499+
qvals[j] = dQuantizeNF4(((float)vals[2 * j]) * local_abs_max) << 4;
500+
qvals[j] |= dQuantizeNF4(((float)vals[2 * j + 1]) * local_abs_max);
501+
}
502+
break;
503+
}
504+
505+
// All 32 threads participate in the store (valid_items limits the actual writes)
506+
__syncthreads();
507+
StoreChar(storec).Store(&(out[base_idx / 2]), qvals, min((BLOCK_SIZE * 2 + 1) / 2, (n - base_idx + 1) / 2));
508+
}
509+
426510
template <typename T, int TILE_SIZE, int THREADS, int NUM_PER_TH, int DATA_TYPE>
427511
__global__ void
428512
kDequantizeBlockwise(float* code, unsigned char* A, float* absmax, T* out, const int blocksize, const int n) {
@@ -2440,9 +2524,24 @@ MAKE_kQuantizeBlockwise(__nv_bfloat16, 256, 2, 0, NF4)
24402524
MAKE_kQuantizeBlockwise(__nv_bfloat16, 128, 2, 0, NF4)
24412525
MAKE_kQuantizeBlockwise(__nv_bfloat16, 64, 2, 0, NF4)
24422526

2443-
template __global__ void kDequantizeBlockwise<half, 512, 64, 8, FP4>(
2444-
float* code, unsigned char* A, float* absmax, half* out, const int blocksize, const int n
2445-
);
2527+
// Template instantiations for blocksize=32 specialized kernel (4-bit only)
2528+
#define MAKE_kQuantizeBlockwise32(dtype, data_type_name) \
2529+
template __global__ void kQuantizeBlockwise32<dtype, data_type_name>( \
2530+
float* code, dtype* __restrict__ const A, float* absmax, unsigned char* out, float* __restrict__ const rand, \
2531+
const int rand_offset, const int n \
2532+
);
2533+
2534+
// FP4 instantiations for blocksize=32
2535+
MAKE_kQuantizeBlockwise32(half, FP4) MAKE_kQuantizeBlockwise32(float, FP4) MAKE_kQuantizeBlockwise32(__nv_bfloat16, FP4)
2536+
2537+
// NF4 instantiations for blocksize=32
2538+
MAKE_kQuantizeBlockwise32(half, NF4) MAKE_kQuantizeBlockwise32(float, NF4) MAKE_kQuantizeBlockwise32(
2539+
__nv_bfloat16, NF4
2540+
)
2541+
2542+
template __global__ void kDequantizeBlockwise<half, 512, 64, 8, FP4>(
2543+
float* code, unsigned char* A, float* absmax, half* out, const int blocksize, const int n
2544+
);
24462545
template __global__ void kDequantizeBlockwise<half, 512, 64, 8, General8bit>(
24472546
float* code, unsigned char* A, float* absmax, half* out, const int blocksize, const int n
24482547
);

csrc/kernels.cuh

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,11 @@ __global__ void kQuantizeBlockwise(
1717
float* code, T* __restrict__ const A, float* absmax, unsigned char* out, float* __restrict__ const rand,
1818
const int rand_offset, const int n
1919
);
20+
template <typename T, int DATA_TYPE>
21+
__global__ void kQuantizeBlockwise32(
22+
float* code, T* __restrict__ const A, float* absmax, unsigned char* out, float* __restrict__ const rand,
23+
const int rand_offset, const int n
24+
);
2025
template <typename T, int BLOCK_SIZE, int THREADS, int NUM_PER_TH, int DATA_TYPE>
2126
__global__ void
2227
kDequantizeBlockwise(float* code, unsigned char* A, float* absmax, T* out, const int blocksize, const int n);

csrc/ops.cu

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,14 @@ void quantizeBlockwise(
5050
kQuantizeBlockwise<T, 128, 2, 0, DATA_TYPE><<<num_blocks, 64>>>(code, A, absmax, out, rand, rand_offset, n);
5151
else if (blocksize == 64)
5252
kQuantizeBlockwise<T, 64, 2, 0, DATA_TYPE><<<num_blocks, 32>>>(code, A, absmax, out, rand, rand_offset, n);
53+
else if (blocksize == 32) {
54+
// For 4-bit: use specialized kernel (kQuantizeBlockwise32) that processes 2 blocks per warp
55+
// Each CUDA block handles 2 quantization blocks, so divide num_blocks by 2
56+
if (DATA_TYPE > 0) {
57+
int num_blocks_adjusted = (num_blocks + 1) / 2;
58+
kQuantizeBlockwise32<T, DATA_TYPE><<<num_blocks_adjusted, 32>>>(code, A, absmax, out, rand, rand_offset, n);
59+
}
60+
}
5361

5462
CUDA_CHECK_RETURN(cudaPeekAtLastError());
5563
}

tests/test_functional.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1098,7 +1098,7 @@ class TestQuantize4BitFunctional:
10981098
@pytest.mark.parametrize("quant_type", ["fp4", "nf4"])
10991099
@pytest.mark.parametrize(
11001100
"blocksize",
1101-
[64, 128, 256, 512, 1024, 2048, 4096] if not ROCM_WARP_SIZE_64 else [128, 256, 512, 1024, 2048, 4096],
1101+
[32, 64, 128, 256, 512, 1024, 2048, 4096] if not ROCM_WARP_SIZE_64 else [128, 256, 512, 1024, 2048, 4096],
11021102
)
11031103
def test_4bit_quant(self, device, dtype, quant_type, blocksize):
11041104
if device == "hpu" and not is_supported_on_hpu(quant_type, dtype):
@@ -1122,6 +1122,7 @@ def test_4bit_quant(self, device, dtype, quant_type, blocksize):
11221122
error_dict["fp4"] = dict()
11231123
error_dict["nf4"] = dict()
11241124
error_dict["fp4"]["err"] = {
1125+
32: 0.088918,
11251126
64: 0.096545,
11261127
128: 0.102947,
11271128
256: 0.108685,
@@ -1131,6 +1132,7 @@ def test_4bit_quant(self, device, dtype, quant_type, blocksize):
11311132
4096: 0.129573,
11321133
}
11331134
error_dict["fp4"]["rel_err"] = {
1135+
32: 0.242380,
11341136
64: 0.260130,
11351137
128: 0.275734,
11361138
256: 0.289842,
@@ -1141,6 +1143,7 @@ def test_4bit_quant(self, device, dtype, quant_type, blocksize):
11411143
}
11421144

11431145
error_dict["nf4"]["err"] = {
1146+
32: 0.067745,
11441147
64: 0.072792,
11451148
128: 0.076835,
11461149
256: 0.080326,
@@ -1150,6 +1153,7 @@ def test_4bit_quant(self, device, dtype, quant_type, blocksize):
11501153
4096: 0.092537,
11511154
}
11521155
error_dict["nf4"]["rel_err"] = {
1156+
32: 0.189700,
11531157
64: 0.203299,
11541158
128: 0.215252,
11551159
256: 0.226044,
@@ -1168,7 +1172,9 @@ def test_4bit_quant(self, device, dtype, quant_type, blocksize):
11681172

11691173
@pytest.mark.parametrize("device", get_available_devices())
11701174
@pytest.mark.parametrize("quant_type", ["fp4", "nf4"])
1171-
@pytest.mark.parametrize("blocksize", [64, 128] if not ROCM_WARP_SIZE_64 else [128], ids=id_formatter("blocksize"))
1175+
@pytest.mark.parametrize(
1176+
"blocksize", [32, 64, 128] if not ROCM_WARP_SIZE_64 else [128], ids=id_formatter("blocksize")
1177+
)
11721178
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=describe_dtype)
11731179
def test_4bit_compressed_stats(self, device, quant_type, blocksize, dtype):
11741180
if device == "hpu" and not is_supported_on_hpu(quant_type, dtype):
@@ -1205,7 +1211,9 @@ def test_4bit_compressed_stats(self, device, quant_type, blocksize, dtype):
12051211
@pytest.mark.skipif(not get_available_devices(no_cpu=True), reason="No accelerator device")
12061212
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16], ids=describe_dtype)
12071213
@pytest.mark.parametrize("quant_type", ["fp4", "nf4"])
1208-
@pytest.mark.parametrize("blocksize", [64, 128] if not ROCM_WARP_SIZE_64 else [128], ids=id_formatter("blocksize"))
1214+
@pytest.mark.parametrize(
1215+
"blocksize", [32, 64, 128] if not ROCM_WARP_SIZE_64 else [128], ids=id_formatter("blocksize")
1216+
)
12091217
def test_4bit_quant_large(self, device, dtype, quant_type, blocksize):
12101218
"""
12111219
Test that we can successfully quantize a large tensor. Note that the following limitations apply:

tests/test_linear4bit.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,7 @@ def test_linear_serialization(
193193

194194
@pytest.mark.parametrize("device", get_available_devices())
195195
@pytest.mark.parametrize("quant_type", ["nf4", "fp4"])
196-
@pytest.mark.parametrize("blocksize", [64, 128] if not ROCM_WARP_SIZE_64 else [128])
196+
@pytest.mark.parametrize("blocksize", [32, 64, 128] if not ROCM_WARP_SIZE_64 else [128])
197197
@pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics"))
198198
def test_copy_param(device, quant_type, blocksize, compress_statistics):
199199
if device == "hpu" and not is_supported_on_hpu(quant_type):
@@ -250,7 +250,7 @@ def test_params4bit_torch_chunk_split(device, quant_type):
250250

251251
@pytest.mark.parametrize("device", get_available_devices())
252252
@pytest.mark.parametrize("quant_type", ["nf4", "fp4"])
253-
@pytest.mark.parametrize("blocksize", [64, 128] if not ROCM_WARP_SIZE_64 else [128])
253+
@pytest.mark.parametrize("blocksize", [32, 64, 128] if not ROCM_WARP_SIZE_64 else [128])
254254
@pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics"))
255255
def test_deepcopy_param(device, quant_type, blocksize, compress_statistics):
256256
if device == "hpu" and not is_supported_on_hpu(quant_type):
@@ -279,7 +279,7 @@ def test_deepcopy_param(device, quant_type, blocksize, compress_statistics):
279279

280280
@pytest.mark.parametrize("device", get_available_devices())
281281
@pytest.mark.parametrize("quant_type", ["nf4", "fp4"])
282-
@pytest.mark.parametrize("blocksize", [64, 128] if not ROCM_WARP_SIZE_64 else [128])
282+
@pytest.mark.parametrize("blocksize", [32, 64, 128] if not ROCM_WARP_SIZE_64 else [128])
283283
@pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics"))
284284
def test_params4bit_real_serialization(device, quant_type, blocksize, compress_statistics):
285285
if device == "hpu" and not is_supported_on_hpu(quant_type):

tests/test_ops.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,7 @@ class Test4bitBlockwiseQuantOps:
152152
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype"))
153153
@pytest.mark.parametrize("storage_dtype", [torch.uint8, torch.bfloat16], ids=id_formatter("storage_dtype"))
154154
@pytest.mark.parametrize("quant_type", ["fp4", "nf4"])
155-
@pytest.mark.parametrize("blocksize", [64, 128, 256, 512] if not ROCM_WARP_SIZE_64 else [128, 256, 512])
155+
@pytest.mark.parametrize("blocksize", [32, 64, 128, 256, 512] if not ROCM_WARP_SIZE_64 else [128, 256, 512])
156156
def test_quantize_4bit(self, device, dtype, storage_dtype, quant_type, blocksize):
157157
if device == "hpu" and not is_supported_on_hpu(quant_type, dtype, storage_dtype):
158158
pytest.skip("This configuration is not supported on HPU.")
@@ -176,7 +176,7 @@ def test_quantize_4bit(self, device, dtype, storage_dtype, quant_type, blocksize
176176
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype"))
177177
@pytest.mark.parametrize("storage_dtype", [torch.uint8, torch.bfloat16], ids=id_formatter("storage_dtype"))
178178
@pytest.mark.parametrize("quant_type", ["fp4", "nf4"])
179-
@pytest.mark.parametrize("blocksize", [64, 128, 256, 512] if not ROCM_WARP_SIZE_64 else [128, 256, 512])
179+
@pytest.mark.parametrize("blocksize", [32, 64, 128, 256, 512] if not ROCM_WARP_SIZE_64 else [128, 256, 512])
180180
def test_dequantize_4bit(self, device, dtype, storage_dtype, quant_type, blocksize):
181181
if device == "hpu" and not is_supported_on_hpu(quant_type, dtype, storage_dtype):
182182
pytest.skip("This configuration is not supported on HPU.")
@@ -210,7 +210,7 @@ def test_dequantize_4bit(self, device, dtype, storage_dtype, quant_type, blocksi
210210
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype"))
211211
@pytest.mark.parametrize("storage_dtype", [torch.uint8, torch.bfloat16], ids=id_formatter("storage_dtype"))
212212
@pytest.mark.parametrize("quant_type", ["fp4", "nf4"])
213-
@pytest.mark.parametrize("blocksize", [64, 128, 256, 512] if not ROCM_WARP_SIZE_64 else [128, 256, 512])
213+
@pytest.mark.parametrize("blocksize", [32, 64, 128, 256, 512] if not ROCM_WARP_SIZE_64 else [128, 256, 512])
214214
@pytest.mark.skipif(ROCM_WARP_SIZE_64, reason="this test is not supported on ROCm yet")
215215
def test_gemv_4bit(self, device, dtype, storage_dtype, quant_type, blocksize):
216216
if device == "hpu" and not is_supported_on_hpu(quant_type, dtype, storage_dtype):

0 commit comments

Comments
 (0)