@@ -374,69 +374,72 @@ __global__ void kQuantizeBlockwise(
374374 }
375375}
376376
377- // Unified small-blocksize kernel for 4-bit quantization
378- // Processes 2 blocks of BNB_WARP_SIZE values per thread block
379- // On CUDA (warp=32): blocksize=32, 32 threads, WarpReduce<16>
380- // On HIP (warp=64): blocksize=64, 64 threads, WarpReduce<32>
381- // On HIP (warp=32): blocksize=32, 32 threads, WarpReduce<16>
382- template <typename T, int DATA_TYPE>
377+ // Small-blocksize kernel for 4-bit quantization, parameterized on quantization
378+ // block size (QBLOCK_SIZE). Always launches exactly BNB_WARP_SIZE threads so
379+ // every lane in the wavefront is productive. Multiple quantization blocks are
380+ // packed into one wavefront when QBLOCK_SIZE < BNB_WARP_SIZE * NUM_PER_TH:
381+ //
382+ // CDNA (64), QBLOCK_SIZE=32 -> 4 quant blocks per wavefront
383+ // CDNA (64), QBLOCK_SIZE=64 -> 2 quant blocks per wavefront
384+ // CUDA/RDNA (32), QBLOCK_SIZE=32 -> 2 quant blocks per wavefront
385+ //
386+ // Uses logical-warp WarpReduce<THREADS_PER_QB> so each quantization block's
387+ // threads reduce independently via warp shuffles.
388+ template <typename T, int QBLOCK_SIZE, int DATA_TYPE>
383389__global__ void kQuantizeBlockwiseSmall (
384390 float * code, T* __restrict__ const A, float * absmax, unsigned char * out, float * __restrict__ const rand,
385391 const int rand_offset, const int n
386392) {
387- constexpr int BLOCK_SIZE = BNB_WARP_SIZE; // Size of each quantization block
388- constexpr int NUM_PER_TH = 2 ; // Values per thread (for 4-bit packing)
389- constexpr int THREADS = BNB_WARP_SIZE; // Total threads (one full warp)
390- constexpr int THREADS_PER_BLOCK = BNB_WARP_SIZE / 2 ; // Half-warp per quantization block
393+ static_assert (QBLOCK_SIZE <= BNB_WARP_SIZE * 2 , " QBLOCK_SIZE too large for one warp" );
391394
392- const int base_idx = blockIdx .x * BLOCK_SIZE * 2 ; // 2 blocks per thread block
395+ constexpr int NUM_PER_TH = 2 ;
396+ constexpr int THREADS = BNB_WARP_SIZE;
397+ constexpr int THREADS_PER_QB = QBLOCK_SIZE / NUM_PER_TH;
398+ constexpr int NUM_QB = THREADS / THREADS_PER_QB;
399+ constexpr int TOTAL_VALUES = QBLOCK_SIZE * NUM_QB;
400+
401+ const int base_idx = blockIdx .x * TOTAL_VALUES;
393402
394403 T vals[NUM_PER_TH];
395- unsigned char qvals[NUM_PER_TH / 2 ]; // For 4-bit: 2 values per byte
404+ unsigned char qvals[NUM_PER_TH / 2 ];
396405 float local_abs_max = 0 .0f ;
397406
398- const int block_id = threadIdx .x / THREADS_PER_BLOCK; // 0 for threads 0-15, 1 for threads 16-31
399- const int local_thread_id = threadIdx .x % THREADS_PER_BLOCK; // Thread ID within the block (0-15)
407+ const int qb_id = threadIdx .x / THREADS_PER_QB;
408+ const int local_tid = threadIdx .x % THREADS_PER_QB;
400409
401410 typedef bnb_cub::BlockLoad<T, THREADS, NUM_PER_TH, bnb_cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadT;
402411 typedef bnb_cub::BlockStore<unsigned char , THREADS, NUM_PER_TH / 2 , bnb_cub::BLOCK_STORE_WARP_TRANSPOSE> StoreChar;
403- typedef bnb_cub::WarpReduce<float , THREADS_PER_BLOCK>
404- WarpReduce; // Half-warp logical reduction: each half reduces independently
412+ typedef bnb_cub::WarpReduce<float , THREADS_PER_QB> WarpReduce;
405413
406414 __shared__ typename LoadT::TempStorage loadt;
407415 __shared__ typename StoreChar::TempStorage storec;
408- __shared__ typename WarpReduce::TempStorage warp_reduce[2 ]; // One per logical warp
409- __shared__ float smem_absmax_value[2 ];
416+ __shared__ typename WarpReduce::TempStorage warp_reduce[NUM_QB];
417+ __shared__ float smem_absmax_value[NUM_QB ];
410418
411- const int i = base_idx + block_id * BLOCK_SIZE;
412- // Use a flag instead of early return: BlockLoad/BlockStore/__syncthreads are cooperative
413- // operations that require ALL 32 threads to participate
414- const bool block_valid = (i < n);
419+ const int qi = base_idx + qb_id * QBLOCK_SIZE;
420+ const bool qb_valid = (qi < n);
415421
416- // All 32 threads participate in the load (out-of-bounds threads get 0.0f)
417422 __syncthreads ();
418- LoadT (loadt).Load (&(A[base_idx]), vals, min (BLOCK_SIZE * 2 , n - base_idx), (T)0 .0f );
423+ LoadT (loadt).Load (&(A[base_idx]), vals, min (TOTAL_VALUES , n - base_idx), (T)0 .0f );
419424
420- // Each thread computes max of its values
421425 local_abs_max = -FLT_MAX;
422426#pragma unroll NUM_PER_TH
423427 for (int j = 0 ; j < NUM_PER_TH; j++)
424428 local_abs_max = fmaxf (local_abs_max, fabsf ((float )vals[j]));
425429
426- // Reduce within each logical warp of 16 threads independently
427- local_abs_max = WarpReduce (warp_reduce[block_id]).Reduce (local_abs_max, BNB_MAX_OP);
430+ local_abs_max = WarpReduce (warp_reduce[qb_id]).Reduce (local_abs_max, BNB_MAX_OP);
428431
429- if (local_thread_id == 0 ) {
430- if (block_valid ) {
431- smem_absmax_value[block_id ] = 1 .0f / local_abs_max;
432- absmax[blockIdx .x * 2 + block_id ] = local_abs_max;
432+ if (local_tid == 0 ) {
433+ if (qb_valid ) {
434+ smem_absmax_value[qb_id ] = 1 .0f / local_abs_max;
435+ absmax[blockIdx .x * NUM_QB + qb_id ] = local_abs_max;
433436 } else {
434- smem_absmax_value[block_id ] = 0 .0f ;
437+ smem_absmax_value[qb_id ] = 0 .0f ;
435438 }
436439 }
437440 __syncthreads ();
438441
439- local_abs_max = smem_absmax_value[block_id ];
442+ local_abs_max = smem_absmax_value[qb_id ];
440443
441444 switch (DATA_TYPE) {
442445 case FP4:
@@ -455,9 +458,8 @@ __global__ void kQuantizeBlockwiseSmall(
455458 break ;
456459 }
457460
458- // All 32 threads participate in the store (valid_items limits the actual writes)
459461 __syncthreads ();
460- StoreChar (storec).Store (&(out[base_idx / 2 ]), qvals, min ((BLOCK_SIZE * 2 + 1 ) / 2 , (n - base_idx + 1 ) / 2 ));
462+ StoreChar (storec).Store (&(out[base_idx / 2 ]), qvals, min ((TOTAL_VALUES + 1 ) / 2 , (n - base_idx + 1 ) / 2 ));
461463}
462464
463465template <typename T, int TILE_SIZE, int THREADS, int NUM_PER_TH, int DATA_TYPE>
@@ -1446,15 +1448,15 @@ __global__ void kgemm_4bit_inference_naive(
14461448) {
14471449
14481450 // per threadblock:
1449- // load step-by-step in chunks of [32 ,warps]: 1x32 * [32 ,warps] -> [1,warps]
1450- // 4 warps -> 4 loads per iter
1451- // 1x32 * 32x4 -> 1x4 outputs per thread block
1451+ // load step-by-step in chunks of [warp_size ,warps]: 1xwarp_size * [warp_size ,warps] -> [1,warps]
1452+ // THREADS/BNB_WARP_SIZE warps -> that many loads per iter
1453+ // 1xwarp_size * warp_size x warps -> 1 x warps outputs per thread block
14521454 typedef bnb_cub::WarpReduce<float > WarpReduce;
1453- __shared__ typename WarpReduce::TempStorage temp_storage[THREADS / 32 ];
1455+ __shared__ typename WarpReduce::TempStorage temp_storage[THREADS / BNB_WARP_SIZE ];
14541456
1455- const int warp_idx = threadIdx .x / 32 ;
1456- const int warp_lane = threadIdx .x % 32 ;
1457- const int row_B = (THREADS / 32 ) * blockIdx .x + warp_idx;
1457+ const int warp_idx = threadIdx .x / BNB_WARP_SIZE ;
1458+ const int warp_lane = threadIdx .x % BNB_WARP_SIZE ;
1459+ const int row_B = (THREADS / BNB_WARP_SIZE ) * blockIdx .x + warp_idx;
14581460 const int offset_B = ldb * row_B;
14591461 const int num_values_8bit = num_values_4bit / 2 ;
14601462 float local_C = 0 .0f ;
@@ -1473,7 +1475,7 @@ __global__ void kgemm_4bit_inference_naive(
14731475
14741476 // A: [1, K]
14751477 // B: [N, K]
1476- for (int inner_idx = warp_lane * num_values_4bit; inner_idx < K; inner_idx += 32 * num_values_4bit) {
1478+ for (int inner_idx = warp_lane * num_values_4bit; inner_idx < K; inner_idx += BNB_WARP_SIZE * num_values_4bit) {
14771479 const int inner_idx_halved = inner_idx / 2 ;
14781480
14791481 // Since blocksize will always be a power-of-2, we avoid more expensive
@@ -1766,22 +1768,28 @@ MAKE_kQuantizeBlockwise(bnb_bfloat16, 256, 2, 0, NF4)
17661768MAKE_kQuantizeBlockwise(bnb_bfloat16, 128 , 2 , 0 , NF4)
17671769MAKE_kQuantizeBlockwise(bnb_bfloat16, 64 , 2 , 0 , NF4)
17681770
1769- // Template instantiations for blocksize=32 specialized kernel (4-bit only)
1770- #define MAKE_kQuantizeBlockwiseSmall (dtype, data_type_name ) \
1771- template __global__ void kQuantizeBlockwiseSmall <dtype, data_type_name>( \
1771+ // Template instantiations for kQuantizeBlockwiseSmall (4-bit only)
1772+ #define MAKE_kQuantizeBlockwiseSmall (dtype, qblock_size, data_type_name ) \
1773+ template __global__ void kQuantizeBlockwiseSmall <dtype, qblock_size, data_type_name>( \
17721774 float * code, dtype* __restrict__ const A, float * absmax, unsigned char * out, float * __restrict__ const rand, \
17731775 const int rand_offset, const int n \
17741776 );
17751777
1776- // FP4 instantiations for blocksize=32
1777- MAKE_kQuantizeBlockwiseSmall (half, FP4) MAKE_kQuantizeBlockwiseSmall(float , FP4) MAKE_kQuantizeBlockwiseSmall(
1778- bnb_bfloat16, FP4
1779- )
1780-
1781- // NF4 instantiations for blocksize=32
1782- MAKE_kQuantizeBlockwiseSmall(half, NF4) MAKE_kQuantizeBlockwiseSmall(float , NF4) MAKE_kQuantizeBlockwiseSmall(
1783- bnb_bfloat16, NF4
1784- )
1778+ // QBLOCK_SIZE=32 instantiations
1779+ MAKE_kQuantizeBlockwiseSmall (half, 32 , FP4)
1780+ MAKE_kQuantizeBlockwiseSmall(float , 32 , FP4)
1781+ MAKE_kQuantizeBlockwiseSmall(bnb_bfloat16, 32 , FP4)
1782+ MAKE_kQuantizeBlockwiseSmall(half, 32 , NF4)
1783+ MAKE_kQuantizeBlockwiseSmall(float , 32 , NF4)
1784+ MAKE_kQuantizeBlockwiseSmall(bnb_bfloat16, 32 , NF4)
1785+
1786+ // QBLOCK_SIZE=64 instantiations (used on HIP for blocksize=64)
1787+ MAKE_kQuantizeBlockwiseSmall(half, 64 , FP4)
1788+ MAKE_kQuantizeBlockwiseSmall(float , 64 , FP4)
1789+ MAKE_kQuantizeBlockwiseSmall(bnb_bfloat16, 64 , FP4)
1790+ MAKE_kQuantizeBlockwiseSmall(half, 64 , NF4)
1791+ MAKE_kQuantizeBlockwiseSmall(float , 64 , NF4)
1792+ MAKE_kQuantizeBlockwiseSmall(bnb_bfloat16, 64 , NF4)
17851793
17861794 template __global__ void kDequantizeBlockwise<half, 512, 64, 8, FP4>(
17871795 float * code, unsigned char * A, float * absmax, half* out, const int blocksize, const int n
0 commit comments