@@ -656,23 +656,22 @@ template void percentileClipping(half* g, float* gnorm_vec, int step, const int
656656// ---- Device helpers ----
657657
658658__device__ __forceinline__ float warp_reduce_absmax_kbit (float val) {
659- #pragma unroll
659+ #pragma unroll
660660 for (int offset = 16 ; offset > 0 ; offset >>= 1 )
661661 val = fmaxf (val, __shfl_down_sync (0xFFFFFFFF , val, offset));
662662 return __shfl_sync (0xFFFFFFFF , val, 0 );
663663}
664664
665- template <int K>
666- __device__ __forceinline__ void pack_kbit_warp (unsigned char qval, unsigned int * packed_words) {
667- #pragma unroll
665+ template <int K> __device__ __forceinline__ void pack_kbit_warp (unsigned char qval, unsigned int * packed_words) {
666+ #pragma unroll
668667 for (int bit = 0 ; bit < K; bit++)
669668 packed_words[bit] = __ballot_sync (0xFFFFFFFF , (qval >> bit) & 1 );
670669}
671670
672671template <int K>
673672__device__ __forceinline__ unsigned char unpack_kbit_warp (const unsigned int * packed_words, int lane_id) {
674673 unsigned char val = 0 ;
675- #pragma unroll
674+ #pragma unroll
676675 for (int bit = 0 ; bit < K; bit++)
677676 val |= ((packed_words[bit] >> lane_id) & 1 ) << bit;
678677 return val;
@@ -682,25 +681,24 @@ __device__ __forceinline__ unsigned char unpack_kbit_warp(const unsigned int* pa
682681
683682template <typename T, int K>
684683__global__ void kQuantizeBlockwise_kbit (
685- const float * __restrict__ codebook,
686- const T* __restrict__ A,
687- float * __restrict__ absmax,
688- unsigned int * __restrict__ packed_out,
689- const int n
684+ const float * __restrict__ codebook, const T* __restrict__ A, float * __restrict__ absmax,
685+ unsigned int * __restrict__ packed_out, const int n
690686) {
691687 const int warp_id = (blockIdx .x * blockDim .x + threadIdx .x ) / 32 ;
692688 const int lane_id = threadIdx .x % 32 ;
693689 const int block_start = warp_id * 32 ;
694- if (block_start >= n) return ;
690+ if (block_start >= n)
691+ return ;
695692 float val = (block_start + lane_id < n) ? (float )A[block_start + lane_id] : 0 .0f ;
696693 float amax = warp_reduce_absmax_kbit (fabsf (val));
697694 float amax_safe = fmaxf (amax, 1e-8f );
698- if (lane_id == 0 ) absmax[warp_id] = amax;
695+ if (lane_id == 0 )
696+ absmax[warp_id] = amax;
699697 float normalized = val / amax_safe;
700698 float cb = (lane_id < (1 << K)) ? codebook[lane_id] : 0 .0f ;
701699 unsigned char best_idx = 0 ;
702700 float best_dist = 1e10f;
703- #pragma unroll
701+ #pragma unroll
704702 for (int i = 0 ; i < (1 << K); i++) {
705703 float cb_val = __shfl_sync (0xFFFFFFFF , cb, i);
706704 float dist = fabsf (normalized - cb_val);
@@ -722,7 +720,8 @@ __global__ void kQuantizeBlockwise_kbit(
722720constexpr int E4M4_BIAS = 11 ;
723721
724722__device__ __forceinline__ float decode_e4m4_absmax (unsigned char raw) {
725- if (raw == 0 ) return 0 .0f ;
723+ if (raw == 0 )
724+ return 0 .0f ;
726725 int e = raw >> 4 ;
727726 int m = raw & 0xF ;
728727 if (e == 0 ) {
@@ -738,13 +737,11 @@ __device__ __forceinline__ float decode_e4m4_absmax(unsigned char raw) {
738737
739738// Template helper: convert ABSMAX_T to float.
740739// Specialization for unsigned char uses E4M4 decode.
741- template <typename ABSMAX_T>
742- __device__ __forceinline__ float load_absmax (const ABSMAX_T* absmax, int idx) {
740+ template <typename ABSMAX_T> __device__ __forceinline__ float load_absmax (const ABSMAX_T* absmax, int idx) {
743741 return (float )absmax[idx];
744742}
745743
746- template <>
747- __device__ __forceinline__ float load_absmax<unsigned char >(const unsigned char * absmax, int idx) {
744+ template <> __device__ __forceinline__ float load_absmax<unsigned char >(const unsigned char * absmax, int idx) {
748745 return decode_e4m4_absmax (absmax[idx]);
749746}
750747
@@ -755,30 +752,29 @@ __device__ __forceinline__ float load_absmax<unsigned char>(const unsigned char*
755752// Templated on T (output type) and ABSMAX_T (absmax format).
756753template <typename T, int K, int BLOCKS_PER_WARP, typename ABSMAX_T>
757754__global__ void kDequantizeBlockwise_kbit_vec (
758- const unsigned int * __restrict__ packed_in,
759- const float * __restrict__ codebook,
760- const ABSMAX_T* __restrict__ absmax,
761- T* __restrict__ out,
762- const int n
755+ const unsigned int * __restrict__ packed_in, const float * __restrict__ codebook, const ABSMAX_T* __restrict__ absmax,
756+ T* __restrict__ out, const int n
763757) {
764758 const int warp_id = (blockIdx .x * blockDim .x + threadIdx .x ) / 32 ;
765759 const int lane_id = threadIdx .x % 32 ;
766760 const int base_block = warp_id * BLOCKS_PER_WARP;
767761
768- if (base_block * 32 >= n) return ;
762+ if (base_block * 32 >= n)
763+ return ;
769764
770765 // Load codebook into lane registers (one-time, amortized across BLOCKS_PER_WARP blocks)
771766 float cb = (lane_id < (1 << K)) ? codebook[lane_id] : 0 .0f ;
772767
773- #pragma unroll
768+ #pragma unroll
774769 for (int b = 0 ; b < BLOCKS_PER_WARP; b++) {
775770 const int block_id = base_block + b;
776771 const int block_start = block_id * 32 ;
777- if (block_start >= n) break ;
772+ if (block_start >= n)
773+ break ;
778774
779775 float amax = load_absmax (absmax, block_id);
780776 unsigned int packed[K];
781- #pragma unroll
777+ #pragma unroll
782778 for (int bit = 0 ; bit < K; bit++) {
783779 unsigned int word = (lane_id == bit) ? packed_in[block_id * K + bit] : 0 ;
784780 packed[bit] = __shfl_sync (0xFFFFFFFF , word, bit);
@@ -794,14 +790,12 @@ __global__ void kDequantizeBlockwise_kbit_vec(
794790// ---- Launch wrappers ----
795791
796792#define KBIT_WARPS_PER_BLOCK 8
797- #define KBIT_THREADS_PER_BLOCK (KBIT_WARPS_PER_BLOCK * 32 ) // 256
793+ #define KBIT_THREADS_PER_BLOCK (KBIT_WARPS_PER_BLOCK * 32 ) // 256
798794
799795// ---- Production kernel launchers (Stage 4-5) ----
800796
801797template <typename T, int K>
802- void quantizeBlockwise_kbit (
803- const float * codebook, const T* A, float * absmax, unsigned int * packed_out, int n
804- ) {
798+ void quantizeBlockwise_kbit (const float * codebook, const T* A, float * absmax, unsigned int * packed_out, int n) {
805799 int num_blocks_quant = (n + 31 ) / 32 ;
806800 int num_cuda_blocks = (num_blocks_quant + KBIT_WARPS_PER_BLOCK - 1 ) / KBIT_WARPS_PER_BLOCK;
807801 kQuantizeBlockwise_kbit <T, K><<<num_cuda_blocks, KBIT_THREADS_PER_BLOCK>>> (codebook, A, absmax, packed_out, n);
@@ -811,23 +805,21 @@ void quantizeBlockwise_kbit(
811805// Generic dequant launcher: supports all output types and absmax formats.
812806template <typename T, int K, typename ABSMAX_T>
813807void dequantizeBlockwise_kbit (
814- const unsigned int * packed_in, const float * codebook, const ABSMAX_T* absmax,
815- T* out, int n, cudaStream_t stream
808+ const unsigned int * packed_in, const float * codebook, const ABSMAX_T* absmax, T* out, int n, cudaStream_t stream
816809) {
817- constexpr int BPW = 4 ; // blocks per warp
810+ constexpr int BPW = 4 ; // blocks per warp
818811 int num_blocks_quant = (n + 31 ) / 32 ;
819812 int num_warps = (num_blocks_quant + BPW - 1 ) / BPW;
820813 int num_cuda_blocks = (num_warps + KBIT_WARPS_PER_BLOCK - 1 ) / KBIT_WARPS_PER_BLOCK;
821- kDequantizeBlockwise_kbit_vec <T, K, BPW, ABSMAX_T><<<num_cuda_blocks, KBIT_THREADS_PER_BLOCK, 0 , stream>>> (
822- packed_in, codebook, absmax, out, n);
814+ kDequantizeBlockwise_kbit_vec <T, K, BPW, ABSMAX_T>
815+ <<<num_cuda_blocks, KBIT_THREADS_PER_BLOCK, 0 , stream>>> ( packed_in, codebook, absmax, out, n);
823816 CUDA_CHECK_RETURN (cudaPeekAtLastError ());
824817}
825818
826819// ---- Template instantiations ----
827820
828- #define INSTANTIATE_KBIT_QUANT (T, K ) \
829- template void quantizeBlockwise_kbit<T, K>( \
830- const float *, const T*, float *, unsigned int *, int );
821+ #define INSTANTIATE_KBIT_QUANT (T, K ) \
822+ template void quantizeBlockwise_kbit<T, K>(const float *, const T*, float *, unsigned int *, int );
831823
832824INSTANTIATE_KBIT_QUANT (half, 2 )
833825INSTANTIATE_KBIT_QUANT(half, 3 )
@@ -843,9 +835,10 @@ INSTANTIATE_KBIT_QUANT(float, 4)
843835INSTANTIATE_KBIT_QUANT(float , 5 )
844836
845837// Dequant instantiations: all output types × absmax types × K values
846- #define INSTANTIATE_KBIT_DEQUANT (T, K, ABSMAX_T ) \
847- template void dequantizeBlockwise_kbit<T, K, ABSMAX_T>( \
848- const unsigned int *, const float *, const ABSMAX_T*, T*, int , cudaStream_t);
838+ #define INSTANTIATE_KBIT_DEQUANT (T, K, ABSMAX_T ) \
839+ template void dequantizeBlockwise_kbit<T, K, ABSMAX_T>( \
840+ const unsigned int *, const float *, const ABSMAX_T*, T*, int , cudaStream_t \
841+ );
849842
850843// uint8 E4M4 absmax (default)
851844INSTANTIATE_KBIT_DEQUANT (half, 2 , unsigned char )
0 commit comments