@@ -320,6 +320,110 @@ __global__ void kDequantizeNVFP4(
320320 output[element_idx + 1 ] = (T)val1;
321321}
322322
323+ // ============================================================================
324+ // Block-diagonal Hadamard rotation kernel (Had16)
325+ // Applies a 16x16 normalized Hadamard transform to each consecutive
326+ // 16-element chunk using the Fast Walsh-Hadamard Transform (FWHT).
327+ // 4 butterfly stages: stride 8, 4, 2, 1. Normalization by 1/4 = 1/sqrt(16).
328+ // In-place operation on FP16/BF16/FP32 tensors.
329+ // ============================================================================
330+ template <typename T>
331+ __global__ void kHadamardRotate16 (
332+ T* __restrict__ data,
333+ const int n
334+ ) {
335+ // Each thread handles one element.
336+ // 16 threads form one Hadamard block.
337+ const int tid = blockIdx .x * blockDim .x + threadIdx .x ;
338+ if (tid >= n) return ;
339+
340+ float val = (float )data[tid];
341+
342+ // Fast Walsh-Hadamard Transform: 4 butterfly stages
343+ // Threads within the same 16-element group exchange via warp shuffles
344+ // lane_in_block: position 0-15 within the 16-element Hadamard block
345+ #pragma unroll
346+ for (int stride = 8 ; stride >= 1 ; stride >>= 1 ) {
347+ float other = __shfl_xor_sync (0xFFFFFFFF , val, stride);
348+ // Butterfly: if bit is 0, add; if bit is 1, subtract
349+ int bit = tid & stride;
350+ val = bit ? (other - val) : (val + other);
351+ }
352+
353+ // Normalize by 1/sqrt(16) = 0.25 to make the transform orthogonal
354+ val *= 0 .25f ;
355+
356+ data[tid] = (T)val;
357+ }
358+
359+ // ============================================================================
360+ // Fused Hadamard rotation + NVFP4 quantization kernel
361+ // Combines Had16 rotation with two-level NVFP4 quantization in a single kernel.
362+ // Each CUDA block processes multiple 16-element Hadamard/quantization blocks.
363+ // ============================================================================
364+ template <typename T>
365+ __global__ void kFusedHadamardQuantizeNVFP4 (
366+ const T* __restrict__ input,
367+ unsigned char * __restrict__ output, // Packed FP4: n/2 bytes
368+ unsigned char * __restrict__ block_scales, // E4M3 scales: n/16 bytes
369+ const float tensor_scale,
370+ const int n
371+ ) {
372+ // Each thread handles 1 element for the Hadamard transform,
373+ // then pairs of threads pack 2 elements into 1 byte.
374+ const int tid = blockIdx .x * blockDim .x + threadIdx .x ;
375+ if (tid >= n) return ;
376+
377+ // Load and convert to float
378+ float val = (float )input[tid];
379+
380+ // Apply Hadamard rotation (FWHT, 4 butterfly stages)
381+ #pragma unroll
382+ for (int stride = 8 ; stride >= 1 ; stride >>= 1 ) {
383+ float other = __shfl_xor_sync (0xFFFFFFFF , val, stride);
384+ int bit = tid & stride;
385+ val = bit ? (other - val) : (val + other);
386+ }
387+ val *= 0 .25f ; // Normalize
388+
389+ // Divide by tensor_scale
390+ float inv_tensor_scale = (tensor_scale > 0 .0f ) ? (1 .0f / tensor_scale) : 0 .0f ;
391+ float scaled_val = val * inv_tensor_scale;
392+
393+ // Compute block absmax via warp shuffle (16 threads per Hadamard block)
394+ float local_max = fabsf (scaled_val);
395+ #pragma unroll
396+ for (int offset = 8 ; offset >= 1 ; offset >>= 1 ) {
397+ float other = __shfl_xor_sync (0xFFFFFFFF , local_max, offset);
398+ local_max = fmaxf (local_max, other);
399+ }
400+
401+ // Compute E4M3 block scale
402+ float block_scale_f32 = local_max / 6 .0f ;
403+ unsigned char block_scale_e4m3 = dFloatToE4M3 (block_scale_f32);
404+ float block_scale_deq = dE4M3ToFloat (block_scale_e4m3);
405+ float inv_block_scale = (block_scale_deq > 0 .0f ) ? (1 .0f / block_scale_deq) : 0 .0f ;
406+
407+ // Store block scale (first thread in each 16-thread group)
408+ int lane_in_block = tid & 15 ;
409+ if (lane_in_block == 0 ) {
410+ block_scales[tid / 16 ] = block_scale_e4m3;
411+ }
412+
413+ // Quantize to E2M1
414+ unsigned char q = dQuantizeNVFP4 (scaled_val * inv_block_scale);
415+
416+ // Pack pairs of values: even thread writes low nibble, odd thread writes high nibble
417+ // Get partner's quantized value
418+ unsigned char partner_q = __shfl_xor_sync (0xFFFFFFFF , q, 1 );
419+
420+ if ((tid & 1 ) == 0 ) {
421+ // Even thread: pack self as low nibble, partner (odd) as high nibble
422+ unsigned char packed = ((partner_q & 0x0F ) << 4 ) | (q & 0x0F );
423+ output[tid / 2 ] = packed;
424+ }
425+ }
426+
323427__device__ unsigned char dQuantizeNF4 (float x) {
324428
325429 // the values for this tree was generated by test_normal_map_tree
@@ -2792,6 +2896,25 @@ template __global__ void kDequantizeNVFP4<float>(
27922896 const float tensor_scale, float * __restrict__ output, const int n
27932897);
27942898
2899+ // Hadamard rotation kernel instantiations
2900+ template __global__ void kHadamardRotate16 <half>(half* __restrict__ data, const int n);
2901+ template __global__ void kHadamardRotate16 <__nv_bfloat16>(__nv_bfloat16* __restrict__ data, const int n);
2902+ template __global__ void kHadamardRotate16 <float >(float * __restrict__ data, const int n);
2903+
2904+ // Fused Hadamard + NVFP4 quantize kernel instantiations
2905+ template __global__ void kFusedHadamardQuantizeNVFP4 <half>(
2906+ const half* __restrict__ input, unsigned char * __restrict__ output,
2907+ unsigned char * __restrict__ block_scales, const float tensor_scale, const int n
2908+ );
2909+ template __global__ void kFusedHadamardQuantizeNVFP4 <__nv_bfloat16>(
2910+ const __nv_bfloat16* __restrict__ input, unsigned char * __restrict__ output,
2911+ unsigned char * __restrict__ block_scales, const float tensor_scale, const int n
2912+ );
2913+ template __global__ void kFusedHadamardQuantizeNVFP4 <float >(
2914+ const float * __restrict__ input, unsigned char * __restrict__ output,
2915+ unsigned char * __restrict__ block_scales, const float tensor_scale, const int n
2916+ );
2917+
27952918#define MAKE_OptimizerStatic8bit2StateBlockwise (oname, gtype, block_size, num_per_thread ) \
27962919 template __global__ void kOptimizerStatic8bit2StateBlockwise <gtype, oname, block_size, num_per_thread>( \
27972920 gtype * p, gtype* __restrict__ const g, unsigned char * state1, unsigned char * state2, const float beta1, \
0 commit comments