@@ -66,7 +66,8 @@ def quantize_fp4_blockwise_kernel(
6666
6767 packed_flat = tl .reshape (packed , (BLOCK_SIZE * SPLIT_NUM_BLOCKS ,))
6868 out_offsets = block_start_idx * BLOCK_SIZE // 2 + tl .arange (0 , SPLIT_NUM_BLOCKS * BLOCK_SIZE )
69- out_mask = out_offsets < n_elements // 2
69+ # Use n - n//2 instead of (n+1)//2 to avoid integer overflow for large n
70+ out_mask = out_offsets < (n_elements - n_elements // 2 )
7071 tl .store (out_ptr + out_offsets , packed_flat , mask = out_mask )
7172
7273
@@ -148,7 +149,8 @@ def quantize_nf4_blockwise_kernel(
148149
149150 packed_flat = tl .reshape (packed , (BLOCK_SIZE * SPLIT_NUM_BLOCKS ,))
150151 out_offsets = block_start_idx * BLOCK_SIZE // 2 + tl .arange (0 , SPLIT_NUM_BLOCKS * BLOCK_SIZE )
151- out_mask = out_offsets < n_elements // 2
152+ # Use n - n//2 instead of (n+1)//2 to avoid integer overflow for large n
153+ out_mask = out_offsets < (n_elements - n_elements // 2 )
152154 tl .store (out_ptr + out_offsets , packed_flat , mask = out_mask )
153155
154156
@@ -330,7 +332,14 @@ def dequant_nf4_body_util(a, offsets, absmax_ptr, n_elems, QUANT_BLOCK: tl.const
330332# )
331333@triton .jit
332334def dequant_4bit_kernel (
333- a_ptr , c_ptr , quant_ptr , absmax_ptr , num_paired_elements , QUANT_BLOCK : tl .constexpr , SPLIT_SIZE : tl .constexpr
335+ a_ptr ,
336+ c_ptr ,
337+ quant_ptr ,
338+ absmax_ptr ,
339+ num_paired_elements ,
340+ num_output_elements ,
341+ QUANT_BLOCK : tl .constexpr ,
342+ SPLIT_SIZE : tl .constexpr ,
334343):
335344 pid = tl .program_id (axis = 0 ) # We use a 1D launch grid so axis is 0.
336345 block_start = pid * SPLIT_SIZE
@@ -350,7 +359,7 @@ def dequant_4bit_kernel(
350359
351360 out_block_start = pid * SPLIT_SIZE * 2
352361 offs = out_block_start + tl .arange (0 , SPLIT_SIZE * 2 )
353- mask = offs < num_paired_elements * 2
362+ mask = offs < num_output_elements
354363 tl .store (c_ptr + offs , out_dq , mask )
355364
356365
@@ -367,7 +376,13 @@ def dequant_4bit_kernel(
367376# )
368377@triton .jit
369378def dequant_fp4_kernel (
370- a_ptr , c_ptr , absmax_ptr , num_paired_elements , QUANT_BLOCK : tl .constexpr , SPLIT_SIZE : tl .constexpr
379+ a_ptr ,
380+ c_ptr ,
381+ absmax_ptr ,
382+ num_paired_elements ,
383+ num_output_elements ,
384+ QUANT_BLOCK : tl .constexpr ,
385+ SPLIT_SIZE : tl .constexpr ,
371386):
372387 pid = tl .program_id (axis = 0 ) # We use a 1D launch grid so axis is 0.
373388 block_start = pid * SPLIT_SIZE
@@ -386,7 +401,7 @@ def dequant_fp4_kernel(
386401
387402 out_block_start = pid * SPLIT_SIZE * 2
388403 offs = out_block_start + tl .arange (0 , SPLIT_SIZE * 2 )
389- mask = offs < num_paired_elements * 2
404+ mask = offs < num_output_elements
390405 tl .store (c_ptr + offs , out_dq , mask )
391406
392407
@@ -403,7 +418,13 @@ def dequant_fp4_kernel(
403418# )
404419@triton .jit
405420def dequant_nf4_kernel (
406- a_ptr , c_ptr , absmax_ptr , num_paired_elements , QUANT_BLOCK : tl .constexpr , SPLIT_SIZE : tl .constexpr
421+ a_ptr ,
422+ c_ptr ,
423+ absmax_ptr ,
424+ num_paired_elements ,
425+ num_output_elements ,
426+ QUANT_BLOCK : tl .constexpr ,
427+ SPLIT_SIZE : tl .constexpr ,
407428):
408429 pid = tl .program_id (axis = 0 ) # We use a 1D launch grid so axis is 0.
409430 block_start = pid * SPLIT_SIZE
@@ -422,7 +443,7 @@ def dequant_nf4_kernel(
422443
423444 out_block_start = pid * SPLIT_SIZE * 2
424445 offs = out_block_start + tl .arange (0 , SPLIT_SIZE * 2 )
425- mask = offs < num_paired_elements * 2
446+ mask = offs < num_output_elements
426447 tl .store (c_ptr + offs , out_dq , mask )
427448
428449
@@ -439,15 +460,16 @@ def dequantize_4bit_impl(
439460 # Elements are in uint8 format, so interleaved
440461 # so total amount of data is 2 * elem_count
441462 number_of_paired_elements = A .numel ()
463+ num_output_elements = out .numel ()
442464 # we assume that split_size > quant_blocksize
443465
444466 SPLIT_SIZE = 256
445467 # grid = lambda META: (triton.cdiv(number_of_paired_elements, META['SPLIT_SIZE']), )
446468 grid = (triton .cdiv (number_of_paired_elements , SPLIT_SIZE ),)
447469 if quant_type == "fp4" :
448- dequant_fp4_kernel [grid ](A , out , absmax , number_of_paired_elements , blocksize , SPLIT_SIZE )
470+ dequant_fp4_kernel [grid ](A , out , absmax , number_of_paired_elements , num_output_elements , blocksize , SPLIT_SIZE )
449471 else :
450- dequant_nf4_kernel [grid ](A , out , absmax , number_of_paired_elements , blocksize , SPLIT_SIZE )
472+ dequant_nf4_kernel [grid ](A , out , absmax , number_of_paired_elements , num_output_elements , blocksize , SPLIT_SIZE )
451473
452474
453475def dequantize_4bit_impl_passing_code (
@@ -459,12 +481,15 @@ def dequantize_4bit_impl_passing_code(
459481 out : torch .Tensor ,
460482) -> None :
461483 number_of_paired_elements = A .numel ()
484+ num_output_elements = out .numel ()
462485 # we assume that split_size > quant_blocksize
463486
464487 SPLIT_SIZE = 256
465488 # grid = lambda META: (triton.cdiv(number_of_paired_elements, META['SPLIT_SIZE']), )
466489 grid = (triton .cdiv (number_of_paired_elements , SPLIT_SIZE ),)
467- dequant_4bit_kernel [grid ](A , out , code , absmax , number_of_paired_elements , blocksize , SPLIT_SIZE )
490+ dequant_4bit_kernel [grid ](
491+ A , out , code , absmax , number_of_paired_elements , num_output_elements , blocksize , SPLIT_SIZE
492+ )
468493
469494
470495######################### Fallback dequantization functions #########################
0 commit comments