@@ -121,6 +121,205 @@ __device__ unsigned char dQuantizeFP4(float x) {
121121
122122__device__ __forceinline__ float dDequantizeNF4 (unsigned char val) { return nf4_dequantization_lut[val & 0x0F ]; }
123123
124+ // ============================================================================
125+ // NVFP4 (E2M1) device functions
126+ // E2M1 format: 1 sign + 2 exponent (bias=1) + 1 mantissa
127+ // Representable magnitudes: {0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0}
128+ // ============================================================================
129+
130+ // E2M1 dequantization LUT - maps 3-bit unsigned magnitude code to float
131+ __device__ static float nvfp4_dequant_lut[8 ] = {
132+ 0 .0f , 0 .5f , 1 .0f , 1 .5f , 2 .0f , 3 .0f , 4 .0f , 6 .0f
133+ };
134+
135+ // Dequantize a 4-bit E2M1 code to float
136+ // Bit layout: [sign(1) | exponent(2) | mantissa(1)]
137+ __device__ __forceinline__ float dDequantizeNVFP4 (unsigned char val) {
138+ float sign = (val & 0x08 ) ? -1 .0f : 1 .0f ;
139+ return nvfp4_dequant_lut[val & 0x07 ] * sign;
140+ }
141+
142+ // Quantize a float to 4-bit E2M1 code using round-to-nearest
143+ // Input should be pre-scaled so that the representable range [-6, 6] is appropriate
144+ __device__ unsigned char dQuantizeNVFP4 (float x) {
145+ unsigned char sign = (x < 0 .0f ) ? 0x08 : 0x00 ;
146+ float ax = fabsf (x);
147+
148+ // Decision boundaries are midpoints between adjacent representable values
149+ unsigned char code;
150+ if (ax > 5 .0f )
151+ code = 0x07 ; // 6.0
152+ else if (ax > 3 .5f )
153+ code = 0x06 ; // 4.0
154+ else if (ax > 2 .5f )
155+ code = 0x05 ; // 3.0
156+ else if (ax > 1 .75f )
157+ code = 0x04 ; // 2.0
158+ else if (ax > 1 .25f )
159+ code = 0x03 ; // 1.5
160+ else if (ax > 0 .75f )
161+ code = 0x02 ; // 1.0
162+ else if (ax > 0 .25f )
163+ code = 0x01 ; // 0.5
164+ else
165+ code = 0x00 ; // 0.0
166+
167+ return code | sign;
168+ }
169+
170+ // Convert positive float to unsigned E4M3 (8-bit: 4 exponent bits, bias=7, 3 mantissa bits)
171+ // Range: [0, 448]. Used for NVFP4 block scale factors.
172+ __device__ unsigned char dFloatToE4M3 (float x) {
173+ if (x <= 0 .0f ) return 0 ;
174+ if (x >= 448 .0f ) return 0x7E ; // Max normal (exp=14, mant=6). exp=15 mant=7 is NaN.
175+
176+ unsigned int bits = __float_as_uint (x);
177+ int fp32_exp = ((bits >> 23 ) & 0xFF ) - 127 ; // Unbiased FP32 exponent
178+ int e4m3_exp = fp32_exp + 7 ; // E4M3 bias is 7
179+
180+ if (e4m3_exp <= 0 ) {
181+ // Subnormal in E4M3: value = mantissa/8 * 2^(-6)
182+ int mant = __float2int_rn (x * 512 .0f ); // 512 = 8 * 2^6
183+ if (mant <= 0 ) return 0 ;
184+ if (mant > 7 ) mant = 7 ;
185+ return (unsigned char )mant;
186+ }
187+
188+ // Normal: extract top 3 mantissa bits with round-to-nearest
189+ unsigned int fp32_mant = bits & 0x7FFFFF ;
190+ unsigned int mant_3bit = (fp32_mant + (1 << 19 )) >> 20 ;
191+
192+ if (mant_3bit >= 8 ) {
193+ mant_3bit = 0 ;
194+ e4m3_exp++;
195+ }
196+
197+ if (e4m3_exp > 15 ) return 0x7E ;
198+ if (e4m3_exp == 15 && mant_3bit >= 7 ) return 0x7E ; // Clamp, don't produce NaN
199+
200+ return (unsigned char )((e4m3_exp << 3 ) | mant_3bit);
201+ }
202+
203+ // Convert unsigned E4M3 byte to float
204+ __device__ float dE4M3ToFloat (unsigned char val) {
205+ if (val == 0 ) return 0 .0f ;
206+
207+ int exp = (val >> 3 ) & 0x0F ;
208+ int mant = val & 0x07 ;
209+
210+ if (exp == 0 ) {
211+ // Subnormal: value = mant/8 * 2^(1-7) = mant / 512
212+ return (float )mant / 512 .0f ;
213+ }
214+
215+ // Normal: value = (1 + mant/8) * 2^(exp-7)
216+ return (1 .0f + (float )mant * 0 .125f ) * exp2f ((float )(exp - 7 ));
217+ }
218+
219+ // ============================================================================
220+ // NVFP4 quantization kernel
221+ // Two-level scaling: FP32 tensor_scale + E4M3 block_scale (per 16 elements)
222+ // Input: T* tensor, float tensor_scale (precomputed)
223+ // Output: packed uint8 (2 values per byte), uint8 block_scales (E4M3)
224+ // ============================================================================
225+ template <typename T>
226+ __global__ void kQuantizeNVFP4 (
227+ const T* __restrict__ input,
228+ unsigned char * __restrict__ output, // Packed FP4: n/2 bytes
229+ unsigned char * __restrict__ block_scales, // E4M3 scales: n/16 bytes
230+ const float tensor_scale,
231+ const int n
232+ ) {
233+ // Each thread handles 2 consecutive elements (packs into 1 byte)
234+ // 8 threads per 16-element quantization block
235+ const int tid = blockIdx .x * blockDim .x + threadIdx .x ;
236+ const int element_idx = tid * 2 ;
237+
238+ if (element_idx >= n) return ;
239+
240+ const float inv_tensor_scale = (tensor_scale > 0 .0f ) ? (1 .0f / tensor_scale) : 0 .0f ;
241+
242+ // Load 2 elements, divide by tensor_scale
243+ float val0 = (element_idx < n) ? (float )input[element_idx] * inv_tensor_scale : 0 .0f ;
244+ float val1 = (element_idx + 1 < n) ? (float )input[element_idx + 1 ] * inv_tensor_scale : 0 .0f ;
245+
246+ // Compute per-thread absmax
247+ float local_max = fmaxf (fabsf (val0), fabsf (val1));
248+
249+ // Warp-shuffle reduction within 8-thread quantization block
250+ // Threads 0-7 handle block 0, 8-15 handle block 1, etc.
251+ // XOR offsets 4, 2, 1 stay within each 8-thread group
252+ #pragma unroll
253+ for (int offset = 4 ; offset >= 1 ; offset >>= 1 ) {
254+ float other = __shfl_xor_sync (0xFFFFFFFF , local_max, offset);
255+ local_max = fmaxf (local_max, other);
256+ }
257+
258+ // Compute E4M3 block scale: block_absmax / 6.0 (E2M1 max)
259+ float block_scale_f32 = local_max / 6 .0f ;
260+ unsigned char block_scale_e4m3 = dFloatToE4M3 (block_scale_f32);
261+ float block_scale_deq = dE4M3ToFloat (block_scale_e4m3);
262+
263+ // Avoid division by zero for all-zero blocks
264+ float inv_block_scale = (block_scale_deq > 0 .0f ) ? (1 .0f / block_scale_deq) : 0 .0f ;
265+
266+ // Store block scale (first thread in each 8-thread group)
267+ int lane_in_block = threadIdx .x & 7 ;
268+ if (lane_in_block == 0 ) {
269+ int block_idx = element_idx / 16 ;
270+ block_scales[block_idx] = block_scale_e4m3;
271+ }
272+
273+ // Quantize values to E2M1
274+ unsigned char q0 = dQuantizeNVFP4 (val0 * inv_block_scale);
275+ unsigned char q1 = dQuantizeNVFP4 (val1 * inv_block_scale);
276+
277+ // Pack: low nibble = first element, high nibble = second element
278+ unsigned char packed = ((q1 & 0x0F ) << 4 ) | (q0 & 0x0F );
279+
280+ // Store packed byte
281+ output[element_idx / 2 ] = packed;
282+ }
283+
284+ // ============================================================================
285+ // NVFP4 dequantization kernel
286+ // Reverses the two-level scaling: unpacks FP4, multiplies by block_scale * tensor_scale
287+ // ============================================================================
288+ template <typename T>
289+ __global__ void kDequantizeNVFP4 (
290+ const unsigned char * __restrict__ input, // Packed FP4: n/2 bytes
291+ const unsigned char * __restrict__ block_scales, // E4M3 scales: n/16 bytes
292+ const float tensor_scale,
293+ T* __restrict__ output,
294+ const int n
295+ ) {
296+ const int tid = blockIdx .x * blockDim .x + threadIdx .x ;
297+ const int element_idx = tid * 2 ;
298+
299+ if (element_idx >= n) return ;
300+
301+ // Load and unpack
302+ unsigned char packed = input[element_idx / 2 ];
303+ unsigned char q0 = packed & 0x0F ; // Low nibble
304+ unsigned char q1 = (packed >> 4 ) & 0x0F ; // High nibble
305+
306+ // Load block scale
307+ int block_idx = element_idx / 16 ;
308+ float block_scale_f32 = dE4M3ToFloat (block_scales[block_idx]);
309+
310+ // Combined scale factor
311+ float scale = block_scale_f32 * tensor_scale;
312+
313+ // Dequantize and write
314+ float val0 = dDequantizeNVFP4 (q0) * scale;
315+ float val1 = dDequantizeNVFP4 (q1) * scale;
316+
317+ if (element_idx < n)
318+ output[element_idx] = (T)val0;
319+ if (element_idx + 1 < n)
320+ output[element_idx + 1 ] = (T)val1;
321+ }
322+
124323__device__ unsigned char dQuantizeNF4 (float x) {
125324
126325 // the values for this tree was generated by test_normal_map_tree
@@ -2567,6 +2766,32 @@ template __global__ void kDequantizeBlockwise<__nv_bfloat16, 512, 64, 8, NF4>(
25672766 float * code, unsigned char * A, float * absmax, __nv_bfloat16* out, const int blocksize, const int n
25682767);
25692768
2769+ // NVFP4 kernel template instantiations
2770+ template __global__ void kQuantizeNVFP4 <half>(
2771+ const half* __restrict__ input, unsigned char * __restrict__ output,
2772+ unsigned char * __restrict__ block_scales, const float tensor_scale, const int n
2773+ );
2774+ template __global__ void kQuantizeNVFP4 <__nv_bfloat16>(
2775+ const __nv_bfloat16* __restrict__ input, unsigned char * __restrict__ output,
2776+ unsigned char * __restrict__ block_scales, const float tensor_scale, const int n
2777+ );
2778+ template __global__ void kQuantizeNVFP4 <float >(
2779+ const float * __restrict__ input, unsigned char * __restrict__ output,
2780+ unsigned char * __restrict__ block_scales, const float tensor_scale, const int n
2781+ );
2782+ template __global__ void kDequantizeNVFP4 <half>(
2783+ const unsigned char * __restrict__ input, const unsigned char * __restrict__ block_scales,
2784+ const float tensor_scale, half* __restrict__ output, const int n
2785+ );
2786+ template __global__ void kDequantizeNVFP4 <__nv_bfloat16>(
2787+ const unsigned char * __restrict__ input, const unsigned char * __restrict__ block_scales,
2788+ const float tensor_scale, __nv_bfloat16* __restrict__ output, const int n
2789+ );
2790+ template __global__ void kDequantizeNVFP4 <float >(
2791+ const unsigned char * __restrict__ input, const unsigned char * __restrict__ block_scales,
2792+ const float tensor_scale, float * __restrict__ output, const int n
2793+ );
2794+
25702795#define MAKE_OptimizerStatic8bit2StateBlockwise (oname, gtype, block_size, num_per_thread ) \
25712796 template __global__ void kOptimizerStatic8bit2StateBlockwise <gtype, oname, block_size, num_per_thread>( \
25722797 gtype * p, gtype* __restrict__ const g, unsigned char * state1, unsigned char * state2, const float beta1, \
0 commit comments