Skip to content

Commit 4d8db39

Browse files
TimDettmersclaude
andcommitted
feat: Add NVFP4 (E2M1) quantize/dequantize CUDA kernels
Implements two-level block-scaled NVFP4 quantization: - E2M1 quantize/dequantize device functions with decision-tree and LUT - E4M3 float conversion helpers for block scale factors - kQuantizeNVFP4: FP16/BF16/FP32 -> packed FP4 + E4M3 block scales - kDequantizeNVFP4: packed FP4 + scales -> FP16/BF16/FP32 - Host launchers, template instantiations, extern C symbols - NVFP4=3 added to DataType_t enum Block size fixed at 16 (hardware requirement). Two-level scaling: FP32 tensor_scale + unsigned E4M3 per-block scale. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 0ee29e5 commit 4d8db39

File tree

6 files changed

+389
-0
lines changed

6 files changed

+389
-0
lines changed

csrc/common.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,5 @@ typedef enum DataType_t {
44
General8bit = 0,
55
FP4 = 1,
66
NF4 = 2,
7+
NVFP4 = 3,
78
} DataType_t;

csrc/kernels.cu

Lines changed: 225 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,205 @@ __device__ unsigned char dQuantizeFP4(float x) {
121121

122122
__device__ __forceinline__ float dDequantizeNF4(unsigned char val) { return nf4_dequantization_lut[val & 0x0F]; }
123123

124+
// ============================================================================
125+
// NVFP4 (E2M1) device functions
126+
// E2M1 format: 1 sign + 2 exponent (bias=1) + 1 mantissa
127+
// Representable magnitudes: {0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0}
128+
// ============================================================================
129+
130+
// E2M1 dequantization LUT - maps 3-bit unsigned magnitude code to float
131+
__device__ static float nvfp4_dequant_lut[8] = {
132+
0.0f, 0.5f, 1.0f, 1.5f, 2.0f, 3.0f, 4.0f, 6.0f
133+
};
134+
135+
// Dequantize a 4-bit E2M1 code to float
136+
// Bit layout: [sign(1) | exponent(2) | mantissa(1)]
137+
__device__ __forceinline__ float dDequantizeNVFP4(unsigned char val) {
138+
float sign = (val & 0x08) ? -1.0f : 1.0f;
139+
return nvfp4_dequant_lut[val & 0x07] * sign;
140+
}
141+
142+
// Quantize a float to 4-bit E2M1 code using round-to-nearest
143+
// Input should be pre-scaled so that the representable range [-6, 6] is appropriate
144+
__device__ unsigned char dQuantizeNVFP4(float x) {
145+
unsigned char sign = (x < 0.0f) ? 0x08 : 0x00;
146+
float ax = fabsf(x);
147+
148+
// Decision boundaries are midpoints between adjacent representable values
149+
unsigned char code;
150+
if (ax > 5.0f)
151+
code = 0x07; // 6.0
152+
else if (ax > 3.5f)
153+
code = 0x06; // 4.0
154+
else if (ax > 2.5f)
155+
code = 0x05; // 3.0
156+
else if (ax > 1.75f)
157+
code = 0x04; // 2.0
158+
else if (ax > 1.25f)
159+
code = 0x03; // 1.5
160+
else if (ax > 0.75f)
161+
code = 0x02; // 1.0
162+
else if (ax > 0.25f)
163+
code = 0x01; // 0.5
164+
else
165+
code = 0x00; // 0.0
166+
167+
return code | sign;
168+
}
169+
170+
// Convert positive float to unsigned E4M3 (8-bit: 4 exponent bits, bias=7, 3 mantissa bits)
171+
// Range: [0, 448]. Used for NVFP4 block scale factors.
172+
__device__ unsigned char dFloatToE4M3(float x) {
173+
if (x <= 0.0f) return 0;
174+
if (x >= 448.0f) return 0x7E; // Max normal (exp=14, mant=6). exp=15 mant=7 is NaN.
175+
176+
unsigned int bits = __float_as_uint(x);
177+
int fp32_exp = ((bits >> 23) & 0xFF) - 127; // Unbiased FP32 exponent
178+
int e4m3_exp = fp32_exp + 7; // E4M3 bias is 7
179+
180+
if (e4m3_exp <= 0) {
181+
// Subnormal in E4M3: value = mantissa/8 * 2^(-6)
182+
int mant = __float2int_rn(x * 512.0f); // 512 = 8 * 2^6
183+
if (mant <= 0) return 0;
184+
if (mant > 7) mant = 7;
185+
return (unsigned char)mant;
186+
}
187+
188+
// Normal: extract top 3 mantissa bits with round-to-nearest
189+
unsigned int fp32_mant = bits & 0x7FFFFF;
190+
unsigned int mant_3bit = (fp32_mant + (1 << 19)) >> 20;
191+
192+
if (mant_3bit >= 8) {
193+
mant_3bit = 0;
194+
e4m3_exp++;
195+
}
196+
197+
if (e4m3_exp > 15) return 0x7E;
198+
if (e4m3_exp == 15 && mant_3bit >= 7) return 0x7E; // Clamp, don't produce NaN
199+
200+
return (unsigned char)((e4m3_exp << 3) | mant_3bit);
201+
}
202+
203+
// Convert unsigned E4M3 byte to float
204+
__device__ float dE4M3ToFloat(unsigned char val) {
205+
if (val == 0) return 0.0f;
206+
207+
int exp = (val >> 3) & 0x0F;
208+
int mant = val & 0x07;
209+
210+
if (exp == 0) {
211+
// Subnormal: value = mant/8 * 2^(1-7) = mant / 512
212+
return (float)mant / 512.0f;
213+
}
214+
215+
// Normal: value = (1 + mant/8) * 2^(exp-7)
216+
return (1.0f + (float)mant * 0.125f) * exp2f((float)(exp - 7));
217+
}
218+
219+
// ============================================================================
220+
// NVFP4 quantization kernel
221+
// Two-level scaling: FP32 tensor_scale + E4M3 block_scale (per 16 elements)
222+
// Input: T* tensor, float tensor_scale (precomputed)
223+
// Output: packed uint8 (2 values per byte), uint8 block_scales (E4M3)
224+
// ============================================================================
225+
template <typename T>
226+
__global__ void kQuantizeNVFP4(
227+
const T* __restrict__ input,
228+
unsigned char* __restrict__ output, // Packed FP4: n/2 bytes
229+
unsigned char* __restrict__ block_scales, // E4M3 scales: n/16 bytes
230+
const float tensor_scale,
231+
const int n
232+
) {
233+
// Each thread handles 2 consecutive elements (packs into 1 byte)
234+
// 8 threads per 16-element quantization block
235+
const int tid = blockIdx.x * blockDim.x + threadIdx.x;
236+
const int element_idx = tid * 2;
237+
238+
if (element_idx >= n) return;
239+
240+
const float inv_tensor_scale = (tensor_scale > 0.0f) ? (1.0f / tensor_scale) : 0.0f;
241+
242+
// Load 2 elements, divide by tensor_scale
243+
float val0 = (element_idx < n) ? (float)input[element_idx] * inv_tensor_scale : 0.0f;
244+
float val1 = (element_idx + 1 < n) ? (float)input[element_idx + 1] * inv_tensor_scale : 0.0f;
245+
246+
// Compute per-thread absmax
247+
float local_max = fmaxf(fabsf(val0), fabsf(val1));
248+
249+
// Warp-shuffle reduction within 8-thread quantization block
250+
// Threads 0-7 handle block 0, 8-15 handle block 1, etc.
251+
// XOR offsets 4, 2, 1 stay within each 8-thread group
252+
#pragma unroll
253+
for (int offset = 4; offset >= 1; offset >>= 1) {
254+
float other = __shfl_xor_sync(0xFFFFFFFF, local_max, offset);
255+
local_max = fmaxf(local_max, other);
256+
}
257+
258+
// Compute E4M3 block scale: block_absmax / 6.0 (E2M1 max)
259+
float block_scale_f32 = local_max / 6.0f;
260+
unsigned char block_scale_e4m3 = dFloatToE4M3(block_scale_f32);
261+
float block_scale_deq = dE4M3ToFloat(block_scale_e4m3);
262+
263+
// Avoid division by zero for all-zero blocks
264+
float inv_block_scale = (block_scale_deq > 0.0f) ? (1.0f / block_scale_deq) : 0.0f;
265+
266+
// Store block scale (first thread in each 8-thread group)
267+
int lane_in_block = threadIdx.x & 7;
268+
if (lane_in_block == 0) {
269+
int block_idx = element_idx / 16;
270+
block_scales[block_idx] = block_scale_e4m3;
271+
}
272+
273+
// Quantize values to E2M1
274+
unsigned char q0 = dQuantizeNVFP4(val0 * inv_block_scale);
275+
unsigned char q1 = dQuantizeNVFP4(val1 * inv_block_scale);
276+
277+
// Pack: low nibble = first element, high nibble = second element
278+
unsigned char packed = ((q1 & 0x0F) << 4) | (q0 & 0x0F);
279+
280+
// Store packed byte
281+
output[element_idx / 2] = packed;
282+
}
283+
284+
// ============================================================================
285+
// NVFP4 dequantization kernel
286+
// Reverses the two-level scaling: unpacks FP4, multiplies by block_scale * tensor_scale
287+
// ============================================================================
288+
template <typename T>
289+
__global__ void kDequantizeNVFP4(
290+
const unsigned char* __restrict__ input, // Packed FP4: n/2 bytes
291+
const unsigned char* __restrict__ block_scales, // E4M3 scales: n/16 bytes
292+
const float tensor_scale,
293+
T* __restrict__ output,
294+
const int n
295+
) {
296+
const int tid = blockIdx.x * blockDim.x + threadIdx.x;
297+
const int element_idx = tid * 2;
298+
299+
if (element_idx >= n) return;
300+
301+
// Load and unpack
302+
unsigned char packed = input[element_idx / 2];
303+
unsigned char q0 = packed & 0x0F; // Low nibble
304+
unsigned char q1 = (packed >> 4) & 0x0F; // High nibble
305+
306+
// Load block scale
307+
int block_idx = element_idx / 16;
308+
float block_scale_f32 = dE4M3ToFloat(block_scales[block_idx]);
309+
310+
// Combined scale factor
311+
float scale = block_scale_f32 * tensor_scale;
312+
313+
// Dequantize and write
314+
float val0 = dDequantizeNVFP4(q0) * scale;
315+
float val1 = dDequantizeNVFP4(q1) * scale;
316+
317+
if (element_idx < n)
318+
output[element_idx] = (T)val0;
319+
if (element_idx + 1 < n)
320+
output[element_idx + 1] = (T)val1;
321+
}
322+
124323
__device__ unsigned char dQuantizeNF4(float x) {
125324

126325
// the values for this tree was generated by test_normal_map_tree
@@ -2567,6 +2766,32 @@ template __global__ void kDequantizeBlockwise<__nv_bfloat16, 512, 64, 8, NF4>(
25672766
float* code, unsigned char* A, float* absmax, __nv_bfloat16* out, const int blocksize, const int n
25682767
);
25692768

2769+
// NVFP4 kernel template instantiations
2770+
template __global__ void kQuantizeNVFP4<half>(
2771+
const half* __restrict__ input, unsigned char* __restrict__ output,
2772+
unsigned char* __restrict__ block_scales, const float tensor_scale, const int n
2773+
);
2774+
template __global__ void kQuantizeNVFP4<__nv_bfloat16>(
2775+
const __nv_bfloat16* __restrict__ input, unsigned char* __restrict__ output,
2776+
unsigned char* __restrict__ block_scales, const float tensor_scale, const int n
2777+
);
2778+
template __global__ void kQuantizeNVFP4<float>(
2779+
const float* __restrict__ input, unsigned char* __restrict__ output,
2780+
unsigned char* __restrict__ block_scales, const float tensor_scale, const int n
2781+
);
2782+
template __global__ void kDequantizeNVFP4<half>(
2783+
const unsigned char* __restrict__ input, const unsigned char* __restrict__ block_scales,
2784+
const float tensor_scale, half* __restrict__ output, const int n
2785+
);
2786+
template __global__ void kDequantizeNVFP4<__nv_bfloat16>(
2787+
const unsigned char* __restrict__ input, const unsigned char* __restrict__ block_scales,
2788+
const float tensor_scale, __nv_bfloat16* __restrict__ output, const int n
2789+
);
2790+
template __global__ void kDequantizeNVFP4<float>(
2791+
const unsigned char* __restrict__ input, const unsigned char* __restrict__ block_scales,
2792+
const float tensor_scale, float* __restrict__ output, const int n
2793+
);
2794+
25702795
#define MAKE_OptimizerStatic8bit2StateBlockwise(oname, gtype, block_size, num_per_thread) \
25712796
template __global__ void kOptimizerStatic8bit2StateBlockwise<gtype, oname, block_size, num_per_thread>( \
25722797
gtype * p, gtype* __restrict__ const g, unsigned char* state1, unsigned char* state2, const float beta1, \

csrc/kernels.cuh

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,17 @@ template <typename T, int BLOCK_SIZE, int THREADS, int NUM_PER_TH, int DATA_TYPE
2626
__global__ void
2727
kDequantizeBlockwise(float* code, unsigned char* A, float* absmax, T* out, const int blocksize, const int n);
2828

29+
template <typename T>
30+
__global__ void kQuantizeNVFP4(
31+
const T* __restrict__ input, unsigned char* __restrict__ output,
32+
unsigned char* __restrict__ block_scales, const float tensor_scale, const int n
33+
);
34+
template <typename T>
35+
__global__ void kDequantizeNVFP4(
36+
const unsigned char* __restrict__ input, const unsigned char* __restrict__ block_scales,
37+
const float tensor_scale, T* __restrict__ output, const int n
38+
);
39+
2940
template <typename T, int OPTIMIZER, int BLOCK_SIZE, int NUM_VALS>
3041
__global__ void kPreconditionOptimizer32bit2State(
3142
T* g, T* p, float* state1, float* state2, float* unorm, const float beta1, const float beta2, const float eps,

csrc/ops.cu

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,67 @@ void dequantizeBlockwise(
8181
CUDA_CHECK_RETURN(cudaPeekAtLastError());
8282
}
8383

84+
// ============================================================================
85+
// NVFP4 quantize/dequantize host-side launchers
86+
// ============================================================================
87+
88+
template <typename T>
89+
void quantizeNVFP4(
90+
const T* input, unsigned char* output, unsigned char* block_scales,
91+
float tensor_scale, const int n
92+
) {
93+
// Each thread handles 2 elements, so we need n/2 threads
94+
const int threads_per_block = 256;
95+
const int num_threads = (n + 1) / 2;
96+
const int num_blocks = (num_threads + threads_per_block - 1) / threads_per_block;
97+
98+
kQuantizeNVFP4<T><<<num_blocks, threads_per_block>>>(
99+
input, output, block_scales, tensor_scale, n
100+
);
101+
CUDA_CHECK_RETURN(cudaPeekAtLastError());
102+
}
103+
104+
template <typename T>
105+
void dequantizeNVFP4(
106+
const unsigned char* input, const unsigned char* block_scales,
107+
float tensor_scale, T* output, const int n, cudaStream_t stream
108+
) {
109+
const int threads_per_block = 256;
110+
const int num_threads = (n + 1) / 2;
111+
const int num_blocks = (num_threads + threads_per_block - 1) / threads_per_block;
112+
113+
kDequantizeNVFP4<T><<<num_blocks, threads_per_block, 0, stream>>>(
114+
input, block_scales, tensor_scale, output, n
115+
);
116+
CUDA_CHECK_RETURN(cudaPeekAtLastError());
117+
}
118+
119+
// NVFP4 template instantiations
120+
template void quantizeNVFP4<half>(
121+
const half* input, unsigned char* output, unsigned char* block_scales,
122+
float tensor_scale, const int n
123+
);
124+
template void quantizeNVFP4<__nv_bfloat16>(
125+
const __nv_bfloat16* input, unsigned char* output, unsigned char* block_scales,
126+
float tensor_scale, const int n
127+
);
128+
template void quantizeNVFP4<float>(
129+
const float* input, unsigned char* output, unsigned char* block_scales,
130+
float tensor_scale, const int n
131+
);
132+
template void dequantizeNVFP4<half>(
133+
const unsigned char* input, const unsigned char* block_scales,
134+
float tensor_scale, half* output, const int n, cudaStream_t stream
135+
);
136+
template void dequantizeNVFP4<__nv_bfloat16>(
137+
const unsigned char* input, const unsigned char* block_scales,
138+
float tensor_scale, __nv_bfloat16* output, const int n, cudaStream_t stream
139+
);
140+
template void dequantizeNVFP4<float>(
141+
const unsigned char* input, const unsigned char* block_scales,
142+
float tensor_scale, float* output, const int n, cudaStream_t stream
143+
);
144+
84145
template <typename T, int OPTIMIZER>
85146
void optimizer32bit(
86147
T* g, T* p, float* state1, float* state2, float* unorm, float max_unorm, float param_norm, const float beta1,

csrc/ops.cuh

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,17 @@ void dequantizeBlockwise(
120120
float* code, unsigned char* A, float* absmax, T* out, int block_size, const int n, cudaStream_t stream
121121
);
122122

123+
template <typename T>
124+
void quantizeNVFP4(
125+
const T* input, unsigned char* output, unsigned char* block_scales,
126+
float tensor_scale, const int n
127+
);
128+
template <typename T>
129+
void dequantizeNVFP4(
130+
const unsigned char* input, const unsigned char* block_scales,
131+
float tensor_scale, T* output, const int n, cudaStream_t stream
132+
);
133+
123134
template <typename T, int OPTIMIZER>
124135
void optimizer32bit(
125136
T* g, T* p, float* state1, float* state2, float* unorm, float max_unorm, float param_norm, float beta1, float beta2,

0 commit comments

Comments
 (0)