@@ -128,9 +128,7 @@ __device__ __forceinline__ float dDequantizeNF4(unsigned char val) { return nf4_
128128// ============================================================================
129129
130130// E2M1 dequantization LUT - maps 3-bit unsigned magnitude code to float
131- __device__ static float nvfp4_dequant_lut[8 ] = {
132- 0 .0f , 0 .5f , 1 .0f , 1 .5f , 2 .0f , 3 .0f , 4 .0f , 6 .0f
133- };
131+ __device__ static float nvfp4_dequant_lut[8 ] = {0 .0f , 0 .5f , 1 .0f , 1 .5f , 2 .0f , 3 .0f , 4 .0f , 6 .0f };
134132
135133// Dequantize a 4-bit E2M1 code to float
136134// Bit layout: [sign(1) | exponent(2) | mantissa(1)]
@@ -170,8 +168,10 @@ __device__ unsigned char dQuantizeNVFP4(float x) {
170168// Convert positive float to unsigned E4M3 (8-bit: 4 exponent bits, bias=7, 3 mantissa bits)
171169// Range: [0, 448]. Used for NVFP4 block scale factors.
172170__device__ unsigned char dFloatToE4M3 (float x) {
173- if (x <= 0 .0f ) return 0 ;
174- if (x >= 448 .0f ) return 0x7E ; // Max normal (exp=14, mant=6). exp=15 mant=7 is NaN.
171+ if (x <= 0 .0f )
172+ return 0 ;
173+ if (x >= 448 .0f )
174+ return 0x7E ; // Max normal (exp=14, mant=6). exp=15 mant=7 is NaN.
175175
176176 unsigned int bits = __float_as_uint (x);
177177 int fp32_exp = ((bits >> 23 ) & 0xFF ) - 127 ; // Unbiased FP32 exponent
@@ -180,8 +180,10 @@ __device__ unsigned char dFloatToE4M3(float x) {
180180 if (e4m3_exp <= 0 ) {
181181 // Subnormal in E4M3: value = mantissa/8 * 2^(-6)
182182 int mant = __float2int_rn (x * 512 .0f ); // 512 = 8 * 2^6
183- if (mant <= 0 ) return 0 ;
184- if (mant > 7 ) mant = 7 ;
183+ if (mant <= 0 )
184+ return 0 ;
185+ if (mant > 7 )
186+ mant = 7 ;
185187 return (unsigned char )mant;
186188 }
187189
@@ -194,15 +196,18 @@ __device__ unsigned char dFloatToE4M3(float x) {
194196 e4m3_exp++;
195197 }
196198
197- if (e4m3_exp > 15 ) return 0x7E ;
198- if (e4m3_exp == 15 && mant_3bit >= 7 ) return 0x7E ; // Clamp, don't produce NaN
199+ if (e4m3_exp > 15 )
200+ return 0x7E ;
201+ if (e4m3_exp == 15 && mant_3bit >= 7 )
202+ return 0x7E ; // Clamp, don't produce NaN
199203
200204 return (unsigned char )((e4m3_exp << 3 ) | mant_3bit);
201205}
202206
203207// Convert unsigned E4M3 byte to float
204208__device__ float dE4M3ToFloat (unsigned char val) {
205- if (val == 0 ) return 0 .0f ;
209+ if (val == 0 )
210+ return 0 .0f ;
206211
207212 int exp = (val >> 3 ) & 0x0F ;
208213 int mant = val & 0x07 ;
@@ -225,17 +230,17 @@ __device__ float dE4M3ToFloat(unsigned char val) {
225230template <typename T>
226231__global__ void kQuantizeNVFP4 (
227232 const T* __restrict__ input,
228- unsigned char * __restrict__ output, // Packed FP4: n/2 bytes
233+ unsigned char * __restrict__ output, // Packed FP4: n/2 bytes
229234 unsigned char * __restrict__ block_scales, // E4M3 scales: n/16 bytes
230- const float tensor_scale,
231- const int n
235+ const float tensor_scale, const int n
232236) {
233237 // Each thread handles 2 consecutive elements (packs into 1 byte)
234238 // 8 threads per 16-element quantization block
235239 const int tid = blockIdx .x * blockDim .x + threadIdx .x ;
236240 const int element_idx = tid * 2 ;
237241
238- if (element_idx >= n) return ;
242+ if (element_idx >= n)
243+ return ;
239244
240245 const float inv_tensor_scale = (tensor_scale > 0 .0f ) ? (1 .0f / tensor_scale) : 0 .0f ;
241246
@@ -246,10 +251,10 @@ __global__ void kQuantizeNVFP4(
246251 // Compute per-thread absmax
247252 float local_max = fmaxf (fabsf (val0), fabsf (val1));
248253
249- // Warp-shuffle reduction within 8-thread quantization block
250- // Threads 0-7 handle block 0, 8-15 handle block 1, etc.
251- // XOR offsets 4, 2, 1 stay within each 8-thread group
252- #pragma unroll
254+ // Warp-shuffle reduction within 8-thread quantization block
255+ // Threads 0-7 handle block 0, 8-15 handle block 1, etc.
256+ // XOR offsets 4, 2, 1 stay within each 8-thread group
257+ #pragma unroll
253258 for (int offset = 4 ; offset >= 1 ; offset >>= 1 ) {
254259 float other = __shfl_xor_sync (0xFFFFFFFF , local_max, offset);
255260 local_max = fmaxf (local_max, other);
@@ -287,16 +292,15 @@ __global__ void kQuantizeNVFP4(
287292// ============================================================================
288293template <typename T>
289294__global__ void kDequantizeNVFP4 (
290- const unsigned char * __restrict__ input, // Packed FP4: n/2 bytes
295+ const unsigned char * __restrict__ input, // Packed FP4: n/2 bytes
291296 const unsigned char * __restrict__ block_scales, // E4M3 scales: n/16 bytes
292- const float tensor_scale,
293- T* __restrict__ output,
294- const int n
297+ const float tensor_scale, T* __restrict__ output, const int n
295298) {
296299 const int tid = blockIdx .x * blockDim .x + threadIdx .x ;
297300 const int element_idx = tid * 2 ;
298301
299- if (element_idx >= n) return ;
302+ if (element_idx >= n)
303+ return ;
300304
301305 // Load and unpack
302306 unsigned char packed = input[element_idx / 2 ];
@@ -327,22 +331,19 @@ __global__ void kDequantizeNVFP4(
327331// 4 butterfly stages: stride 8, 4, 2, 1. Normalization by 1/4 = 1/sqrt(16).
328332// In-place operation on FP16/BF16/FP32 tensors.
329333// ============================================================================
330- template <typename T>
331- __global__ void kHadamardRotate16 (
332- T* __restrict__ data,
333- const int n
334- ) {
334+ template <typename T> __global__ void kHadamardRotate16 (T* __restrict__ data, const int n) {
335335 // Each thread handles one element.
336336 // 16 threads form one Hadamard block.
337337 const int tid = blockIdx .x * blockDim .x + threadIdx .x ;
338- if (tid >= n) return ;
338+ if (tid >= n)
339+ return ;
339340
340341 float val = (float )data[tid];
341342
342- // Fast Walsh-Hadamard Transform: 4 butterfly stages
343- // Threads within the same 16-element group exchange via warp shuffles
344- // lane_in_block: position 0-15 within the 16-element Hadamard block
345- #pragma unroll
343+ // Fast Walsh-Hadamard Transform: 4 butterfly stages
344+ // Threads within the same 16-element group exchange via warp shuffles
345+ // lane_in_block: position 0-15 within the 16-element Hadamard block
346+ #pragma unroll
346347 for (int stride = 8 ; stride >= 1 ; stride >>= 1 ) {
347348 float other = __shfl_xor_sync (0xFFFFFFFF , val, stride);
348349 // Butterfly: if bit is 0, add; if bit is 1, subtract
@@ -365,20 +366,20 @@ template <typename T>
365366__global__ void kFusedHadamardQuantizeNVFP4 (
366367 const T* __restrict__ input,
367368 unsigned char * __restrict__ output, // Packed FP4: n/2 bytes
368- unsigned char * __restrict__ block_scales, // E4M3 scales: n/16 bytes
369- const float tensor_scale,
370- const int n
369+ unsigned char * __restrict__ block_scales, // E4M3 scales: n/16 bytes
370+ const float tensor_scale, const int n
371371) {
372372 // Each thread handles 1 element for the Hadamard transform,
373373 // then pairs of threads pack 2 elements into 1 byte.
374374 const int tid = blockIdx .x * blockDim .x + threadIdx .x ;
375- if (tid >= n) return ;
375+ if (tid >= n)
376+ return ;
376377
377378 // Load and convert to float
378379 float val = (float )input[tid];
379380
380- // Apply Hadamard rotation (FWHT, 4 butterfly stages)
381- #pragma unroll
381+ // Apply Hadamard rotation (FWHT, 4 butterfly stages)
382+ #pragma unroll
382383 for (int stride = 8 ; stride >= 1 ; stride >>= 1 ) {
383384 float other = __shfl_xor_sync (0xFFFFFFFF , val, stride);
384385 int bit = tid & stride;
@@ -392,7 +393,7 @@ __global__ void kFusedHadamardQuantizeNVFP4(
392393
393394 // Compute block absmax via warp shuffle (16 threads per Hadamard block)
394395 float local_max = fabsf (scaled_val);
395- #pragma unroll
396+ #pragma unroll
396397 for (int offset = 8 ; offset >= 1 ; offset >>= 1 ) {
397398 float other = __shfl_xor_sync (0xFFFFFFFF , local_max, offset);
398399 local_max = fmaxf (local_max, other);
@@ -2872,28 +2873,28 @@ template __global__ void kDequantizeBlockwise<__nv_bfloat16, 512, 64, 8, NF4>(
28722873
28732874// NVFP4 kernel template instantiations
28742875template __global__ void kQuantizeNVFP4 <half>(
2875- const half* __restrict__ input, unsigned char * __restrict__ output,
2876- unsigned char * __restrict__ block_scales, const float tensor_scale, const int n
2876+ const half* __restrict__ input, unsigned char * __restrict__ output, unsigned char * __restrict__ block_scales,
2877+ const float tensor_scale, const int n
28772878);
28782879template __global__ void kQuantizeNVFP4 <__nv_bfloat16>(
28792880 const __nv_bfloat16* __restrict__ input, unsigned char * __restrict__ output,
28802881 unsigned char * __restrict__ block_scales, const float tensor_scale, const int n
28812882);
28822883template __global__ void kQuantizeNVFP4 <float >(
2883- const float * __restrict__ input, unsigned char * __restrict__ output,
2884- unsigned char * __restrict__ block_scales, const float tensor_scale, const int n
2884+ const float * __restrict__ input, unsigned char * __restrict__ output, unsigned char * __restrict__ block_scales,
2885+ const float tensor_scale, const int n
28852886);
28862887template __global__ void kDequantizeNVFP4 <half>(
2887- const unsigned char * __restrict__ input, const unsigned char * __restrict__ block_scales,
2888- const float tensor_scale, half* __restrict__ output, const int n
2888+ const unsigned char * __restrict__ input, const unsigned char * __restrict__ block_scales, const float tensor_scale,
2889+ half* __restrict__ output, const int n
28892890);
28902891template __global__ void kDequantizeNVFP4 <__nv_bfloat16>(
2891- const unsigned char * __restrict__ input, const unsigned char * __restrict__ block_scales,
2892- const float tensor_scale, __nv_bfloat16* __restrict__ output, const int n
2892+ const unsigned char * __restrict__ input, const unsigned char * __restrict__ block_scales, const float tensor_scale,
2893+ __nv_bfloat16* __restrict__ output, const int n
28932894);
28942895template __global__ void kDequantizeNVFP4 <float >(
2895- const unsigned char * __restrict__ input, const unsigned char * __restrict__ block_scales,
2896- const float tensor_scale, float * __restrict__ output, const int n
2896+ const unsigned char * __restrict__ input, const unsigned char * __restrict__ block_scales, const float tensor_scale,
2897+ float * __restrict__ output, const int n
28972898);
28982899
28992900// Hadamard rotation kernel instantiations
@@ -2903,16 +2904,16 @@ template __global__ void kHadamardRotate16<float>(float* __restrict__ data, cons
29032904
29042905// Fused Hadamard + NVFP4 quantize kernel instantiations
29052906template __global__ void kFusedHadamardQuantizeNVFP4 <half>(
2906- const half* __restrict__ input, unsigned char * __restrict__ output,
2907- unsigned char * __restrict__ block_scales, const float tensor_scale, const int n
2907+ const half* __restrict__ input, unsigned char * __restrict__ output, unsigned char * __restrict__ block_scales,
2908+ const float tensor_scale, const int n
29082909);
29092910template __global__ void kFusedHadamardQuantizeNVFP4 <__nv_bfloat16>(
29102911 const __nv_bfloat16* __restrict__ input, unsigned char * __restrict__ output,
29112912 unsigned char * __restrict__ block_scales, const float tensor_scale, const int n
29122913);
29132914template __global__ void kFusedHadamardQuantizeNVFP4 <float >(
2914- const float * __restrict__ input, unsigned char * __restrict__ output,
2915- unsigned char * __restrict__ block_scales, const float tensor_scale, const int n
2915+ const float * __restrict__ input, unsigned char * __restrict__ output, unsigned char * __restrict__ block_scales,
2916+ const float tensor_scale, const int n
29162917);
29172918
29182919#define MAKE_OptimizerStatic8bit2StateBlockwise (oname, gtype, block_size, num_per_thread ) \
0 commit comments