Skip to content

Commit dd6b88c

Browse files
TimDettmersclaude
andcommitted
feat: Add Hadamard rotation and fused rotate+quantize NVFP4 kernels
- kHadamardRotate16: block-diagonal 16x16 Hadamard via FWHT (4 butterfly stages with warp shuffles), normalized by 1/sqrt(16) - kFusedHadamardQuantizeNVFP4: single-kernel Had16 rotation + NVFP4 quantization (rotation, block scale computation, E2M1 encoding, packing) - Host launchers, template instantiations, extern C symbols for all Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 4d8db39 commit dd6b88c

File tree

5 files changed

+241
-0
lines changed

5 files changed

+241
-0
lines changed

csrc/kernels.cu

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -320,6 +320,110 @@ __global__ void kDequantizeNVFP4(
320320
output[element_idx + 1] = (T)val1;
321321
}
322322

323+
// ============================================================================
324+
// Block-diagonal Hadamard rotation kernel (Had16)
325+
// Applies a 16x16 normalized Hadamard transform to each consecutive
326+
// 16-element chunk using the Fast Walsh-Hadamard Transform (FWHT).
327+
// 4 butterfly stages: stride 8, 4, 2, 1. Normalization by 1/4 = 1/sqrt(16).
328+
// In-place operation on FP16/BF16/FP32 tensors.
329+
// ============================================================================
330+
template <typename T>
331+
__global__ void kHadamardRotate16(
332+
T* __restrict__ data,
333+
const int n
334+
) {
335+
// Each thread handles one element.
336+
// 16 threads form one Hadamard block.
337+
const int tid = blockIdx.x * blockDim.x + threadIdx.x;
338+
if (tid >= n) return;
339+
340+
float val = (float)data[tid];
341+
342+
// Fast Walsh-Hadamard Transform: 4 butterfly stages
343+
// Threads within the same 16-element group exchange via warp shuffles
344+
// lane_in_block: position 0-15 within the 16-element Hadamard block
345+
#pragma unroll
346+
for (int stride = 8; stride >= 1; stride >>= 1) {
347+
float other = __shfl_xor_sync(0xFFFFFFFF, val, stride);
348+
// Butterfly: if bit is 0, add; if bit is 1, subtract
349+
int bit = tid & stride;
350+
val = bit ? (other - val) : (val + other);
351+
}
352+
353+
// Normalize by 1/sqrt(16) = 0.25 to make the transform orthogonal
354+
val *= 0.25f;
355+
356+
data[tid] = (T)val;
357+
}
358+
359+
// ============================================================================
360+
// Fused Hadamard rotation + NVFP4 quantization kernel
361+
// Combines Had16 rotation with two-level NVFP4 quantization in a single kernel.
362+
// Each CUDA block processes multiple 16-element Hadamard/quantization blocks.
363+
// ============================================================================
364+
template <typename T>
365+
__global__ void kFusedHadamardQuantizeNVFP4(
366+
const T* __restrict__ input,
367+
unsigned char* __restrict__ output, // Packed FP4: n/2 bytes
368+
unsigned char* __restrict__ block_scales, // E4M3 scales: n/16 bytes
369+
const float tensor_scale,
370+
const int n
371+
) {
372+
// Each thread handles 1 element for the Hadamard transform,
373+
// then pairs of threads pack 2 elements into 1 byte.
374+
const int tid = blockIdx.x * blockDim.x + threadIdx.x;
375+
if (tid >= n) return;
376+
377+
// Load and convert to float
378+
float val = (float)input[tid];
379+
380+
// Apply Hadamard rotation (FWHT, 4 butterfly stages)
381+
#pragma unroll
382+
for (int stride = 8; stride >= 1; stride >>= 1) {
383+
float other = __shfl_xor_sync(0xFFFFFFFF, val, stride);
384+
int bit = tid & stride;
385+
val = bit ? (other - val) : (val + other);
386+
}
387+
val *= 0.25f; // Normalize
388+
389+
// Divide by tensor_scale
390+
float inv_tensor_scale = (tensor_scale > 0.0f) ? (1.0f / tensor_scale) : 0.0f;
391+
float scaled_val = val * inv_tensor_scale;
392+
393+
// Compute block absmax via warp shuffle (16 threads per Hadamard block)
394+
float local_max = fabsf(scaled_val);
395+
#pragma unroll
396+
for (int offset = 8; offset >= 1; offset >>= 1) {
397+
float other = __shfl_xor_sync(0xFFFFFFFF, local_max, offset);
398+
local_max = fmaxf(local_max, other);
399+
}
400+
401+
// Compute E4M3 block scale
402+
float block_scale_f32 = local_max / 6.0f;
403+
unsigned char block_scale_e4m3 = dFloatToE4M3(block_scale_f32);
404+
float block_scale_deq = dE4M3ToFloat(block_scale_e4m3);
405+
float inv_block_scale = (block_scale_deq > 0.0f) ? (1.0f / block_scale_deq) : 0.0f;
406+
407+
// Store block scale (first thread in each 16-thread group)
408+
int lane_in_block = tid & 15;
409+
if (lane_in_block == 0) {
410+
block_scales[tid / 16] = block_scale_e4m3;
411+
}
412+
413+
// Quantize to E2M1
414+
unsigned char q = dQuantizeNVFP4(scaled_val * inv_block_scale);
415+
416+
// Pack pairs of values: even thread writes low nibble, odd thread writes high nibble
417+
// Get partner's quantized value
418+
unsigned char partner_q = __shfl_xor_sync(0xFFFFFFFF, q, 1);
419+
420+
if ((tid & 1) == 0) {
421+
// Even thread: pack self as low nibble, partner (odd) as high nibble
422+
unsigned char packed = ((partner_q & 0x0F) << 4) | (q & 0x0F);
423+
output[tid / 2] = packed;
424+
}
425+
}
426+
323427
__device__ unsigned char dQuantizeNF4(float x) {
324428

325429
// the values for this tree was generated by test_normal_map_tree
@@ -2792,6 +2896,25 @@ template __global__ void kDequantizeNVFP4<float>(
27922896
const float tensor_scale, float* __restrict__ output, const int n
27932897
);
27942898

2899+
// Hadamard rotation kernel instantiations
2900+
template __global__ void kHadamardRotate16<half>(half* __restrict__ data, const int n);
2901+
template __global__ void kHadamardRotate16<__nv_bfloat16>(__nv_bfloat16* __restrict__ data, const int n);
2902+
template __global__ void kHadamardRotate16<float>(float* __restrict__ data, const int n);
2903+
2904+
// Fused Hadamard + NVFP4 quantize kernel instantiations
2905+
template __global__ void kFusedHadamardQuantizeNVFP4<half>(
2906+
const half* __restrict__ input, unsigned char* __restrict__ output,
2907+
unsigned char* __restrict__ block_scales, const float tensor_scale, const int n
2908+
);
2909+
template __global__ void kFusedHadamardQuantizeNVFP4<__nv_bfloat16>(
2910+
const __nv_bfloat16* __restrict__ input, unsigned char* __restrict__ output,
2911+
unsigned char* __restrict__ block_scales, const float tensor_scale, const int n
2912+
);
2913+
template __global__ void kFusedHadamardQuantizeNVFP4<float>(
2914+
const float* __restrict__ input, unsigned char* __restrict__ output,
2915+
unsigned char* __restrict__ block_scales, const float tensor_scale, const int n
2916+
);
2917+
27952918
#define MAKE_OptimizerStatic8bit2StateBlockwise(oname, gtype, block_size, num_per_thread) \
27962919
template __global__ void kOptimizerStatic8bit2StateBlockwise<gtype, oname, block_size, num_per_thread>( \
27972920
gtype * p, gtype* __restrict__ const g, unsigned char* state1, unsigned char* state2, const float beta1, \

csrc/kernels.cuh

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,15 @@ __global__ void kDequantizeNVFP4(
3737
const float tensor_scale, T* __restrict__ output, const int n
3838
);
3939

40+
template <typename T>
41+
__global__ void kHadamardRotate16(T* __restrict__ data, const int n);
42+
43+
template <typename T>
44+
__global__ void kFusedHadamardQuantizeNVFP4(
45+
const T* __restrict__ input, unsigned char* __restrict__ output,
46+
unsigned char* __restrict__ block_scales, const float tensor_scale, const int n
47+
);
48+
4049
template <typename T, int OPTIMIZER, int BLOCK_SIZE, int NUM_VALS>
4150
__global__ void kPreconditionOptimizer32bit2State(
4251
T* g, T* p, float* state1, float* state2, float* unorm, const float beta1, const float beta2, const float eps,

csrc/ops.cu

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,44 @@ template void dequantizeNVFP4<float>(
142142
float tensor_scale, float* output, const int n, cudaStream_t stream
143143
);
144144

145+
template <typename T>
146+
void hadamardRotate16(T* data, const int n) {
147+
const int threads_per_block = 256;
148+
const int num_blocks = (n + threads_per_block - 1) / threads_per_block;
149+
kHadamardRotate16<T><<<num_blocks, threads_per_block>>>(data, n);
150+
CUDA_CHECK_RETURN(cudaPeekAtLastError());
151+
}
152+
153+
template <typename T>
154+
void fusedHadamardQuantizeNVFP4(
155+
const T* input, unsigned char* output, unsigned char* block_scales,
156+
float tensor_scale, const int n
157+
) {
158+
const int threads_per_block = 256;
159+
const int num_blocks = (n + threads_per_block - 1) / threads_per_block;
160+
kFusedHadamardQuantizeNVFP4<T><<<num_blocks, threads_per_block>>>(
161+
input, output, block_scales, tensor_scale, n
162+
);
163+
CUDA_CHECK_RETURN(cudaPeekAtLastError());
164+
}
165+
166+
// Hadamard and fused kernel instantiations
167+
template void hadamardRotate16<half>(half* data, const int n);
168+
template void hadamardRotate16<__nv_bfloat16>(__nv_bfloat16* data, const int n);
169+
template void hadamardRotate16<float>(float* data, const int n);
170+
template void fusedHadamardQuantizeNVFP4<half>(
171+
const half* input, unsigned char* output, unsigned char* block_scales,
172+
float tensor_scale, const int n
173+
);
174+
template void fusedHadamardQuantizeNVFP4<__nv_bfloat16>(
175+
const __nv_bfloat16* input, unsigned char* output, unsigned char* block_scales,
176+
float tensor_scale, const int n
177+
);
178+
template void fusedHadamardQuantizeNVFP4<float>(
179+
const float* input, unsigned char* output, unsigned char* block_scales,
180+
float tensor_scale, const int n
181+
);
182+
145183
template <typename T, int OPTIMIZER>
146184
void optimizer32bit(
147185
T* g, T* p, float* state1, float* state2, float* unorm, float max_unorm, float param_norm, const float beta1,

csrc/ops.cuh

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,15 @@ void dequantizeNVFP4(
131131
float tensor_scale, T* output, const int n, cudaStream_t stream
132132
);
133133

134+
template <typename T>
135+
void hadamardRotate16(T* data, const int n);
136+
137+
template <typename T>
138+
void fusedHadamardQuantizeNVFP4(
139+
const T* input, unsigned char* output, unsigned char* block_scales,
140+
float tensor_scale, const int n
141+
);
142+
134143
template <typename T, int OPTIMIZER>
135144
void optimizer32bit(
136145
T* g, T* p, float* state1, float* state2, float* unorm, float max_unorm, float param_norm, float beta1, float beta2,

csrc/pythonInterface.cpp

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,37 @@ void quantizeNVFP4_fp32(
224224
quantizeNVFP4<float>(input, output, block_scales, tensor_scale, n);
225225
}
226226

227+
// Hadamard rotation wrapper functions
228+
void hadamardRotate16_fp16(half* data, const int n) {
229+
hadamardRotate16<half>(data, n);
230+
}
231+
void hadamardRotate16_bf16(__nv_bfloat16* data, const int n) {
232+
hadamardRotate16<__nv_bfloat16>(data, n);
233+
}
234+
void hadamardRotate16_fp32(float* data, const int n) {
235+
hadamardRotate16<float>(data, n);
236+
}
237+
238+
// Fused Hadamard + NVFP4 quantize wrapper functions
239+
void fusedHadamardQuantizeNVFP4_fp16(
240+
const half* input, unsigned char* output, unsigned char* block_scales,
241+
float tensor_scale, const int n
242+
) {
243+
fusedHadamardQuantizeNVFP4<half>(input, output, block_scales, tensor_scale, n);
244+
}
245+
void fusedHadamardQuantizeNVFP4_bf16(
246+
const __nv_bfloat16* input, unsigned char* output, unsigned char* block_scales,
247+
float tensor_scale, const int n
248+
) {
249+
fusedHadamardQuantizeNVFP4<__nv_bfloat16>(input, output, block_scales, tensor_scale, n);
250+
}
251+
void fusedHadamardQuantizeNVFP4_fp32(
252+
const float* input, unsigned char* output, unsigned char* block_scales,
253+
float tensor_scale, const int n
254+
) {
255+
fusedHadamardQuantizeNVFP4<float>(input, output, block_scales, tensor_scale, n);
256+
}
257+
227258
// NVFP4 dequantize wrapper functions
228259
void dequantizeNVFP4_fp16(
229260
const unsigned char* input, const unsigned char* block_scales,
@@ -532,6 +563,37 @@ void cdequantize_blockwise_bf16_nf4(
532563
dequantizeBlockwise_bf16_nf4(code, A, absmax, out, blocksize, n, stream);
533564
}
534565

566+
// Hadamard rotation extern "C" wrappers
567+
void chadamard_rotate16_fp16(half* data, const int n) {
568+
hadamardRotate16_fp16(data, n);
569+
}
570+
void chadamard_rotate16_bf16(__nv_bfloat16* data, const int n) {
571+
hadamardRotate16_bf16(data, n);
572+
}
573+
void chadamard_rotate16_fp32(float* data, const int n) {
574+
hadamardRotate16_fp32(data, n);
575+
}
576+
577+
// Fused Hadamard + NVFP4 quantize extern "C" wrappers
578+
void cfused_hadamard_quantize_nvfp4_fp16(
579+
const half* input, unsigned char* output, unsigned char* block_scales,
580+
float tensor_scale, const int n
581+
) {
582+
fusedHadamardQuantizeNVFP4_fp16(input, output, block_scales, tensor_scale, n);
583+
}
584+
void cfused_hadamard_quantize_nvfp4_bf16(
585+
const __nv_bfloat16* input, unsigned char* output, unsigned char* block_scales,
586+
float tensor_scale, const int n
587+
) {
588+
fusedHadamardQuantizeNVFP4_bf16(input, output, block_scales, tensor_scale, n);
589+
}
590+
void cfused_hadamard_quantize_nvfp4_fp32(
591+
const float* input, unsigned char* output, unsigned char* block_scales,
592+
float tensor_scale, const int n
593+
) {
594+
fusedHadamardQuantizeNVFP4_fp32(input, output, block_scales, tensor_scale, n);
595+
}
596+
535597
// NVFP4 quantize extern "C" wrappers
536598
void cquantize_nvfp4_fp16(
537599
const half* input, unsigned char* output, unsigned char* block_scales,

0 commit comments

Comments
 (0)