Skip to content

Commit fbb16f4

Browse files
[Common] Tuned NVFP4 cast kernel (NVIDIA#2412)
* Implemented persistent nvfp4 kernel Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix FP4 guard in ptx Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com> * Fix Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com> * Fix in ptx. reduxf32 guard Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com> * Fixes per PR review Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fixes per PR review. Added parameter to turn off the persistency Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Modified reference CPU implementation in C++ unit tests to match GPU (numerical truncation). Tightened the numerical tolerance Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com> * Disabled persistency by default, as non-persistent kernel is more performant when inputs are large Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Use the tuned kernel also for the rowwise only quantization Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com> * Fixed typo Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com> * Addressed comments from the PR review Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com> * Resolved conflicts Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Macros renaming Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com> --------- Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 27fc168 commit fbb16f4

5 files changed

Lines changed: 1184 additions & 54 deletions

File tree

tests/cpp/operator/test_cast_nvfp4_transpose.cu

Lines changed: 76 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -54,12 +54,16 @@ std::vector<InputType> create_transpose(const InputType* const input, const size
5454
}
5555

5656
// Compute the global encode scale factor for a given global amax
57-
float compute_global_encode_scaling_factor_FP4(const float global_amax) {
57+
float compute_global_encode_scaling_factor_FP4(const float global_amax, const bool use_fast_math) {
5858
constexpr float fp8_max = 448.0f; // 448.0f;
5959
constexpr float fp4_max = 6.0f; // 6.0f;
6060
float global_encode_scale = fp8_max * fp4_max / global_amax;
61-
// If scale is infinity, return max value of float32
62-
global_encode_scale = fminf(global_encode_scale, Numeric_Traits<float>::maxNorm);
61+
// If scale is infinity, return the max normalized value
62+
const float max_norm_clamp = use_fast_math
63+
? Numeric_Traits<bf16>::maxNorm
64+
: Numeric_Traits<float>::maxNorm;
65+
66+
global_encode_scale = fminf(global_encode_scale, max_norm_clamp);
6367
// If global amax is 0 or infinity, return 1
6468
if (global_amax == 0.0f || global_encode_scale == 0.0f) {
6569
return 1.0f;
@@ -76,10 +80,11 @@ void quantize_nvfp4_1d(float (*OP)(const float),
7680
const size_t rows,
7781
const size_t cols,
7882
const size_t scales_stride,
79-
const float global_amax) {
83+
const float global_amax,
84+
const bool use_fast_math) {
8085

8186
// Compute a global encoding/decoding scaling factor for all S_dec_b
82-
const float S_enc = compute_global_encode_scaling_factor_FP4(global_amax);
87+
const float S_enc = compute_global_encode_scaling_factor_FP4(global_amax, use_fast_math);
8388

8489
constexpr size_t block_size_X = 16;
8590
const size_t blocks_X = divide_round_up(cols, block_size_X);
@@ -114,14 +119,20 @@ void quantize_nvfp4_1d(float (*OP)(const float),
114119
const float S_dec_b = block_amax / 6.0f;
115120

116121
// Scale & Store per-block decoding scaling factor
117-
const float S_dec_b_fp8 = S_dec_b * S_enc;
122+
const fp8e4m3 S_dec_b_fp8 = static_cast<fp8e4m3>(S_dec_b * S_enc);
123+
const float S_dec_b_fp32 = static_cast<float>(S_dec_b_fp8);
118124

119125
// Compute "correct" per-block encoding scaling factor
120-
const float S_enc_b_fp8 = S_dec_b_fp8 == 0 ? 0.f : S_enc / S_dec_b_fp8;
126+
const float S_enc_b_fp8 = S_dec_b_fp32 == 0.f ? 0.f : S_enc / S_dec_b_fp32;
121127

122128
const size_t scale_idx = i * scales_stride + block_X;
123-
scales[scale_idx] = static_cast<fp8e4m3>(S_dec_b_fp8);
124-
const float scale_reciprocal = S_enc_b_fp8;
129+
scales[scale_idx] = S_dec_b_fp8;
130+
131+
float scale_reciprocal = S_enc_b_fp8;
132+
if (use_fast_math) {
133+
// Numerical truncation to match GPU implementation, if mixed precision FMA instruction is used
134+
scale_reciprocal = static_cast<float>(static_cast<bf16>(scale_reciprocal));
135+
}
125136

126137
for (size_t j = j_min; j < j_max; j += 2) {
127138
const int idx_pair = (i * cols + j) / 2;
@@ -136,7 +147,7 @@ void quantize_nvfp4_1d(float (*OP)(const float),
136147
fp4e2m1x2 casted_to_e2m1_pair(scaled_elt_pair);
137148
output[idx_pair] = casted_to_e2m1_pair;
138149

139-
// const double2 truncated_pair = cvt_fp4x2_to_double2(casted_to_e2m1_pair);
150+
const double2 truncated_pair = cvt_fp4x2_to_double2(casted_to_e2m1_pair);
140151
}
141152
}
142153
}
@@ -149,9 +160,10 @@ void compute_2d_mathematical_scales(float (*OP)(const float),
149160
const size_t rows,
150161
const size_t cols,
151162
const float global_amax,
152-
std::vector<std::vector<fp8e4m3>>& math_scales) {
163+
std::vector<std::vector<fp8e4m3>>& math_scales,
164+
const bool use_fast_math) {
153165

154-
const float S_enc = compute_global_encode_scaling_factor_FP4(global_amax);
166+
const float S_enc = compute_global_encode_scaling_factor_FP4(global_amax, use_fast_math);
155167
constexpr size_t block_size_Y = 16;
156168
constexpr size_t block_size_X = 16;
157169
const size_t blocks_Y = divide_round_up(rows, block_size_Y);
@@ -195,13 +207,14 @@ void quantize_nvfp4_2d(float (*OP)(const float),
195207
const size_t rows,
196208
const size_t cols,
197209
const size_t scales_stride,
198-
const float global_amax) {
210+
const float global_amax,
211+
const bool use_fast_math) {
199212

200213
// Step 1: Compute mathematical 8x8 scaling factors
201214
std::vector<std::vector<fp8e4m3>> math_scales;
202-
compute_2d_mathematical_scales(OP, input, rows, cols, global_amax, math_scales);
215+
compute_2d_mathematical_scales(OP, input, rows, cols, global_amax, math_scales, use_fast_math);
203216

204-
const float S_enc = compute_global_encode_scaling_factor_FP4(global_amax);
217+
const float S_enc = compute_global_encode_scaling_factor_FP4(global_amax, use_fast_math);
205218
constexpr size_t block_size_Y = 16;
206219
constexpr size_t block_size_X = 16;
207220
const size_t blocks_Y = divide_round_up(rows, block_size_Y);
@@ -282,11 +295,12 @@ void quantize_nvfp4(float (*OP)(const float),
282295
const size_t cols,
283296
const size_t scales_stride,
284297
const float global_amax,
298+
const bool use_fast_math,
285299
const bool use_2d_quantization = false) {
286300
if (use_2d_quantization) {
287-
quantize_nvfp4_2d(OP, input, output, scales, rows, cols, scales_stride, global_amax);
301+
quantize_nvfp4_2d(OP, input, output, scales, rows, cols, scales_stride, global_amax, use_fast_math);
288302
} else {
289-
quantize_nvfp4_1d(OP, input, output, scales, rows, cols, scales_stride, global_amax);
303+
quantize_nvfp4_1d(OP, input, output, scales, rows, cols, scales_stride, global_amax, use_fast_math);
290304
}
291305
}
292306

@@ -302,14 +316,15 @@ void compute_ref(float (*OP)(const float),
302316
const size_t cols,
303317
const size_t scales_stride,
304318
const size_t scales_stride_t,
319+
const bool use_fast_math,
305320
const bool use_2d_quantization = false)
306321
{
307322
std::vector<InputType> input_t = create_transpose(input, rows, cols);
308323

309324
if (use_2d_quantization) {
310325
// Step 1: Compute mathematical 8×8 scaling factors
311326
std::vector<std::vector<fp8e4m3>> math_scales;
312-
compute_2d_mathematical_scales(OP, input, rows, cols, global_amax, math_scales);
327+
compute_2d_mathematical_scales(OP, input, rows, cols, global_amax, math_scales, use_fast_math);
313328

314329
constexpr size_t block_size_Y = 16;
315330
constexpr size_t block_size_X = 16;
@@ -336,19 +351,25 @@ void compute_ref(float (*OP)(const float),
336351

337352
// Step 4: Process quantized outputs using the same algorithm as quantize_nvfp4_2d
338353
// (This part processes the actual FP4 data using the mathematical scaling factors)
339-
quantize_nvfp4_2d(OP, input, output, nullptr, rows, cols, scales_stride, global_amax); // scales already filled
340-
quantize_nvfp4_2d(OP, input_t.data(), output_t, nullptr, cols, rows, scales_stride_t, global_amax); // scales_t already filled
354+
quantize_nvfp4_2d(OP, input, output, nullptr, rows, cols, scales_stride, global_amax,
355+
use_fast_math); // scales already filled
356+
quantize_nvfp4_2d(OP, input_t.data(), output_t, nullptr, cols, rows, scales_stride_t, global_amax,
357+
use_fast_math); // scales_t already filled
341358

342359
} else {
343-
quantize_nvfp4(OP, input, output, scales, rows, cols, scales_stride, global_amax, use_2d_quantization);
344-
quantize_nvfp4(OP, input_t.data(), output_t, scales_t, cols, rows, scales_stride_t, global_amax, use_2d_quantization);
360+
quantize_nvfp4(OP, input, output, scales, rows, cols, scales_stride, global_amax,
361+
use_fast_math, use_2d_quantization);
362+
quantize_nvfp4(OP, input_t.data(), output_t, scales_t, cols, rows, scales_stride_t, global_amax,
363+
use_fast_math, use_2d_quantization);
345364
}
346365
}
347366

348367
void compare_nvfp4_tensors(const std::string& name,
349368
const fp4e2m1 *test_data, const fp4e2m1 *ref_data,
350369
const int rows, const int cols,
351370
double atol = 1e-5, double rtol = 1e-8) {
371+
constexpr int max_mismatches_to_print = 3;
372+
352373
std::vector<std::string> mismatch_messages;
353374
size_t total_mismatches = 0;
354375

@@ -362,29 +383,16 @@ void compare_nvfp4_tensors(const std::string& name,
362383
const double t = (k == 0 ? test_data_pair.x : test_data_pair.y);
363384
const double r = (k == 0 ? ref_data_pair.x : ref_data_pair.y);
364385

365-
bool mismatch = fabs(t - r) > atol && (r == 0 || fabs((t - r) / r) > rtol);
366-
/* For Float32 the floating point comparison is enough to error out */
367-
bool assertion = false;
368-
if (mismatch && !assertion) {
369-
/* Check if it is just a failure of round to nearest choosing different
370-
side of the real value */
371-
const double mean = (t + r) / 2;
372-
const double mean_p = mean >= 0 ? mean * (1 + 1e-6) : mean * (1 - 1e-6);
373-
const double mean_m = mean >= 0 ? mean * (1 - 1e-6) : mean * (1 + 1e-6);
374-
const double cast_mean_p = static_cast<double>(static_cast<fp4e2m1>(mean_p));
375-
const double cast_mean_m = static_cast<double>(static_cast<fp4e2m1>(mean_m));
376-
assertion = !(cast_mean_m == std::min(t,r) && cast_mean_p == std::max(t,r));
377-
}
378-
if (assertion) {
386+
const bool mismatch = fabs(t - r) > (atol + fabs(r) * rtol);
387+
if (mismatch) {
379388
total_mismatches++;
380-
std::string msg = "Mismatch at place (" + std::to_string(idx + k) + "): " +
381-
std::to_string(t) + " vs " + std::to_string(r) +
382-
" (abs_diff: " + std::to_string(fabs(t - r)) +
383-
", rel_diff: " + std::to_string(r == 0 ? 0.0 : fabs((t - r) / r)) + ")";
384-
mismatch_messages.push_back(msg);
385-
386389
// Optional: limit number of detailed messages to avoid overwhelming output
387-
if (mismatch_messages.size() <= 100) {
390+
if (total_mismatches <= max_mismatches_to_print) {
391+
std::string msg = "Mismatch at place (" + std::to_string(idx + k) + "): " +
392+
std::to_string(t) + " vs " + std::to_string(r) +
393+
" (abs_diff: " + std::to_string(fabs(t - r)) +
394+
", rel_diff: " + std::to_string(r == 0 ? 0.0 : fabs((t - r) / r)) + ")";
395+
mismatch_messages.push_back(msg);
388396
std::cout << "Error in tensor " << name << ": " << msg << std::endl;
389397
}
390398
}
@@ -400,8 +408,9 @@ void compare_nvfp4_tensors(const std::string& name,
400408
std::cout << "STATUS: FAILED for output" << std::endl;
401409
std::cout << "Total mismatches found: " << total_mismatches << std::endl;
402410
std::cout << "Mismatch rate: " << (100.0 * total_mismatches) / (rows * cols) << "%" << std::endl;
403-
if (mismatch_messages.size() > 100) {
404-
std::cout << "... and " << (mismatch_messages.size() - 100) << " more mismatches (showing first 100)" << std::endl;
411+
if (mismatch_messages.size() > max_mismatches_to_print) {
412+
std::cout << "... and " << (mismatch_messages.size() - max_mismatches_to_print)
413+
<< " more mismatches (showing first " << max_mismatches_to_print << ")" << std::endl;
405414
}
406415
std::cout << "============================" << std::endl;
407416

@@ -519,7 +528,8 @@ void compareResults_nvfp4(const Tensor &test,
519528

520529
template <typename InputType>
521530
void performTest(float (*OP)(const float),
522-
const std::vector<size_t>& shape) {
531+
const std::vector<size_t>& shape,
532+
const bool use_fast_math) {
523533
using namespace test;
524534

525535
DType itype = TypeInfo<InputType>::dtype;
@@ -580,15 +590,16 @@ void performTest(float (*OP)(const float),
580590
cols,
581591
scales_stride,
582592
scales_stride_t,
593+
use_fast_math,
583594
use_2d_quantization);
584-
585-
QuantizationConfigWrapper quant_config;
586-
587595
// Initialize stochastic rounding
588596
Tensor rng_state("rng_state", std::vector<size_t>{2}, DType::kInt64);
589597
rng_state.rowwise_cpu_dptr<int64_t>()[0] = 123; // rng_seed
590598
rng_state.rowwise_cpu_dptr<int64_t>()[1] = 321; // rng_sequence
591599
rng_state.from_cpu();
600+
601+
QuantizationConfigWrapper quant_config;
602+
quant_config.set_use_fast_math(use_fast_math);
592603
quant_config.set_stochastic_rounding(false);
593604
quant_config.set_rng_state(rng_state.data());
594605

@@ -619,8 +630,8 @@ void performTest(float (*OP)(const float),
619630
}
620631
ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err);
621632

622-
const double atol = 0.05;
623-
const double rtol = 0.1;
633+
const double atol = 1.0E-6;
634+
const double rtol = 1.0E-6;
624635

625636
// Set dump_data=true to enable dumping tensor data to files for analysis
626637
compareResults_nvfp4(output, ref_output.get(), ref_output_t.get(), rows, cols, atol, rtol, true, false);
@@ -666,12 +677,18 @@ std::vector<ActivationType> Activation_types = {
666677
ActivationType::Identity
667678
};
668679

680+
std::vector<bool> use_fast_nvfp4_scaling_vec = {
681+
false,
682+
true
683+
};
684+
669685
} // namespace
670686

671687
class FusedCastTransposeNVFP4TestSuite : public ::testing::TestWithParam
672688
<std::tuple<ActivationType,
673689
std::vector<size_t>,
674-
transformer_engine::DType>> {};
690+
transformer_engine::DType,
691+
bool>> {};
675692

676693
TEST_P(FusedCastTransposeNVFP4TestSuite, TestFusedCastTransposeNVFP4) {
677694
// Skip tests for pre-Blackwell architectures
@@ -685,6 +702,7 @@ TEST_P(FusedCastTransposeNVFP4TestSuite, TestFusedCastTransposeNVFP4) {
685702
const ActivationType Act_type = std::get<0>(GetParam());
686703
const auto tensor_dims = std::get<1>(GetParam());
687704
const DType input_type = std::get<2>(GetParam());
705+
const bool use_fast_math = std::get<3>(GetParam());
688706

689707
// Skip tests if the input tensor is 1D
690708
if (tensor_dims.size() < 2) {
@@ -702,7 +720,7 @@ TEST_P(FusedCastTransposeNVFP4TestSuite, TestFusedCastTransposeNVFP4) {
702720
}
703721

704722
TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(input_type, InputType,
705-
performTest<InputType>(OP, tensor_dims);
723+
performTest<InputType>(OP, tensor_dims, use_fast_math);
706724
);
707725
}
708726

@@ -724,13 +742,17 @@ INSTANTIATE_TEST_SUITE_P(
724742
::testing::Combine(
725743
::testing::ValuesIn(Activation_types),
726744
::testing::ValuesIn(tensor_dims),
727-
::testing::Values(DType::kBFloat16)),
745+
::testing::Values(DType::kBFloat16),
746+
::testing::ValuesIn(use_fast_nvfp4_scaling_vec)),
728747
[](const testing::TestParamInfo<FusedCastTransposeNVFP4TestSuite::ParamType>& info) {
729748
std::string name = to_string(std::get<0>(info.param));
730749
const auto& shape = std::get<1>(info.param);
731750
for ( const auto& s: shape) {
732751
name += "X" + std::to_string(s);
733752
}
734753
name += "X" + test::typeName(std::get<2>(info.param));
754+
if (std::get<3>(info.param)) {
755+
name += "X_FAST_SCALING";
756+
}
735757
return name;
736758
});

transformer_engine/common/cast/core/common.cuh

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,12 @@ inline bool dimensions_supported_by_TMA(const Tensor *const t) {
3535
return cols % alignment_requirement == 0;
3636
}
3737

38+
__device__ __forceinline__ unsigned char *align_smem_ptr_per_TMA_requirements(unsigned char *p) {
39+
size_t addr = reinterpret_cast<size_t>(p);
40+
addr = (addr + TMA_SHMEM_ALIGNMENT - 1) & ~(TMA_SHMEM_ALIGNMENT - 1);
41+
return reinterpret_cast<unsigned char *>(addr);
42+
}
43+
3844
namespace kernel {
3945

4046
constexpr size_t THREADS_PER_BLOCK = 256;

transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
#include "../../util/ptx.cuh"
2222
#include "../../utils.cuh"
2323
#include "core_nvfp4.cuh"
24+
#include "specialized/quantize_transpose_nvfp4_tuned_1D.cuh"
2425

2526
namespace transformer_engine {
2627
namespace dispatch {
@@ -1159,13 +1160,19 @@ void quantize_transpose(const Tensor &input, const Tensor *noop, Tensor *output,
11591160
#if FP4_TYPE_SUPPORTED
11601161
using namespace quantize_transpose_kernel;
11611162
using namespace ptx;
1163+
11621164
bool use_stochastic_rounding = quant_config ? quant_config->stochastic_rounding : false;
11631165

11641166
// If transposed output is allocated, return the transposed data. Otherwise, it's not necesary to
11651167
// return the transposed data.
11661168
// TODO(Frank): Is there a better way to do this?
11671169
bool return_transpose = output->has_columnwise_data();
11681170

1171+
if (!use_2d_quantization && (input.dtype() == DType::kBFloat16)) {
1172+
quantize_transpose_tuned_1D(input, noop, output, quant_config, stream);
1173+
return;
1174+
}
1175+
11691176
constexpr bool COMPUTE_ACTIVATIONS = false;
11701177
using ParamOP = Empty;
11711178
constexpr float (*OP)(float, const ParamOP &) = nullptr;

0 commit comments

Comments
 (0)