@@ -443,6 +443,94 @@ __global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float
443443 }
444444}
445445
446+ // Specialized kernel for blocksize=64 with 4-bit quantization
447+ // Works on both warp32 and warp64 hardware
448+ // Processes 2 blocks of 64 values per thread block using 64 threads
449+ // Uses logical warps of 32: threads 0-31 handle block 0, threads 32-63 handle block 1
450+ // - warp32: 2 hardware warps, each reduces naturally
451+ // - warp64: 1 hardware warp split into 2 logical warps of 32
452+ template <typename T, int DATA_TYPE>
453+ __global__ void kQuantizeBlockwise64 (
454+ float * code, T* __restrict__ const A, float * absmax, unsigned char * out, float * __restrict__ const rand,
455+ const int rand_offset, const int n
456+ ) {
457+ constexpr int BLOCK_SIZE = 64 ; // Size of each quantization block
458+ constexpr int NUM_PER_TH = 2 ; // Values per thread (for 4-bit packing)
459+ constexpr int THREADS = 64 ; // Total threads per HIP block
460+ constexpr int THREADS_PER_BLOCK = 32 ; // Threads handling each quantization block
461+
462+ const int base_idx = blockIdx.x * BLOCK_SIZE * 2 ; // 2 quantization blocks per HIP block
463+
464+ T vals[NUM_PER_TH];
465+ unsigned char qvals[NUM_PER_TH / 2 ]; // For 4-bit: 2 values per byte
466+ float local_abs_max = 0 .0f ;
467+
468+ const int block_id = threadIdx.x / THREADS_PER_BLOCK; // 0 for threads 0-31, 1 for threads 32-63
469+ const int local_thread_id = threadIdx.x % THREADS_PER_BLOCK; // Thread ID within the quantization block (0-31)
470+
471+ typedef hipcub::BlockLoad<T, THREADS, NUM_PER_TH, hipcub::BLOCK_LOAD_WARP_TRANSPOSE> LoadT;
472+ typedef hipcub::BlockStore<unsigned char , THREADS, NUM_PER_TH / 2 , hipcub::BLOCK_STORE_WARP_TRANSPOSE> StoreChar;
473+ // Logical warp size of 32: on warp32 this matches hardware warps,
474+ // on warp64 this splits the single hardware warp into two independent reductions
475+ typedef hipcub::WarpReduce<float , 32 > WarpReduce;
476+
477+ __shared__ typename LoadT::TempStorage loadt;
478+ __shared__ typename StoreChar::TempStorage storec;
479+ __shared__ typename WarpReduce::TempStorage warp_reduce[2 ]; // One per logical warp
480+ __shared__ float smem_absmax_value[2 ];
481+
482+ const int i = base_idx + block_id * BLOCK_SIZE;
483+ // Use a flag instead of early return: BlockLoad/BlockStore/__syncthreads are cooperative
484+ // operations that require ALL 64 threads to participate
485+ const bool block_valid = (i < n);
486+
487+ // All 64 threads participate in the load (out-of-bounds threads get 0.0f)
488+ __syncthreads ();
489+ LoadT (loadt).Load (&(A[base_idx]), vals, min (BLOCK_SIZE * 2 , n - base_idx), (T)0 .0f );
490+
491+ // Each thread computes max of its values
492+ local_abs_max = -FLT_MAX;
493+ #pragma unroll NUM_PER_TH
494+ for (int j = 0 ; j < NUM_PER_TH; j++)
495+ local_abs_max = fmaxf (local_abs_max, fabsf ((float )vals[j]));
496+
497+ // Reduce within each logical warp of 32 threads independently
498+ local_abs_max = WarpReduce (warp_reduce[block_id]).Reduce (local_abs_max, hipcub::Max ());
499+
500+ if (local_thread_id == 0 ) {
501+ if (block_valid) {
502+ smem_absmax_value[block_id] = 1 .0f / local_abs_max;
503+ absmax[blockIdx.x * 2 + block_id] = local_abs_max;
504+ } else {
505+ smem_absmax_value[block_id] = 0 .0f ;
506+ }
507+ }
508+ __syncthreads ();
509+
510+ local_abs_max = smem_absmax_value[block_id];
511+
512+ switch (DATA_TYPE) {
513+ case FP4:
514+ #pragma unroll NUM_PER_TH
515+ for (int j = 0 ; j < NUM_PER_TH / 2 ; j++) {
516+ qvals[j] = dQuantizeFP4 (((float )vals[2 * j]) * local_abs_max) << 4 ;
517+ qvals[j] |= dQuantizeFP4 (((float )vals[2 * j + 1 ]) * local_abs_max);
518+ }
519+ break ;
520+ case NF4:
521+ #pragma unroll NUM_PER_TH
522+ for (int j = 0 ; j < NUM_PER_TH / 2 ; j++) {
523+ qvals[j] = dQuantizeNF4 (((float )vals[2 * j]) * local_abs_max) << 4 ;
524+ qvals[j] |= dQuantizeNF4 (((float )vals[2 * j + 1 ]) * local_abs_max);
525+ }
526+ break ;
527+ }
528+
529+ // All 64 threads participate in the store (valid_items limits the actual writes)
530+ __syncthreads ();
531+ StoreChar (storec).Store (&(out[base_idx / 2 ]), qvals, min ((BLOCK_SIZE * 2 + 1 ) / 2 , (n - base_idx + 1 ) / 2 ));
532+ }
533+
446534template <typename T, int TILE_SIZE, int THREADS, int NUM_PER_TH, int DATA_TYPE>
447535__global__ void kDequantizeBlockwise (float *code, unsigned char * A, float * absmax, T *out, const int blocksize, const int n)
448536{
@@ -2566,6 +2654,20 @@ MAKE_kQuantizeBlockwise(hip_bfloat16, 128, 2, 0, NF4)
25662654 MAKE_kQuantizeBlockwise (hip_bfloat16, 64 , 2 , 0 , NF4)
25672655#endif
25682656
2657+ // Specialized blocksize=64 4-bit quantization kernel instantiations for ROCm
2658+ #define MAKE_kQuantizeBlockwise64 (dtype, data_type_name ) \
2659+ template __global__ void kQuantizeBlockwise64 <dtype, data_type_name>(float * code, dtype * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
2660+
2661+ // FP4 instantiations
2662+ MAKE_kQuantizeBlockwise64 (half, FP4)
2663+ MAKE_kQuantizeBlockwise64(float , FP4)
2664+ MAKE_kQuantizeBlockwise64(hip_bfloat16, FP4)
2665+
2666+ // NF4 instantiations
2667+ MAKE_kQuantizeBlockwise64(half, NF4)
2668+ MAKE_kQuantizeBlockwise64(float , NF4)
2669+ MAKE_kQuantizeBlockwise64(hip_bfloat16, NF4)
2670+
25692671template __global__ void kDequantizeBlockwise<half, 512, 64, 8, FP4>(float *code, unsigned char * A, float * absmax, half *out, const int blocksize, const int n);
25702672template __global__ void kDequantizeBlockwise <half, 512 , 64 , 8 , General8bit>(float *code, unsigned char * A, float * absmax, half *out, const int blocksize, const int n);
25712673template __global__ void kDequantizeBlockwise <half, 512 , 64 , 8 , NF4>(float *code, unsigned char * A, float * absmax, half *out, const int blocksize, const int n);
0 commit comments