diff --git a/bitsandbytes/backends/cuda/ops.py b/bitsandbytes/backends/cuda/ops.py index 6305fc0c7..d92f9a490 100644 --- a/bitsandbytes/backends/cuda/ops.py +++ b/bitsandbytes/backends/cuda/ops.py @@ -212,7 +212,7 @@ def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor torch._check_is_size(blocksize) if ROCM_WARP_SIZE_64: - torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128]) + torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) else: torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64, 32]) @@ -270,7 +270,7 @@ def _dequantize_blockwise_impl( A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype, out: torch.Tensor ) -> None: if ROCM_WARP_SIZE_64: - torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128]) + torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) else: torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64, 32]) @@ -304,7 +304,7 @@ def _( A: torch.Tensor, blocksize: int, quant_type: str, quant_storage: torch.dtype ) -> tuple[torch.Tensor, torch.Tensor]: if ROCM_WARP_SIZE_64: - torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128]) + torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) else: torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64, 32]) @@ -386,7 +386,7 @@ def _dequantize_4bit_impl( out: torch.Tensor, ) -> None: if ROCM_WARP_SIZE_64: - torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128]) + torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) else: torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64, 32]) diff --git a/csrc/kernels.hip b/csrc/kernels.hip index f4bbfdd79..7a70d3c6d 100644 --- a/csrc/kernels.hip +++ b/csrc/kernels.hip @@ -443,6 +443,94 @@ __global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float } } +// Specialized kernel for blocksize=64 with 4-bit quantization +// Works on both warp32 and warp64 hardware +// Processes 2 blocks of 64 values per thread block using 64 threads +// Uses logical warps of 32: threads 0-31 handle block 0, threads 32-63 handle block 1 +// - warp32: 2 hardware warps, each reduces naturally +// - warp64: 1 hardware warp split into 2 logical warps of 32 +template +__global__ void kQuantizeBlockwise64( + float* code, T* __restrict__ const A, float* absmax, unsigned char* out, float* __restrict__ const rand, + const int rand_offset, const int n +) { + constexpr int BLOCK_SIZE = 64; // Size of each quantization block + constexpr int NUM_PER_TH = 2; // Values per thread (for 4-bit packing) + constexpr int THREADS = 64; // Total threads per HIP block + constexpr int THREADS_PER_BLOCK = 32; // Threads handling each quantization block + + const int base_idx = blockIdx.x * BLOCK_SIZE * 2; // 2 quantization blocks per HIP block + + T vals[NUM_PER_TH]; + unsigned char qvals[NUM_PER_TH / 2]; // For 4-bit: 2 values per byte + float local_abs_max = 0.0f; + + const int block_id = threadIdx.x / THREADS_PER_BLOCK; // 0 for threads 0-31, 1 for threads 32-63 + const int local_thread_id = threadIdx.x % THREADS_PER_BLOCK; // Thread ID within the quantization block (0-31) + + typedef hipcub::BlockLoad LoadT; + typedef hipcub::BlockStore StoreChar; + // Logical warp size of 32: on warp32 this matches hardware warps, + // on warp64 this splits the single hardware warp into two independent reductions + typedef hipcub::WarpReduce WarpReduce; + + __shared__ typename LoadT::TempStorage loadt; + __shared__ typename StoreChar::TempStorage storec; + __shared__ typename WarpReduce::TempStorage warp_reduce[2]; // One per logical warp + __shared__ float smem_absmax_value[2]; + + const int i = base_idx + block_id * BLOCK_SIZE; + // Use a flag instead of early return: BlockLoad/BlockStore/__syncthreads are cooperative + // operations that require ALL 64 threads to participate + const bool block_valid = (i < n); + + // All 64 threads participate in the load (out-of-bounds threads get 0.0f) + __syncthreads(); + LoadT(loadt).Load(&(A[base_idx]), vals, min(BLOCK_SIZE * 2, n - base_idx), (T)0.0f); + + // Each thread computes max of its values + local_abs_max = -FLT_MAX; +#pragma unroll NUM_PER_TH + for (int j = 0; j < NUM_PER_TH; j++) + local_abs_max = fmaxf(local_abs_max, fabsf((float)vals[j])); + + // Reduce within each logical warp of 32 threads independently + local_abs_max = WarpReduce(warp_reduce[block_id]).Reduce(local_abs_max, hipcub::Max()); + + if (local_thread_id == 0) { + if (block_valid) { + smem_absmax_value[block_id] = 1.0f / local_abs_max; + absmax[blockIdx.x * 2 + block_id] = local_abs_max; + } else { + smem_absmax_value[block_id] = 0.0f; + } + } + __syncthreads(); + + local_abs_max = smem_absmax_value[block_id]; + + switch (DATA_TYPE) { + case FP4: +#pragma unroll NUM_PER_TH + for (int j = 0; j < NUM_PER_TH / 2; j++) { + qvals[j] = dQuantizeFP4(((float)vals[2 * j]) * local_abs_max) << 4; + qvals[j] |= dQuantizeFP4(((float)vals[2 * j + 1]) * local_abs_max); + } + break; + case NF4: +#pragma unroll NUM_PER_TH + for (int j = 0; j < NUM_PER_TH / 2; j++) { + qvals[j] = dQuantizeNF4(((float)vals[2 * j]) * local_abs_max) << 4; + qvals[j] |= dQuantizeNF4(((float)vals[2 * j + 1]) * local_abs_max); + } + break; + } + + // All 64 threads participate in the store (valid_items limits the actual writes) + __syncthreads(); + StoreChar(storec).Store(&(out[base_idx / 2]), qvals, min((BLOCK_SIZE * 2 + 1) / 2, (n - base_idx + 1) / 2)); +} + template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, T *out, const int blocksize, const int n) { @@ -2566,6 +2654,20 @@ MAKE_kQuantizeBlockwise(hip_bfloat16, 128, 2, 0, NF4) MAKE_kQuantizeBlockwise(hip_bfloat16, 64, 2, 0, NF4) #endif +// Specialized blocksize=64 4-bit quantization kernel instantiations for ROCm +#define MAKE_kQuantizeBlockwise64(dtype, data_type_name) \ +template __global__ void kQuantizeBlockwise64(float * code, dtype * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); + +// FP4 instantiations +MAKE_kQuantizeBlockwise64(half, FP4) +MAKE_kQuantizeBlockwise64(float, FP4) +MAKE_kQuantizeBlockwise64(hip_bfloat16, FP4) + +// NF4 instantiations +MAKE_kQuantizeBlockwise64(half, NF4) +MAKE_kQuantizeBlockwise64(float, NF4) +MAKE_kQuantizeBlockwise64(hip_bfloat16, NF4) + template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, half *out, const int blocksize, const int n); template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, half *out, const int blocksize, const int n); template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, half *out, const int blocksize, const int n); diff --git a/csrc/kernels_hip.cuh b/csrc/kernels_hip.cuh index 1430d6441..318e87ed9 100644 --- a/csrc/kernels_hip.cuh +++ b/csrc/kernels_hip.cuh @@ -19,6 +19,11 @@ __global__ void kQuantizeBlockwise( float* code, T* __restrict__ const A, float* absmax, unsigned char* out, float* __restrict__ const rand, const int rand_offset, const int n ); +template +__global__ void kQuantizeBlockwise64( + float* code, T* __restrict__ const A, float* absmax, unsigned char* out, float* __restrict__ const rand, + const int rand_offset, const int n +); template __global__ void kDequantizeBlockwise(float* code, unsigned char* A, float* absmax, T* out, const int blocksize, const int n); diff --git a/csrc/ops.hip b/csrc/ops.hip index dc3dc091e..1017f58ca 100644 --- a/csrc/ops.hip +++ b/csrc/ops.hip @@ -56,9 +56,14 @@ template void quantizeBlockwise(floa hipLaunchKernelGGL(( kQuantizeBlockwise), dim3(num_blocks), dim3(128), 0, 0, code, A, absmax, out, rand, rand_offset, n); else if(blocksize == 128) hipLaunchKernelGGL(( kQuantizeBlockwise), dim3(num_blocks), dim3(64), 0, 0, code, A, absmax, out, rand, rand_offset, n); - else if(blocksize == 64 && BNB_WARP_SIZE == 32) - hipLaunchKernelGGL(( kQuantizeBlockwise), dim3(num_blocks), dim3(32), 0, 0, code, A, absmax, out, rand, rand_offset, n); - + else if(blocksize == 64) { + // For 4-bit (FP4/NF4): use specialized kernel that processes 2 blocks of 64 per thread block + // Works on all warp sizes (32 and 64) by using logical warps of 32 + if constexpr(DATA_TYPE > 0) + hipLaunchKernelGGL(( kQuantizeBlockwise64), dim3((num_blocks + 1) / 2), dim3(64), 0, 0, code, A, absmax, out, rand, rand_offset, n); + else + hipLaunchKernelGGL(( kQuantizeBlockwise), dim3(num_blocks), dim3(32), 0, 0, code, A, absmax, out, rand, rand_offset, n); + } CUDA_CHECK_RETURN(hipPeekAtLastError()); } diff --git a/tests/test_functional.py b/tests/test_functional.py index 7de72ebc3..26a1a95aa 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -1098,7 +1098,7 @@ class TestQuantize4BitFunctional: @pytest.mark.parametrize("quant_type", ["fp4", "nf4"]) @pytest.mark.parametrize( "blocksize", - [32, 64, 128, 256, 512, 1024, 2048, 4096] if not ROCM_WARP_SIZE_64 else [128, 256, 512, 1024, 2048, 4096], + [32, 64, 128, 256, 512, 1024, 2048, 4096] if not ROCM_WARP_SIZE_64 else [64, 128, 256, 512, 1024, 2048, 4096], ) def test_4bit_quant(self, device, dtype, quant_type, blocksize): if device == "hpu" and not is_supported_on_hpu(quant_type, dtype): @@ -1173,7 +1173,7 @@ def test_4bit_quant(self, device, dtype, quant_type, blocksize): @pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("quant_type", ["fp4", "nf4"]) @pytest.mark.parametrize( - "blocksize", [32, 64, 128] if not ROCM_WARP_SIZE_64 else [128], ids=id_formatter("blocksize") + "blocksize", [32, 64, 128] if not ROCM_WARP_SIZE_64 else [64, 128], ids=id_formatter("blocksize") ) @pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=describe_dtype) def test_4bit_compressed_stats(self, device, quant_type, blocksize, dtype): @@ -1212,7 +1212,7 @@ def test_4bit_compressed_stats(self, device, quant_type, blocksize, dtype): @pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16], ids=describe_dtype) @pytest.mark.parametrize("quant_type", ["fp4", "nf4"]) @pytest.mark.parametrize( - "blocksize", [32, 64, 128] if not ROCM_WARP_SIZE_64 else [128], ids=id_formatter("blocksize") + "blocksize", [32, 64, 128] if not ROCM_WARP_SIZE_64 else [64, 128], ids=id_formatter("blocksize") ) def test_4bit_quant_large(self, device, dtype, quant_type, blocksize): """ diff --git a/tests/test_linear4bit.py b/tests/test_linear4bit.py index aa693713c..1bf6374dd 100644 --- a/tests/test_linear4bit.py +++ b/tests/test_linear4bit.py @@ -193,7 +193,7 @@ def test_linear_serialization( @pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("quant_type", ["nf4", "fp4"]) -@pytest.mark.parametrize("blocksize", [32, 64, 128] if not ROCM_WARP_SIZE_64 else [128]) +@pytest.mark.parametrize("blocksize", [32, 64, 128] if not ROCM_WARP_SIZE_64 else [64, 128]) @pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics")) def test_copy_param(device, quant_type, blocksize, compress_statistics): if device == "hpu" and not is_supported_on_hpu(quant_type): @@ -250,7 +250,7 @@ def test_params4bit_torch_chunk_split(device, quant_type): @pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("quant_type", ["nf4", "fp4"]) -@pytest.mark.parametrize("blocksize", [32, 64, 128] if not ROCM_WARP_SIZE_64 else [128]) +@pytest.mark.parametrize("blocksize", [32, 64, 128] if not ROCM_WARP_SIZE_64 else [64, 128]) @pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics")) def test_deepcopy_param(device, quant_type, blocksize, compress_statistics): if device == "hpu" and not is_supported_on_hpu(quant_type): @@ -279,7 +279,7 @@ def test_deepcopy_param(device, quant_type, blocksize, compress_statistics): @pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("quant_type", ["nf4", "fp4"]) -@pytest.mark.parametrize("blocksize", [32, 64, 128] if not ROCM_WARP_SIZE_64 else [128]) +@pytest.mark.parametrize("blocksize", [32, 64, 128] if not ROCM_WARP_SIZE_64 else [64, 128]) @pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics")) def test_params4bit_real_serialization(device, quant_type, blocksize, compress_statistics): if device == "hpu" and not is_supported_on_hpu(quant_type): diff --git a/tests/test_ops.py b/tests/test_ops.py index aa20995ee..5f780f2ac 100644 --- a/tests/test_ops.py +++ b/tests/test_ops.py @@ -152,7 +152,7 @@ class Test4bitBlockwiseQuantOps: @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype")) @pytest.mark.parametrize("storage_dtype", [torch.uint8, torch.bfloat16], ids=id_formatter("storage_dtype")) @pytest.mark.parametrize("quant_type", ["fp4", "nf4"]) - @pytest.mark.parametrize("blocksize", [32, 64, 128, 256, 512] if not ROCM_WARP_SIZE_64 else [128, 256, 512]) + @pytest.mark.parametrize("blocksize", [32, 64, 128, 256, 512] if not ROCM_WARP_SIZE_64 else [64, 128, 256, 512]) def test_quantize_4bit(self, device, dtype, storage_dtype, quant_type, blocksize): if device == "hpu" and not is_supported_on_hpu(quant_type, dtype, storage_dtype): pytest.skip("This configuration is not supported on HPU.") @@ -176,7 +176,7 @@ def test_quantize_4bit(self, device, dtype, storage_dtype, quant_type, blocksize @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype")) @pytest.mark.parametrize("storage_dtype", [torch.uint8, torch.bfloat16], ids=id_formatter("storage_dtype")) @pytest.mark.parametrize("quant_type", ["fp4", "nf4"]) - @pytest.mark.parametrize("blocksize", [32, 64, 128, 256, 512] if not ROCM_WARP_SIZE_64 else [128, 256, 512]) + @pytest.mark.parametrize("blocksize", [32, 64, 128, 256, 512] if not ROCM_WARP_SIZE_64 else [64, 128, 256, 512]) def test_dequantize_4bit(self, device, dtype, storage_dtype, quant_type, blocksize): if device == "hpu" and not is_supported_on_hpu(quant_type, dtype, storage_dtype): pytest.skip("This configuration is not supported on HPU.")