@@ -54,12 +54,16 @@ std::vector<InputType> create_transpose(const InputType* const input, const size
5454}
5555
5656// Compute the global encode scale factor for a given global amax
57- float compute_global_encode_scaling_factor_FP4 (const float global_amax) {
57+ float compute_global_encode_scaling_factor_FP4 (const float global_amax, const bool use_fast_math ) {
5858 constexpr float fp8_max = 448 .0f ; // 448.0f;
5959 constexpr float fp4_max = 6 .0f ; // 6.0f;
6060 float global_encode_scale = fp8_max * fp4_max / global_amax;
61- // If scale is infinity, return max value of float32
62- global_encode_scale = fminf (global_encode_scale, Numeric_Traits<float >::maxNorm);
61+ // If scale is infinity, return the max normalized value
62+ const float max_norm_clamp = use_fast_math
63+ ? Numeric_Traits<bf16 >::maxNorm
64+ : Numeric_Traits<float >::maxNorm;
65+
66+ global_encode_scale = fminf (global_encode_scale, max_norm_clamp);
6367 // If global amax is 0 or infinity, return 1
6468 if (global_amax == 0 .0f || global_encode_scale == 0 .0f ) {
6569 return 1 .0f ;
@@ -76,10 +80,11 @@ void quantize_nvfp4_1d(float (*OP)(const float),
7680 const size_t rows,
7781 const size_t cols,
7882 const size_t scales_stride,
79- const float global_amax) {
83+ const float global_amax,
84+ const bool use_fast_math) {
8085
8186 // Compute a global encoding/decoding scaling factor for all S_dec_b
82- const float S_enc = compute_global_encode_scaling_factor_FP4 (global_amax);
87+ const float S_enc = compute_global_encode_scaling_factor_FP4 (global_amax, use_fast_math );
8388
8489 constexpr size_t block_size_X = 16 ;
8590 const size_t blocks_X = divide_round_up (cols, block_size_X);
@@ -114,14 +119,20 @@ void quantize_nvfp4_1d(float (*OP)(const float),
114119 const float S_dec_b = block_amax / 6 .0f ;
115120
116121 // Scale & Store per-block decoding scaling factor
117- const float S_dec_b_fp8 = S_dec_b * S_enc;
122+ const fp8e4m3 S_dec_b_fp8 = static_cast <fp8e4m3>(S_dec_b * S_enc);
123+ const float S_dec_b_fp32 = static_cast <float >(S_dec_b_fp8);
118124
119125 // Compute "correct" per-block encoding scaling factor
120- const float S_enc_b_fp8 = S_dec_b_fp8 == 0 ? 0 .f : S_enc / S_dec_b_fp8 ;
126+ const float S_enc_b_fp8 = S_dec_b_fp32 == 0 . f ? 0 .f : S_enc / S_dec_b_fp32 ;
121127
122128 const size_t scale_idx = i * scales_stride + block_X;
123- scales[scale_idx] = static_cast <fp8e4m3>(S_dec_b_fp8);
124- const float scale_reciprocal = S_enc_b_fp8;
129+ scales[scale_idx] = S_dec_b_fp8;
130+
131+ float scale_reciprocal = S_enc_b_fp8;
132+ if (use_fast_math) {
133+ // Numerical truncation to match GPU implementation, if mixed precision FMA instruction is used
134+ scale_reciprocal = static_cast <float >(static_cast <bf16 >(scale_reciprocal));
135+ }
125136
126137 for (size_t j = j_min; j < j_max; j += 2 ) {
127138 const int idx_pair = (i * cols + j) / 2 ;
@@ -136,7 +147,7 @@ void quantize_nvfp4_1d(float (*OP)(const float),
136147 fp4e2m1x2 casted_to_e2m1_pair (scaled_elt_pair);
137148 output[idx_pair] = casted_to_e2m1_pair;
138149
139- // const double2 truncated_pair = cvt_fp4x2_to_double2(casted_to_e2m1_pair);
150+ const double2 truncated_pair = cvt_fp4x2_to_double2 (casted_to_e2m1_pair);
140151 }
141152 }
142153 }
@@ -149,9 +160,10 @@ void compute_2d_mathematical_scales(float (*OP)(const float),
149160 const size_t rows,
150161 const size_t cols,
151162 const float global_amax,
152- std::vector<std::vector<fp8e4m3>>& math_scales) {
163+ std::vector<std::vector<fp8e4m3>>& math_scales,
164+ const bool use_fast_math) {
153165
154- const float S_enc = compute_global_encode_scaling_factor_FP4 (global_amax);
166+ const float S_enc = compute_global_encode_scaling_factor_FP4 (global_amax, use_fast_math );
155167 constexpr size_t block_size_Y = 16 ;
156168 constexpr size_t block_size_X = 16 ;
157169 const size_t blocks_Y = divide_round_up (rows, block_size_Y);
@@ -195,13 +207,14 @@ void quantize_nvfp4_2d(float (*OP)(const float),
195207 const size_t rows,
196208 const size_t cols,
197209 const size_t scales_stride,
198- const float global_amax) {
210+ const float global_amax,
211+ const bool use_fast_math) {
199212
200213 // Step 1: Compute mathematical 8x8 scaling factors
201214 std::vector<std::vector<fp8e4m3>> math_scales;
202- compute_2d_mathematical_scales (OP, input, rows, cols, global_amax, math_scales);
215+ compute_2d_mathematical_scales (OP, input, rows, cols, global_amax, math_scales, use_fast_math );
203216
204- const float S_enc = compute_global_encode_scaling_factor_FP4 (global_amax);
217+ const float S_enc = compute_global_encode_scaling_factor_FP4 (global_amax, use_fast_math );
205218 constexpr size_t block_size_Y = 16 ;
206219 constexpr size_t block_size_X = 16 ;
207220 const size_t blocks_Y = divide_round_up (rows, block_size_Y);
@@ -282,11 +295,12 @@ void quantize_nvfp4(float (*OP)(const float),
282295 const size_t cols,
283296 const size_t scales_stride,
284297 const float global_amax,
298+ const bool use_fast_math,
285299 const bool use_2d_quantization = false) {
286300 if (use_2d_quantization) {
287- quantize_nvfp4_2d (OP, input, output, scales, rows, cols, scales_stride, global_amax);
301+ quantize_nvfp4_2d (OP, input, output, scales, rows, cols, scales_stride, global_amax, use_fast_math );
288302 } else {
289- quantize_nvfp4_1d (OP, input, output, scales, rows, cols, scales_stride, global_amax);
303+ quantize_nvfp4_1d (OP, input, output, scales, rows, cols, scales_stride, global_amax, use_fast_math );
290304 }
291305}
292306
@@ -302,14 +316,15 @@ void compute_ref(float (*OP)(const float),
302316 const size_t cols,
303317 const size_t scales_stride,
304318 const size_t scales_stride_t,
319+ const bool use_fast_math,
305320 const bool use_2d_quantization = false)
306321{
307322 std::vector<InputType> input_t = create_transpose (input, rows, cols);
308323
309324 if (use_2d_quantization) {
310325 // Step 1: Compute mathematical 8×8 scaling factors
311326 std::vector<std::vector<fp8e4m3>> math_scales;
312- compute_2d_mathematical_scales (OP, input, rows, cols, global_amax, math_scales);
327+ compute_2d_mathematical_scales (OP, input, rows, cols, global_amax, math_scales, use_fast_math );
313328
314329 constexpr size_t block_size_Y = 16 ;
315330 constexpr size_t block_size_X = 16 ;
@@ -336,19 +351,25 @@ void compute_ref(float (*OP)(const float),
336351
337352 // Step 4: Process quantized outputs using the same algorithm as quantize_nvfp4_2d
338353 // (This part processes the actual FP4 data using the mathematical scaling factors)
339- quantize_nvfp4_2d (OP, input, output, nullptr , rows, cols, scales_stride, global_amax); // scales already filled
340- quantize_nvfp4_2d (OP, input_t .data (), output_t , nullptr , cols, rows, scales_stride_t , global_amax); // scales_t already filled
354+ quantize_nvfp4_2d (OP, input, output, nullptr , rows, cols, scales_stride, global_amax,
355+ use_fast_math); // scales already filled
356+ quantize_nvfp4_2d (OP, input_t .data (), output_t , nullptr , cols, rows, scales_stride_t , global_amax,
357+ use_fast_math); // scales_t already filled
341358
342359 } else {
343- quantize_nvfp4 (OP, input, output, scales, rows, cols, scales_stride, global_amax, use_2d_quantization);
344- quantize_nvfp4 (OP, input_t .data (), output_t , scales_t , cols, rows, scales_stride_t , global_amax, use_2d_quantization);
360+ quantize_nvfp4 (OP, input, output, scales, rows, cols, scales_stride, global_amax,
361+ use_fast_math, use_2d_quantization);
362+ quantize_nvfp4 (OP, input_t .data (), output_t , scales_t , cols, rows, scales_stride_t , global_amax,
363+ use_fast_math, use_2d_quantization);
345364 }
346365}
347366
348367void compare_nvfp4_tensors (const std::string& name,
349368 const fp4e2m1 *test_data, const fp4e2m1 *ref_data,
350369 const int rows, const int cols,
351370 double atol = 1e-5 , double rtol = 1e-8 ) {
371+ constexpr int max_mismatches_to_print = 3 ;
372+
352373 std::vector<std::string> mismatch_messages;
353374 size_t total_mismatches = 0 ;
354375
@@ -362,29 +383,16 @@ void compare_nvfp4_tensors(const std::string& name,
362383 const double t = (k == 0 ? test_data_pair.x : test_data_pair.y );
363384 const double r = (k == 0 ? ref_data_pair.x : ref_data_pair.y );
364385
365- bool mismatch = fabs (t - r) > atol && (r == 0 || fabs ((t - r) / r) > rtol);
366- /* For Float32 the floating point comparison is enough to error out */
367- bool assertion = false ;
368- if (mismatch && !assertion) {
369- /* Check if it is just a failure of round to nearest choosing different
370- side of the real value */
371- const double mean = (t + r) / 2 ;
372- const double mean_p = mean >= 0 ? mean * (1 + 1e-6 ) : mean * (1 - 1e-6 );
373- const double mean_m = mean >= 0 ? mean * (1 - 1e-6 ) : mean * (1 + 1e-6 );
374- const double cast_mean_p = static_cast <double >(static_cast <fp4e2m1>(mean_p));
375- const double cast_mean_m = static_cast <double >(static_cast <fp4e2m1>(mean_m));
376- assertion = !(cast_mean_m == std::min (t,r) && cast_mean_p == std::max (t,r));
377- }
378- if (assertion) {
386+ const bool mismatch = fabs (t - r) > (atol + fabs (r) * rtol);
387+ if (mismatch) {
379388 total_mismatches++;
380- std::string msg = " Mismatch at place (" + std::to_string (idx + k) + " ): " +
381- std::to_string (t) + " vs " + std::to_string (r) +
382- " (abs_diff: " + std::to_string (fabs (t - r)) +
383- " , rel_diff: " + std::to_string (r == 0 ? 0.0 : fabs ((t - r) / r)) + " )" ;
384- mismatch_messages.push_back (msg);
385-
386389 // Optional: limit number of detailed messages to avoid overwhelming output
387- if (mismatch_messages.size () <= 100 ) {
390+ if (total_mismatches <= max_mismatches_to_print) {
391+ std::string msg = " Mismatch at place (" + std::to_string (idx + k) + " ): " +
392+ std::to_string (t) + " vs " + std::to_string (r) +
393+ " (abs_diff: " + std::to_string (fabs (t - r)) +
394+ " , rel_diff: " + std::to_string (r == 0 ? 0.0 : fabs ((t - r) / r)) + " )" ;
395+ mismatch_messages.push_back (msg);
388396 std::cout << " Error in tensor " << name << " : " << msg << std::endl;
389397 }
390398 }
@@ -400,8 +408,9 @@ void compare_nvfp4_tensors(const std::string& name,
400408 std::cout << " STATUS: FAILED for output" << std::endl;
401409 std::cout << " Total mismatches found: " << total_mismatches << std::endl;
402410 std::cout << " Mismatch rate: " << (100.0 * total_mismatches) / (rows * cols) << " %" << std::endl;
403- if (mismatch_messages.size () > 100 ) {
404- std::cout << " ... and " << (mismatch_messages.size () - 100 ) << " more mismatches (showing first 100)" << std::endl;
411+ if (mismatch_messages.size () > max_mismatches_to_print) {
412+ std::cout << " ... and " << (mismatch_messages.size () - max_mismatches_to_print)
413+ << " more mismatches (showing first " << max_mismatches_to_print << " )" << std::endl;
405414 }
406415 std::cout << " ============================" << std::endl;
407416
@@ -519,7 +528,8 @@ void compareResults_nvfp4(const Tensor &test,
519528
520529template <typename InputType>
521530void performTest (float (*OP)(const float ),
522- const std::vector<size_t>& shape) {
531+ const std::vector<size_t>& shape,
532+ const bool use_fast_math) {
523533 using namespace test ;
524534
525535 DType itype = TypeInfo<InputType>::dtype;
@@ -580,15 +590,16 @@ void performTest(float (*OP)(const float),
580590 cols,
581591 scales_stride,
582592 scales_stride_t ,
593+ use_fast_math,
583594 use_2d_quantization);
584-
585- QuantizationConfigWrapper quant_config;
586-
587595 // Initialize stochastic rounding
588596 Tensor rng_state (" rng_state" , std::vector<size_t >{2 }, DType::kInt64 );
589597 rng_state.rowwise_cpu_dptr <int64_t >()[0 ] = 123 ; // rng_seed
590598 rng_state.rowwise_cpu_dptr <int64_t >()[1 ] = 321 ; // rng_sequence
591599 rng_state.from_cpu ();
600+
601+ QuantizationConfigWrapper quant_config;
602+ quant_config.set_use_fast_math (use_fast_math);
592603 quant_config.set_stochastic_rounding (false );
593604 quant_config.set_rng_state (rng_state.data ());
594605
@@ -619,8 +630,8 @@ void performTest(float (*OP)(const float),
619630 }
620631 ASSERT_EQ (err, cudaSuccess) << cudaGetErrorString (err);
621632
622- const double atol = 0.05 ;
623- const double rtol = 0.1 ;
633+ const double atol = 1.0E-6 ;
634+ const double rtol = 1.0E-6 ;
624635
625636 // Set dump_data=true to enable dumping tensor data to files for analysis
626637 compareResults_nvfp4 (output, ref_output.get (), ref_output_t .get (), rows, cols, atol, rtol, true , false );
@@ -666,12 +677,18 @@ std::vector<ActivationType> Activation_types = {
666677 ActivationType::Identity
667678};
668679
680+ std::vector<bool > use_fast_nvfp4_scaling_vec = {
681+ false ,
682+ true
683+ };
684+
669685} // namespace
670686
671687class FusedCastTransposeNVFP4TestSuite : public ::testing::TestWithParam
672688 <std::tuple<ActivationType,
673689 std::vector<size_t >,
674- transformer_engine::DType>> {};
690+ transformer_engine::DType,
691+ bool >> {};
675692
676693TEST_P (FusedCastTransposeNVFP4TestSuite, TestFusedCastTransposeNVFP4) {
677694 // Skip tests for pre-Blackwell architectures
@@ -685,6 +702,7 @@ TEST_P(FusedCastTransposeNVFP4TestSuite, TestFusedCastTransposeNVFP4) {
685702 const ActivationType Act_type = std::get<0 >(GetParam ());
686703 const auto tensor_dims = std::get<1 >(GetParam ());
687704 const DType input_type = std::get<2 >(GetParam ());
705+ const bool use_fast_math = std::get<3 >(GetParam ());
688706
689707 // Skip tests if the input tensor is 1D
690708 if (tensor_dims.size () < 2 ) {
@@ -702,7 +720,7 @@ TEST_P(FusedCastTransposeNVFP4TestSuite, TestFusedCastTransposeNVFP4) {
702720 }
703721
704722 TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY (input_type, InputType,
705- performTest<InputType>(OP, tensor_dims);
723+ performTest<InputType>(OP, tensor_dims, use_fast_math );
706724 );
707725}
708726
@@ -724,13 +742,17 @@ INSTANTIATE_TEST_SUITE_P(
724742 ::testing::Combine (
725743 ::testing::ValuesIn (Activation_types),
726744 ::testing::ValuesIn(tensor_dims),
727- ::testing::Values(DType::kBFloat16 )),
745+ ::testing::Values(DType::kBFloat16 ),
746+ ::testing::ValuesIn(use_fast_nvfp4_scaling_vec)),
728747 [](const testing::TestParamInfo<FusedCastTransposeNVFP4TestSuite::ParamType>& info) {
729748 std::string name = to_string (std::get<0 >(info.param ));
730749 const auto & shape = std::get<1 >(info.param );
731750 for ( const auto & s: shape) {
732751 name += " X" + std::to_string (s);
733752 }
734753 name += " X" + test::typeName (std::get<2 >(info.param ));
754+ if (std::get<3 >(info.param )) {
755+ name += " X_FAST_SCALING" ;
756+ }
735757 return name;
736758 });
0 commit comments