@@ -423,6 +423,90 @@ __global__ void kQuantizeBlockwise(
423423 }
424424}
425425
426+ // Specialized kernel for blocksize=32 with 4-bit quantization
427+ // Processes 2 blocks of 32 values per warp to maintain full thread utilization
428+ // Uses 32 threads total: threads 0-15 handle block 0, threads 16-31 handle block 1
429+ template <typename T, int DATA_TYPE>
430+ __global__ void kQuantizeBlockwise32 (
431+ float * code, T* __restrict__ const A, float * absmax, unsigned char * out, float * __restrict__ const rand,
432+ const int rand_offset, const int n
433+ ) {
434+ constexpr int BLOCK_SIZE = 32 ; // Size of each quantization block
435+ constexpr int NUM_PER_TH = 2 ; // Values per thread (for 4-bit packing)
436+ constexpr int THREADS = 32 ; // Total threads (full warp)
437+ constexpr int THREADS_PER_BLOCK = 16 ; // Threads handling each quantization block
438+
439+ const int base_idx = blockIdx .x * BLOCK_SIZE * 2 ; // 2 blocks per CUDA block
440+
441+ T vals[NUM_PER_TH];
442+ unsigned char qvals[NUM_PER_TH / 2 ]; // For 4-bit: 2 values per byte
443+ float local_abs_max = 0 .0f ;
444+
445+ const int block_id = threadIdx .x / THREADS_PER_BLOCK; // 0 for threads 0-15, 1 for threads 16-31
446+ const int local_thread_id = threadIdx .x % THREADS_PER_BLOCK; // Thread ID within the block (0-15)
447+
448+ typedef cub::BlockLoad<T, THREADS, NUM_PER_TH, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadT;
449+ typedef cub::BlockStore<unsigned char , THREADS, NUM_PER_TH / 2 , cub::BLOCK_STORE_WARP_TRANSPOSE> StoreChar;
450+ typedef cub::WarpReduce<float , 16 >
451+ WarpReduce; // Logical warp size of 16: threads 0-15 and 16-31 reduce independently
452+
453+ __shared__ typename LoadT::TempStorage loadt;
454+ __shared__ typename StoreChar::TempStorage storec;
455+ __shared__ typename WarpReduce::TempStorage warp_reduce[2 ]; // One per logical warp
456+ __shared__ float smem_absmax_value[2 ];
457+
458+ const int i = base_idx + block_id * BLOCK_SIZE;
459+ // Use a flag instead of early return: BlockLoad/BlockStore/__syncthreads are cooperative
460+ // operations that require ALL 32 threads to participate
461+ const bool block_valid = (i < n);
462+
463+ // All 32 threads participate in the load (out-of-bounds threads get 0.0f)
464+ __syncthreads ();
465+ LoadT (loadt).Load (&(A[base_idx]), vals, min (BLOCK_SIZE * 2 , n - base_idx), (T)0 .0f );
466+
467+ // Each thread computes max of its values
468+ local_abs_max = -FLT_MAX;
469+ #pragma unroll NUM_PER_TH
470+ for (int j = 0 ; j < NUM_PER_TH; j++)
471+ local_abs_max = fmaxf (local_abs_max, fabsf ((float )vals[j]));
472+
473+ // Reduce within each logical warp of 16 threads independently
474+ local_abs_max = WarpReduce (warp_reduce[block_id]).Reduce (local_abs_max, CUB_REDUCTIONOP_MAX);
475+
476+ if (local_thread_id == 0 ) {
477+ if (block_valid) {
478+ smem_absmax_value[block_id] = 1 .0f / local_abs_max;
479+ absmax[blockIdx .x * 2 + block_id] = local_abs_max;
480+ } else {
481+ smem_absmax_value[block_id] = 0 .0f ;
482+ }
483+ }
484+ __syncthreads ();
485+
486+ local_abs_max = smem_absmax_value[block_id];
487+
488+ switch (DATA_TYPE) {
489+ case FP4:
490+ #pragma unroll NUM_PER_TH
491+ for (int j = 0 ; j < NUM_PER_TH / 2 ; j++) {
492+ qvals[j] = dQuantizeFP4 (((float )vals[2 * j]) * local_abs_max) << 4 ;
493+ qvals[j] |= dQuantizeFP4 (((float )vals[2 * j + 1 ]) * local_abs_max);
494+ }
495+ break ;
496+ case NF4:
497+ #pragma unroll NUM_PER_TH
498+ for (int j = 0 ; j < NUM_PER_TH / 2 ; j++) {
499+ qvals[j] = dQuantizeNF4 (((float )vals[2 * j]) * local_abs_max) << 4 ;
500+ qvals[j] |= dQuantizeNF4 (((float )vals[2 * j + 1 ]) * local_abs_max);
501+ }
502+ break ;
503+ }
504+
505+ // All 32 threads participate in the store (valid_items limits the actual writes)
506+ __syncthreads ();
507+ StoreChar (storec).Store (&(out[base_idx / 2 ]), qvals, min ((BLOCK_SIZE * 2 + 1 ) / 2 , (n - base_idx + 1 ) / 2 ));
508+ }
509+
426510template <typename T, int TILE_SIZE, int THREADS, int NUM_PER_TH, int DATA_TYPE>
427511__global__ void
428512 kDequantizeBlockwise (float * code, unsigned char * A, float * absmax, T* out, const int blocksize, const int n) {
@@ -2440,9 +2524,24 @@ MAKE_kQuantizeBlockwise(__nv_bfloat16, 256, 2, 0, NF4)
24402524MAKE_kQuantizeBlockwise(__nv_bfloat16, 128 , 2 , 0 , NF4)
24412525MAKE_kQuantizeBlockwise(__nv_bfloat16, 64 , 2 , 0 , NF4)
24422526
2443- template __global__ void kDequantizeBlockwise<half, 512, 64, 8, FP4>(
2444- float * code, unsigned char * A, float * absmax, half* out, const int blocksize, const int n
2445- );
2527+ // Template instantiations for blocksize=32 specialized kernel (4-bit only)
2528+ #define MAKE_kQuantizeBlockwise32 (dtype, data_type_name ) \
2529+ template __global__ void kQuantizeBlockwise32 <dtype, data_type_name>( \
2530+ float * code, dtype* __restrict__ const A, float * absmax, unsigned char * out, float * __restrict__ const rand, \
2531+ const int rand_offset, const int n \
2532+ );
2533+
2534+ // FP4 instantiations for blocksize=32
2535+ MAKE_kQuantizeBlockwise32 (half, FP4) MAKE_kQuantizeBlockwise32(float , FP4) MAKE_kQuantizeBlockwise32(__nv_bfloat16, FP4)
2536+
2537+ // NF4 instantiations for blocksize=32
2538+ MAKE_kQuantizeBlockwise32(half, NF4) MAKE_kQuantizeBlockwise32(float , NF4) MAKE_kQuantizeBlockwise32(
2539+ __nv_bfloat16, NF4
2540+ )
2541+
2542+ template __global__ void kDequantizeBlockwise<half, 512, 64, 8, FP4>(
2543+ float * code, unsigned char * A, float * absmax, half* out, const int blocksize, const int n
2544+ );
24462545template __global__ void kDequantizeBlockwise <half, 512 , 64 , 8 , General8bit>(
24472546 float * code, unsigned char * A, float * absmax, half* out, const int blocksize, const int n
24482547);
0 commit comments