Skip to content

Commit 16eb28e

Browse files
kQuantizeBlockwise64 addition for rocm (#1856)
1 parent 7b6c76f commit 16eb28e

File tree

7 files changed

+127
-15
lines changed

7 files changed

+127
-15
lines changed

bitsandbytes/backends/cuda/ops.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,7 @@ def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor
212212
torch._check_is_size(blocksize)
213213

214214
if ROCM_WARP_SIZE_64:
215-
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128])
215+
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64])
216216
else:
217217
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64, 32])
218218

@@ -270,7 +270,7 @@ def _dequantize_blockwise_impl(
270270
A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype, out: torch.Tensor
271271
) -> None:
272272
if ROCM_WARP_SIZE_64:
273-
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128])
273+
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64])
274274
else:
275275
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64, 32])
276276

@@ -304,7 +304,7 @@ def _(
304304
A: torch.Tensor, blocksize: int, quant_type: str, quant_storage: torch.dtype
305305
) -> tuple[torch.Tensor, torch.Tensor]:
306306
if ROCM_WARP_SIZE_64:
307-
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128])
307+
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64])
308308
else:
309309
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64, 32])
310310

@@ -386,7 +386,7 @@ def _dequantize_4bit_impl(
386386
out: torch.Tensor,
387387
) -> None:
388388
if ROCM_WARP_SIZE_64:
389-
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128])
389+
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64])
390390
else:
391391
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64, 32])
392392

csrc/kernels.hip

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -443,6 +443,94 @@ __global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float
443443
}
444444
}
445445

446+
// Specialized kernel for blocksize=64 with 4-bit quantization
447+
// Works on both warp32 and warp64 hardware
448+
// Processes 2 blocks of 64 values per thread block using 64 threads
449+
// Uses logical warps of 32: threads 0-31 handle block 0, threads 32-63 handle block 1
450+
// - warp32: 2 hardware warps, each reduces naturally
451+
// - warp64: 1 hardware warp split into 2 logical warps of 32
452+
template <typename T, int DATA_TYPE>
453+
__global__ void kQuantizeBlockwise64(
454+
float* code, T* __restrict__ const A, float* absmax, unsigned char* out, float* __restrict__ const rand,
455+
const int rand_offset, const int n
456+
) {
457+
constexpr int BLOCK_SIZE = 64; // Size of each quantization block
458+
constexpr int NUM_PER_TH = 2; // Values per thread (for 4-bit packing)
459+
constexpr int THREADS = 64; // Total threads per HIP block
460+
constexpr int THREADS_PER_BLOCK = 32; // Threads handling each quantization block
461+
462+
const int base_idx = blockIdx.x * BLOCK_SIZE * 2; // 2 quantization blocks per HIP block
463+
464+
T vals[NUM_PER_TH];
465+
unsigned char qvals[NUM_PER_TH / 2]; // For 4-bit: 2 values per byte
466+
float local_abs_max = 0.0f;
467+
468+
const int block_id = threadIdx.x / THREADS_PER_BLOCK; // 0 for threads 0-31, 1 for threads 32-63
469+
const int local_thread_id = threadIdx.x % THREADS_PER_BLOCK; // Thread ID within the quantization block (0-31)
470+
471+
typedef hipcub::BlockLoad<T, THREADS, NUM_PER_TH, hipcub::BLOCK_LOAD_WARP_TRANSPOSE> LoadT;
472+
typedef hipcub::BlockStore<unsigned char, THREADS, NUM_PER_TH / 2, hipcub::BLOCK_STORE_WARP_TRANSPOSE> StoreChar;
473+
// Logical warp size of 32: on warp32 this matches hardware warps,
474+
// on warp64 this splits the single hardware warp into two independent reductions
475+
typedef hipcub::WarpReduce<float, 32> WarpReduce;
476+
477+
__shared__ typename LoadT::TempStorage loadt;
478+
__shared__ typename StoreChar::TempStorage storec;
479+
__shared__ typename WarpReduce::TempStorage warp_reduce[2]; // One per logical warp
480+
__shared__ float smem_absmax_value[2];
481+
482+
const int i = base_idx + block_id * BLOCK_SIZE;
483+
// Use a flag instead of early return: BlockLoad/BlockStore/__syncthreads are cooperative
484+
// operations that require ALL 64 threads to participate
485+
const bool block_valid = (i < n);
486+
487+
// All 64 threads participate in the load (out-of-bounds threads get 0.0f)
488+
__syncthreads();
489+
LoadT(loadt).Load(&(A[base_idx]), vals, min(BLOCK_SIZE * 2, n - base_idx), (T)0.0f);
490+
491+
// Each thread computes max of its values
492+
local_abs_max = -FLT_MAX;
493+
#pragma unroll NUM_PER_TH
494+
for (int j = 0; j < NUM_PER_TH; j++)
495+
local_abs_max = fmaxf(local_abs_max, fabsf((float)vals[j]));
496+
497+
// Reduce within each logical warp of 32 threads independently
498+
local_abs_max = WarpReduce(warp_reduce[block_id]).Reduce(local_abs_max, hipcub::Max());
499+
500+
if (local_thread_id == 0) {
501+
if (block_valid) {
502+
smem_absmax_value[block_id] = 1.0f / local_abs_max;
503+
absmax[blockIdx.x * 2 + block_id] = local_abs_max;
504+
} else {
505+
smem_absmax_value[block_id] = 0.0f;
506+
}
507+
}
508+
__syncthreads();
509+
510+
local_abs_max = smem_absmax_value[block_id];
511+
512+
switch (DATA_TYPE) {
513+
case FP4:
514+
#pragma unroll NUM_PER_TH
515+
for (int j = 0; j < NUM_PER_TH / 2; j++) {
516+
qvals[j] = dQuantizeFP4(((float)vals[2 * j]) * local_abs_max) << 4;
517+
qvals[j] |= dQuantizeFP4(((float)vals[2 * j + 1]) * local_abs_max);
518+
}
519+
break;
520+
case NF4:
521+
#pragma unroll NUM_PER_TH
522+
for (int j = 0; j < NUM_PER_TH / 2; j++) {
523+
qvals[j] = dQuantizeNF4(((float)vals[2 * j]) * local_abs_max) << 4;
524+
qvals[j] |= dQuantizeNF4(((float)vals[2 * j + 1]) * local_abs_max);
525+
}
526+
break;
527+
}
528+
529+
// All 64 threads participate in the store (valid_items limits the actual writes)
530+
__syncthreads();
531+
StoreChar(storec).Store(&(out[base_idx / 2]), qvals, min((BLOCK_SIZE * 2 + 1) / 2, (n - base_idx + 1) / 2));
532+
}
533+
446534
template<typename T, int TILE_SIZE, int THREADS, int NUM_PER_TH, int DATA_TYPE>
447535
__global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, T *out, const int blocksize, const int n)
448536
{
@@ -2566,6 +2654,20 @@ MAKE_kQuantizeBlockwise(hip_bfloat16, 128, 2, 0, NF4)
25662654
MAKE_kQuantizeBlockwise(hip_bfloat16, 64, 2, 0, NF4)
25672655
#endif
25682656

2657+
// Specialized blocksize=64 4-bit quantization kernel instantiations for ROCm
2658+
#define MAKE_kQuantizeBlockwise64(dtype, data_type_name) \
2659+
template __global__ void kQuantizeBlockwise64<dtype, data_type_name>(float * code, dtype * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
2660+
2661+
// FP4 instantiations
2662+
MAKE_kQuantizeBlockwise64(half, FP4)
2663+
MAKE_kQuantizeBlockwise64(float, FP4)
2664+
MAKE_kQuantizeBlockwise64(hip_bfloat16, FP4)
2665+
2666+
// NF4 instantiations
2667+
MAKE_kQuantizeBlockwise64(half, NF4)
2668+
MAKE_kQuantizeBlockwise64(float, NF4)
2669+
MAKE_kQuantizeBlockwise64(hip_bfloat16, NF4)
2670+
25692671
template __global__ void kDequantizeBlockwise<half, 512, 64, 8, FP4>(float *code, unsigned char * A, float * absmax, half *out, const int blocksize, const int n);
25702672
template __global__ void kDequantizeBlockwise<half, 512, 64, 8, General8bit>(float *code, unsigned char * A, float * absmax, half *out, const int blocksize, const int n);
25712673
template __global__ void kDequantizeBlockwise<half, 512, 64, 8, NF4>(float *code, unsigned char * A, float * absmax, half *out, const int blocksize, const int n);

csrc/kernels_hip.cuh

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,11 @@ __global__ void kQuantizeBlockwise(
1919
float* code, T* __restrict__ const A, float* absmax, unsigned char* out, float* __restrict__ const rand,
2020
const int rand_offset, const int n
2121
);
22+
template <typename T, int DATA_TYPE>
23+
__global__ void kQuantizeBlockwise64(
24+
float* code, T* __restrict__ const A, float* absmax, unsigned char* out, float* __restrict__ const rand,
25+
const int rand_offset, const int n
26+
);
2227
template <typename T, int BLOCK_SIZE, int THREADS, int NUM_PER_TH, int DATA_TYPE>
2328
__global__ void
2429
kDequantizeBlockwise(float* code, unsigned char* A, float* absmax, T* out, const int blocksize, const int n);

csrc/ops.hip

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,9 +56,14 @@ template <typename T, int STOCHASTIC, int DATA_TYPE> void quantizeBlockwise(floa
5656
hipLaunchKernelGGL(( kQuantizeBlockwise<T, 256, 2, 0, DATA_TYPE>), dim3(num_blocks), dim3(128), 0, 0, code, A, absmax, out, rand, rand_offset, n);
5757
else if(blocksize == 128)
5858
hipLaunchKernelGGL(( kQuantizeBlockwise<T, 128, 2, 0, DATA_TYPE>), dim3(num_blocks), dim3(64), 0, 0, code, A, absmax, out, rand, rand_offset, n);
59-
else if(blocksize == 64 && BNB_WARP_SIZE == 32)
60-
hipLaunchKernelGGL(( kQuantizeBlockwise<T, 64, 2, 0, DATA_TYPE>), dim3(num_blocks), dim3(32), 0, 0, code, A, absmax, out, rand, rand_offset, n);
61-
59+
else if(blocksize == 64) {
60+
// For 4-bit (FP4/NF4): use specialized kernel that processes 2 blocks of 64 per thread block
61+
// Works on all warp sizes (32 and 64) by using logical warps of 32
62+
if constexpr(DATA_TYPE > 0)
63+
hipLaunchKernelGGL(( kQuantizeBlockwise64<T, DATA_TYPE>), dim3((num_blocks + 1) / 2), dim3(64), 0, 0, code, A, absmax, out, rand, rand_offset, n);
64+
else
65+
hipLaunchKernelGGL(( kQuantizeBlockwise<T, 64, 2, 0, DATA_TYPE>), dim3(num_blocks), dim3(32), 0, 0, code, A, absmax, out, rand, rand_offset, n);
66+
}
6267

6368
CUDA_CHECK_RETURN(hipPeekAtLastError());
6469
}

tests/test_functional.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1098,7 +1098,7 @@ class TestQuantize4BitFunctional:
10981098
@pytest.mark.parametrize("quant_type", ["fp4", "nf4"])
10991099
@pytest.mark.parametrize(
11001100
"blocksize",
1101-
[32, 64, 128, 256, 512, 1024, 2048, 4096] if not ROCM_WARP_SIZE_64 else [128, 256, 512, 1024, 2048, 4096],
1101+
[32, 64, 128, 256, 512, 1024, 2048, 4096] if not ROCM_WARP_SIZE_64 else [64, 128, 256, 512, 1024, 2048, 4096],
11021102
)
11031103
def test_4bit_quant(self, device, dtype, quant_type, blocksize):
11041104
if device == "hpu" and not is_supported_on_hpu(quant_type, dtype):
@@ -1173,7 +1173,7 @@ def test_4bit_quant(self, device, dtype, quant_type, blocksize):
11731173
@pytest.mark.parametrize("device", get_available_devices())
11741174
@pytest.mark.parametrize("quant_type", ["fp4", "nf4"])
11751175
@pytest.mark.parametrize(
1176-
"blocksize", [32, 64, 128] if not ROCM_WARP_SIZE_64 else [128], ids=id_formatter("blocksize")
1176+
"blocksize", [32, 64, 128] if not ROCM_WARP_SIZE_64 else [64, 128], ids=id_formatter("blocksize")
11771177
)
11781178
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=describe_dtype)
11791179
def test_4bit_compressed_stats(self, device, quant_type, blocksize, dtype):
@@ -1212,7 +1212,7 @@ def test_4bit_compressed_stats(self, device, quant_type, blocksize, dtype):
12121212
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16], ids=describe_dtype)
12131213
@pytest.mark.parametrize("quant_type", ["fp4", "nf4"])
12141214
@pytest.mark.parametrize(
1215-
"blocksize", [32, 64, 128] if not ROCM_WARP_SIZE_64 else [128], ids=id_formatter("blocksize")
1215+
"blocksize", [32, 64, 128] if not ROCM_WARP_SIZE_64 else [64, 128], ids=id_formatter("blocksize")
12161216
)
12171217
def test_4bit_quant_large(self, device, dtype, quant_type, blocksize):
12181218
"""

tests/test_linear4bit.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,7 @@ def test_linear_serialization(
193193

194194
@pytest.mark.parametrize("device", get_available_devices())
195195
@pytest.mark.parametrize("quant_type", ["nf4", "fp4"])
196-
@pytest.mark.parametrize("blocksize", [32, 64, 128] if not ROCM_WARP_SIZE_64 else [128])
196+
@pytest.mark.parametrize("blocksize", [32, 64, 128] if not ROCM_WARP_SIZE_64 else [64, 128])
197197
@pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics"))
198198
def test_copy_param(device, quant_type, blocksize, compress_statistics):
199199
if device == "hpu" and not is_supported_on_hpu(quant_type):
@@ -250,7 +250,7 @@ def test_params4bit_torch_chunk_split(device, quant_type):
250250

251251
@pytest.mark.parametrize("device", get_available_devices())
252252
@pytest.mark.parametrize("quant_type", ["nf4", "fp4"])
253-
@pytest.mark.parametrize("blocksize", [32, 64, 128] if not ROCM_WARP_SIZE_64 else [128])
253+
@pytest.mark.parametrize("blocksize", [32, 64, 128] if not ROCM_WARP_SIZE_64 else [64, 128])
254254
@pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics"))
255255
def test_deepcopy_param(device, quant_type, blocksize, compress_statistics):
256256
if device == "hpu" and not is_supported_on_hpu(quant_type):
@@ -279,7 +279,7 @@ def test_deepcopy_param(device, quant_type, blocksize, compress_statistics):
279279

280280
@pytest.mark.parametrize("device", get_available_devices())
281281
@pytest.mark.parametrize("quant_type", ["nf4", "fp4"])
282-
@pytest.mark.parametrize("blocksize", [32, 64, 128] if not ROCM_WARP_SIZE_64 else [128])
282+
@pytest.mark.parametrize("blocksize", [32, 64, 128] if not ROCM_WARP_SIZE_64 else [64, 128])
283283
@pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics"))
284284
def test_params4bit_real_serialization(device, quant_type, blocksize, compress_statistics):
285285
if device == "hpu" and not is_supported_on_hpu(quant_type):

tests/test_ops.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,7 @@ class Test4bitBlockwiseQuantOps:
152152
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype"))
153153
@pytest.mark.parametrize("storage_dtype", [torch.uint8, torch.bfloat16], ids=id_formatter("storage_dtype"))
154154
@pytest.mark.parametrize("quant_type", ["fp4", "nf4"])
155-
@pytest.mark.parametrize("blocksize", [32, 64, 128, 256, 512] if not ROCM_WARP_SIZE_64 else [128, 256, 512])
155+
@pytest.mark.parametrize("blocksize", [32, 64, 128, 256, 512] if not ROCM_WARP_SIZE_64 else [64, 128, 256, 512])
156156
def test_quantize_4bit(self, device, dtype, storage_dtype, quant_type, blocksize):
157157
if device == "hpu" and not is_supported_on_hpu(quant_type, dtype, storage_dtype):
158158
pytest.skip("This configuration is not supported on HPU.")
@@ -176,7 +176,7 @@ def test_quantize_4bit(self, device, dtype, storage_dtype, quant_type, blocksize
176176
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype"))
177177
@pytest.mark.parametrize("storage_dtype", [torch.uint8, torch.bfloat16], ids=id_formatter("storage_dtype"))
178178
@pytest.mark.parametrize("quant_type", ["fp4", "nf4"])
179-
@pytest.mark.parametrize("blocksize", [32, 64, 128, 256, 512] if not ROCM_WARP_SIZE_64 else [128, 256, 512])
179+
@pytest.mark.parametrize("blocksize", [32, 64, 128, 256, 512] if not ROCM_WARP_SIZE_64 else [64, 128, 256, 512])
180180
def test_dequantize_4bit(self, device, dtype, storage_dtype, quant_type, blocksize):
181181
if device == "hpu" and not is_supported_on_hpu(quant_type, dtype, storage_dtype):
182182
pytest.skip("This configuration is not supported on HPU.")

0 commit comments

Comments
 (0)