Skip to content

Commit fe784a7

Browse files
authored
Fix xpu 4bit kernel (#1839)
* Fix xpu 4bit kernel Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix xpu kernel without api change Signed-off-by: jiqing-feng <jiqing.feng@intel.com> --------- Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
1 parent 31610c9 commit fe784a7

File tree

4 files changed

+50
-21
lines changed

4 files changed

+50
-21
lines changed

bitsandbytes/backends/triton/kernels_4bit.py

Lines changed: 36 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,8 @@ def quantize_fp4_blockwise_kernel(
6666

6767
packed_flat = tl.reshape(packed, (BLOCK_SIZE * SPLIT_NUM_BLOCKS,))
6868
out_offsets = block_start_idx * BLOCK_SIZE // 2 + tl.arange(0, SPLIT_NUM_BLOCKS * BLOCK_SIZE)
69-
out_mask = out_offsets < n_elements // 2
69+
# Use n - n//2 instead of (n+1)//2 to avoid integer overflow for large n
70+
out_mask = out_offsets < (n_elements - n_elements // 2)
7071
tl.store(out_ptr + out_offsets, packed_flat, mask=out_mask)
7172

7273

@@ -148,7 +149,8 @@ def quantize_nf4_blockwise_kernel(
148149

149150
packed_flat = tl.reshape(packed, (BLOCK_SIZE * SPLIT_NUM_BLOCKS,))
150151
out_offsets = block_start_idx * BLOCK_SIZE // 2 + tl.arange(0, SPLIT_NUM_BLOCKS * BLOCK_SIZE)
151-
out_mask = out_offsets < n_elements // 2
152+
# Use n - n//2 instead of (n+1)//2 to avoid integer overflow for large n
153+
out_mask = out_offsets < (n_elements - n_elements // 2)
152154
tl.store(out_ptr + out_offsets, packed_flat, mask=out_mask)
153155

154156

@@ -330,7 +332,14 @@ def dequant_nf4_body_util(a, offsets, absmax_ptr, n_elems, QUANT_BLOCK: tl.const
330332
# )
331333
@triton.jit
332334
def dequant_4bit_kernel(
333-
a_ptr, c_ptr, quant_ptr, absmax_ptr, num_paired_elements, QUANT_BLOCK: tl.constexpr, SPLIT_SIZE: tl.constexpr
335+
a_ptr,
336+
c_ptr,
337+
quant_ptr,
338+
absmax_ptr,
339+
num_paired_elements,
340+
num_output_elements,
341+
QUANT_BLOCK: tl.constexpr,
342+
SPLIT_SIZE: tl.constexpr,
334343
):
335344
pid = tl.program_id(axis=0) # We use a 1D launch grid so axis is 0.
336345
block_start = pid * SPLIT_SIZE
@@ -350,7 +359,7 @@ def dequant_4bit_kernel(
350359

351360
out_block_start = pid * SPLIT_SIZE * 2
352361
offs = out_block_start + tl.arange(0, SPLIT_SIZE * 2)
353-
mask = offs < num_paired_elements * 2
362+
mask = offs < num_output_elements
354363
tl.store(c_ptr + offs, out_dq, mask)
355364

356365

@@ -367,7 +376,13 @@ def dequant_4bit_kernel(
367376
# )
368377
@triton.jit
369378
def dequant_fp4_kernel(
370-
a_ptr, c_ptr, absmax_ptr, num_paired_elements, QUANT_BLOCK: tl.constexpr, SPLIT_SIZE: tl.constexpr
379+
a_ptr,
380+
c_ptr,
381+
absmax_ptr,
382+
num_paired_elements,
383+
num_output_elements,
384+
QUANT_BLOCK: tl.constexpr,
385+
SPLIT_SIZE: tl.constexpr,
371386
):
372387
pid = tl.program_id(axis=0) # We use a 1D launch grid so axis is 0.
373388
block_start = pid * SPLIT_SIZE
@@ -386,7 +401,7 @@ def dequant_fp4_kernel(
386401

387402
out_block_start = pid * SPLIT_SIZE * 2
388403
offs = out_block_start + tl.arange(0, SPLIT_SIZE * 2)
389-
mask = offs < num_paired_elements * 2
404+
mask = offs < num_output_elements
390405
tl.store(c_ptr + offs, out_dq, mask)
391406

392407

@@ -403,7 +418,13 @@ def dequant_fp4_kernel(
403418
# )
404419
@triton.jit
405420
def dequant_nf4_kernel(
406-
a_ptr, c_ptr, absmax_ptr, num_paired_elements, QUANT_BLOCK: tl.constexpr, SPLIT_SIZE: tl.constexpr
421+
a_ptr,
422+
c_ptr,
423+
absmax_ptr,
424+
num_paired_elements,
425+
num_output_elements,
426+
QUANT_BLOCK: tl.constexpr,
427+
SPLIT_SIZE: tl.constexpr,
407428
):
408429
pid = tl.program_id(axis=0) # We use a 1D launch grid so axis is 0.
409430
block_start = pid * SPLIT_SIZE
@@ -422,7 +443,7 @@ def dequant_nf4_kernel(
422443

423444
out_block_start = pid * SPLIT_SIZE * 2
424445
offs = out_block_start + tl.arange(0, SPLIT_SIZE * 2)
425-
mask = offs < num_paired_elements * 2
446+
mask = offs < num_output_elements
426447
tl.store(c_ptr + offs, out_dq, mask)
427448

428449

@@ -439,15 +460,16 @@ def dequantize_4bit_impl(
439460
# Elements are in uint8 format, so interleaved
440461
# so total amount of data is 2 * elem_count
441462
number_of_paired_elements = A.numel()
463+
num_output_elements = out.numel()
442464
# we assume that split_size > quant_blocksize
443465

444466
SPLIT_SIZE = 256
445467
# grid = lambda META: (triton.cdiv(number_of_paired_elements, META['SPLIT_SIZE']), )
446468
grid = (triton.cdiv(number_of_paired_elements, SPLIT_SIZE),)
447469
if quant_type == "fp4":
448-
dequant_fp4_kernel[grid](A, out, absmax, number_of_paired_elements, blocksize, SPLIT_SIZE)
470+
dequant_fp4_kernel[grid](A, out, absmax, number_of_paired_elements, num_output_elements, blocksize, SPLIT_SIZE)
449471
else:
450-
dequant_nf4_kernel[grid](A, out, absmax, number_of_paired_elements, blocksize, SPLIT_SIZE)
472+
dequant_nf4_kernel[grid](A, out, absmax, number_of_paired_elements, num_output_elements, blocksize, SPLIT_SIZE)
451473

452474

453475
def dequantize_4bit_impl_passing_code(
@@ -459,12 +481,15 @@ def dequantize_4bit_impl_passing_code(
459481
out: torch.Tensor,
460482
) -> None:
461483
number_of_paired_elements = A.numel()
484+
num_output_elements = out.numel()
462485
# we assume that split_size > quant_blocksize
463486

464487
SPLIT_SIZE = 256
465488
# grid = lambda META: (triton.cdiv(number_of_paired_elements, META['SPLIT_SIZE']), )
466489
grid = (triton.cdiv(number_of_paired_elements, SPLIT_SIZE),)
467-
dequant_4bit_kernel[grid](A, out, code, absmax, number_of_paired_elements, blocksize, SPLIT_SIZE)
490+
dequant_4bit_kernel[grid](
491+
A, out, code, absmax, number_of_paired_elements, num_output_elements, blocksize, SPLIT_SIZE
492+
)
468493

469494

470495
######################### Fallback dequantization functions #########################

bitsandbytes/backends/triton/ops.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,8 @@ def quantize_4bit(
8282
blocks = -(n // -(blocksize * 2))
8383

8484
absmax = torch.empty((blocks * 2,), device=A.device, dtype=A.dtype)
85-
out = torch.empty((n // 2, 1), device=A.device, dtype=torch.uint8)
85+
# Use n - n//2 instead of (n+1)//2 to avoid integer overflow for large n
86+
out = torch.empty((n - n // 2, 1), device=A.device, dtype=torch.uint8)
8687

8788
with torch_accelerator_module.device(A.device):
8889
kernels_4bit.quantize_4bit_blockwise_triton(

csrc/xpu_kernels.cpp

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -95,20 +95,21 @@ inline float dDequantizeNF4(unsigned char val) {
9595

9696
template <typename T, int TILE_SIZE, int NUM_PER_TH, int DATA_TYPE>
9797
SYCL_EXTERNAL void kDequantizeBlockwise<T, TILE_SIZE, NUM_PER_TH, DATA_TYPE>::operator()(sycl::nd_item<1> item) const {
98-
const int base_idx = item.get_group(0) * TILE_SIZE;
99-
size_t local_idx = item.get_local_id(0) * NUM_PER_TH;
98+
const int64_t base_idx = static_cast<int64_t>(item.get_group(0)) * TILE_SIZE;
99+
int64_t local_idx = static_cast<int64_t>(item.get_local_id(0)) * NUM_PER_TH;
100100
float local_abs_max = -FLT_MAX;
101-
int local_load_idx = 0;
102-
int local_store_idx = 0;
101+
int64_t local_load_idx = 0;
102+
int64_t local_store_idx = 0;
103103

104104
uint8_t qvals[NUM_PER_TH];
105105
T vals[NUM_PER_TH * ((DATA_TYPE > 0) ? 2 : 1)];
106106

107107
if (DATA_TYPE > 0) {
108-
local_load_idx = sycl::min(TILE_SIZE, (n + 1) / 2 - base_idx);
109-
local_store_idx = sycl::min(TILE_SIZE * 2, n - base_idx * 2);
108+
// Cast n to int64_t to avoid overflow for large n (same as CUDA)
109+
local_load_idx = sycl::min(static_cast<int64_t>(TILE_SIZE), (static_cast<int64_t>(n) + 1) / 2 - base_idx);
110+
local_store_idx = sycl::min(static_cast<int64_t>(TILE_SIZE * 2), static_cast<int64_t>(n) - base_idx * 2);
110111
} else {
111-
local_load_idx = sycl::min(TILE_SIZE, n - base_idx);
112+
local_load_idx = sycl::min(static_cast<int64_t>(TILE_SIZE), static_cast<int64_t>(n) - base_idx);
112113
local_store_idx = local_load_idx;
113114
}
114115

csrc/xpu_ops.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,15 +10,17 @@ void dequantizeBlockwise(
1010
const int num_per_th = 4;
1111
const int tile_size = workgroup_size * num_per_th;
1212
if (DATA_TYPE > 0) {
13-
const int workgroup_num = (n + tile_size * 2 - 1) / (tile_size * 2);
13+
// Upcast to int64 to avoid overflow for large n (same as CUDA)
14+
const int workgroup_num = (static_cast<int64_t>(n) + tile_size * 2 - 1) / (tile_size * 2);
1415
sycl::range<1> local_range{(size_t)workgroup_size};
1516
sycl::range<1> global_range{(size_t)workgroup_num * (size_t)workgroup_size};
1617
kDequantizeBlockwise<T, tile_size, num_per_th, DATA_TYPE> kfn(code, A, absmax, out, blocksize / 2, n);
1718
sycl_kernel_submit<decltype(kfn), 1, 32>(
1819
sycl::nd_range<1>(sycl::range<1>(global_range), sycl::range<1>(local_range)), queue, kfn
1920
);
2021
} else {
21-
const int workgroup_num = (n + tile_size - 1) / tile_size;
22+
// Upcast to int64 to avoid overflow for large n (same as CUDA)
23+
const int workgroup_num = (static_cast<int64_t>(n) + tile_size - 1) / tile_size;
2224
sycl::range<1> local_range{(size_t)workgroup_size};
2325
sycl::range<1> global_range{(size_t)workgroup_num * (size_t)workgroup_size};
2426
kDequantizeBlockwise<T, tile_size, num_per_th, DATA_TYPE> kfn(code, A, absmax, out, blocksize, n);

0 commit comments

Comments
 (0)