Skip to content

Commit dc9af4a

Browse files
zianglihtimmoon10pre-commit-ci[bot]
authored
Implement 4over6 NVFP4 recipe (#2972)
* Initial implementation Signed-off-by: Ziang Li <ziangli@umich.edu> * Make 4over6 compile time for dequant Signed-off-by: Ziang Li <ziangli@umich.edu> * Expand 1d fwd+bwd test Signed-off-by: Ziang Li <ziangli@umich.edu> * Refactor Signed-off-by: Ziang Li <ziangli@umich.edu> * Clean up Signed-off-by: Ziang Li <ziangli@umich.edu> * Clean up Signed-off-by: Ziang Li <ziangli@umich.edu> * Add gemm test Signed-off-by: Ziang Li <ziangli@umich.edu> * Add more tests and fix offload Signed-off-by: Ziang Li <ziangli@umich.edu> * Fix offload Signed-off-by: Ziang Li <ziangli@umich.edu> * Clean up arg Signed-off-by: Ziang Li <ziangli@umich.edu> * Add more test Signed-off-by: Ziang Li <ziangli@umich.edu> * Add more tests Signed-off-by: Ziang Li <ziangli@umich.edu> * Clean up test Signed-off-by: Ziang Li <ziangli@umich.edu> * Refactor cuh kernel impl Signed-off-by: Ziang Li <ziangli@umich.edu> * Further extract Signed-off-by: Ziang Li <ziangli@umich.edu> * Clean up Signed-off-by: Ziang Li <ziangli@umich.edu> * Add recipe_id Signed-off-by: Ziang Li <ziangli@umich.edu> * Fix failing unit tests Signed-off-by: Ziang Li <ziangli@umich.edu> * Clean up test Signed-off-by: Ziang Li <ziangli@umich.edu> * Clean up Signed-off-by: Ziang Li <ziangli@umich.edu> * Refactor ref Signed-off-by: Ziang Li <ziangli@umich.edu> * Update comments and docs Signed-off-by: Ziang Li <ziangli@umich.edu> * Drop unnecessary test_sanity workaround The following tests passed: `NVTE_GROUPED_LINEAR_SINGLE_PARAM=1 python3 -m pytest --tb=auto tests/pytorch/test_sanity.py ` `NVTE_GROUPED_LINEAR_SINGLE_PARAM=1 NVTE_TEST_NVINSPECT_ENABLED=1 NVTE_TEST_NVINSPECT_CONFIG_FILE=tests/pytorch/debug/test_configs/dummy_feature.yaml NVTE_TEST_NVINSPECT_FEATURE_DIRS=transformer_engine/debug/features PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 python3 -m pytest --tb=auto tests/pytorch/test_sanity.py ` Signed-off-by: Ziang Li <ziangli@umich.edu> * Refactor `QuantizerRole` Signed-off-by: Ziang Li <ziangli@umich.edu> * Allow separate recipe 4over6 config Signed-off-by: Ziang Li <ziangli@umich.edu> * Support 2d Signed-off-by: Ziang Li <ziangli@umich.edu> * Refactor 2d Signed-off-by: Ziang Li <ziangli@umich.edu> * Clean up anti pattern Signed-off-by: Ziang Li <ziangli@umich.edu> * Enforce 4over6 consistency Signed-off-by: Ziang Li <ziangli@umich.edu> * Update comments Signed-off-by: Ziang Li <ziangli@umich.edu> * Update docs Signed-off-by: Ziang Li <ziangli@umich.edu> * Fix test Signed-off-by: Ziang Li <ziangli@umich.edu> * Drop test_fusible_ops Signed-off-by: Ziang Li <ziangli@umich.edu> * Revert "Drop test_fusible_ops" This reverts commit 69f9ccc. Signed-off-by: Ziang Li <ziangli@umich.edu> * Refactor test_fusible_ops Signed-off-by: Ziang Li <ziangli@umich.edu> * Refactor ref and extend cpp test Signed-off-by: Ziang Li <ziangli@umich.edu> * Clean up cpp test Signed-off-by: Ziang Li <ziangli@umich.edu> * Minor comment Signed-off-by: Ziang Li <ziangli@umich.edu> * Drop doc Signed-off-by: Ziang Li <ziangli@umich.edu> * Explicit handle conditional smem buffer Signed-off-by: Ziang Li <ziangli@umich.edu> * Further clean up Signed-off-by: Ziang Li <ziangli@umich.edu> * More templates Signed-off-by: Ziang Li <ziangli@umich.edu> * Simplify cpp Signed-off-by: Ziang Li <ziangli@umich.edu> * Drop write back lifting Signed-off-by: Ziang Li <ziangli@umich.edu> * Add MAE and dedicated fast math env var Signed-off-by: Ziang Li <ziangli@umich.edu> * Harden cpp test Signed-off-by: Ziang Li <ziangli@umich.edu> * Add warning and err fast math coverage Signed-off-by: Ziang Li <ziangli@umich.edu> * Fold test case and clean up cpp test Signed-off-by: Ziang Li <ziangli@umich.edu> * Initial 448 vs 256 implementation Signed-off-by: Ziang Li <ziangli@umich.edu> * Use e4m3 max instead of boolean, more template Signed-off-by: Ziang Li <ziangli@umich.edu> * Add benchmark script and minor optimization Signed-off-by: Ziang Li <ziangli@umich.edu> * Use standalone kernels Signed-off-by: Ziang Li <ziangli@umich.edu> * Use cp async Signed-off-by: Ziang Li <ziangli@umich.edu> * Add benchmark script Signed-off-by: Ziang Li <ziangli@umich.edu> * Minor fix after rebase Signed-off-by: Ziang Li <ziangli@umich.edu> * Naming consistency Signed-off-by: Ziang Li <ziangli@umich.edu> * Remove 4over6 benchmark Signed-off-by: Ziang Li <ziangli@umich.edu> * Refactor modes Signed-off-by: Ziang Li <ziangli@umich.edu> * Relax tol for `test_layernorm_mlp` for `nvfp4_4over6` Signed-off-by: Ziang Li <ziangli@umich.edu> * Minor fix recipe naming Signed-off-by: Ziang Li <ziangli@umich.edu> * Remove gradient 4over6 quantization and partially allow SR/RHT Signed-off-by: Ziang Li <ziangli@umich.edu> * Allow RHT in pytorch ref Signed-off-by: Ziang Li <ziangli@umich.edu> * Update transformer_engine/pytorch/csrc/quantizer.cpp Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> * Minor fix TODO lint Signed-off-by: Ziang Li <ziangli@umich.edu> * Use standard nvfp4 for grad ref in test_fusible_ops.py since 4over6 is not applied to gradient quantizers Signed-off-by: Ziang Li <ziangli@umich.edu> * Minor fix test-fusible_ops 4over6 helper Signed-off-by: Ziang Li <ziangli@umich.edu> * Default to 256 for 4over6 Signed-off-by: Ziang Li <ziangli@umich.edu> * Reset RNG state for each TE ops test Adding tests affected RNG in unrelated tests. Signed-off-by: Tim Moon <tmoon@nvidia.com> * Remove loosened NVFP4 tols in layernorm MLP test. Make sure tensors are representable in quantized format. Signed-off-by: Tim Moon <tmoon@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Ziang Li <ziangli@umich.edu> Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> Signed-off-by: Tim Moon <tmoon@nvidia.com> Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> Co-authored-by: Tim Moon <tmoon@nvidia.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 9af70a8 commit dc9af4a

37 files changed

Lines changed: 2595 additions & 251 deletions

docs/envvars.rst

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -287,6 +287,30 @@ Kernel Configuration
287287
:Default: ``0``
288288
:Description: Enable row-scaled NVFP4 tensors for forward activation quantizers in the ``NVFP4BlockScaling`` recipe. When set to ``1`` (or when ``NVFP4BlockScaling(row_scaled_activation=True)`` is used), rowwise ``amax`` metadata is stored as one FP32 value per tensor row instead of a single scalar.
289289

290+
.. envvar:: NVTE_NVFP4_4OVER6
291+
292+
:Type: ``str`` (``none``, ``weights``, ``activations``, or ``all``)
293+
:Default: ``none``
294+
:Description: Enable 4over6 adaptive NVFP4 block scaling for weights, activations, or both in the ``NVFP4BlockScaling`` recipe. For each selected FP4 block, quantization compares map-to-4 and map-to-6 candidates and stores the candidate with lower configured error. ``none`` keeps standard NVFP4. Current 4over6 support targets RL and post-training scenarios; pre-training paths that combine 4over6 with RHT are not yet implemented.
295+
296+
.. envvar:: NVTE_NVFP4_4OVER6_E4M3_USE_256
297+
298+
:Type: ``str`` (``none``, ``weights``, ``activations``, or ``all``)
299+
:Default: ``all``
300+
:Description: Select NVFP4 4over6 quantizers that use 256 instead of 448 as the global E4M3 scale bound. By default, all 4over6 quantizers use 256. Set the env var to ``none`` (or set ``NVFP4BlockScaling(nvfp4_4over6_e4m3_use_256="none")``) to use the standard NVFP4 448 bound for all 4over6 quantizers. This option is only meaningful for tensor roles that also enable :envvar:`NVTE_NVFP4_4OVER6`.
301+
302+
.. envvar:: NVTE_NVFP4_4OVER6_ERR_MODE
303+
304+
:Type: ``str`` (``MAE`` or ``MSE``)
305+
:Default: ``MAE``
306+
:Description: Select the input-domain error metric used by NVFP4 4over6 map-to-4 versus map-to-6 candidate selection in the ``NVFP4BlockScaling`` recipe.
307+
308+
.. envvar:: NVTE_NVFP4_4OVER6_ERR_USE_FAST_MATH
309+
310+
:Type: ``int`` (0 or 1)
311+
:Default: ``0``
312+
:Description: Allow the NVFP4 4over6 candidate error computation to use faster non-strict floating-point expressions. By default, 4over6 error comparison uses strict expressions; ``NVTE_USE_FAST_MATH`` does not control this error-comparison path.
313+
290314
Torch Compilation and Fusion
291315
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
292316

tests/cpp/operator/test_cast_nvfp4_transpose.cu

Lines changed: 520 additions & 96 deletions
Large diffs are not rendered by default.

tests/cpp/operator/test_dequantize_nvfp4.cu

Lines changed: 54 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,9 @@ void compute_ref_dequantize_nvfp4(const uint8_t *packed_data,
4646
OType *output,
4747
size_t rows,
4848
size_t cols,
49-
size_t scale_stride) {
50-
constexpr float factor_inv = 1.0f / (6.0f * 448.0f);
49+
size_t scale_stride,
50+
int e4m3_max) {
51+
const float factor_inv = 1.0f / (6.0f * static_cast<float>(e4m3_max));
5152
constexpr size_t BLOCK_SIZE = 16;
5253
const size_t Mread = cols / BLOCK_SIZE;
5354
const size_t bytes_per_block = BLOCK_SIZE / 2;
@@ -86,11 +87,18 @@ float compute_amax(test::Tensor &t, size_t rows, size_t cols) {
8687
return amax;
8788
}
8889

90+
struct NVFP4DequantizeTestConfig {
91+
NVTENVFP44Over6Mode mode = kNVTENVFP44Over6Disabled;
92+
int e4m3_max = 448;
93+
};
94+
8995
// Quantize a high-precision input to NVFP4, then dequantize and compare
9096
// against a CPU reference computed from the quantized data.
9197
template <typename OutputType>
9298
void performTest_dequantize_nvfp4(const size_t rows, const size_t cols,
93-
const bool row_scaled_nvfp4) {
99+
const bool row_scaled_nvfp4,
100+
const NVTENVFP44Over6Mode mode,
101+
const int e4m3_max) {
94102
using namespace test;
95103
DType otype = TypeInfo<OutputType>::dtype;
96104

@@ -105,6 +113,8 @@ void performTest_dequantize_nvfp4(const size_t rows, const size_t cols,
105113

106114
// Configure quantized tensor amax
107115
size_t amax_size = 1;
116+
quantized.set_nvfp4_e4m3_max(e4m3_max);
117+
ASSERT_EQ(quantized.nvfp4_e4m3_max(), e4m3_max);
108118
if (row_scaled_nvfp4) {
109119
quantized.set_row_scaled_nvfp4(true);
110120
amax_size = rows;
@@ -116,7 +126,9 @@ void performTest_dequantize_nvfp4(const size_t rows, const size_t cols,
116126

117127
// Quantize
118128
if (rows > 0 && cols > 0) {
119-
nvte_quantize(input.data(), quantized.data(), 0);
129+
QuantizationConfigWrapper quant_config;
130+
quant_config.set_nvfp4_4over6_mode(mode);
131+
nvte_quantize_v2(input.data(), quantized.data(), quant_config, 0);
120132
cudaDeviceSynchronize();
121133
auto err = cudaGetLastError();
122134
ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err);
@@ -146,7 +158,7 @@ void performTest_dequantize_nvfp4(const size_t rows, const size_t cols,
146158
std::make_unique<OutputType[]>(rows * cols);
147159
compute_ref_dequantize_nvfp4<OutputType>(
148160
fp4_data, scales, amax_vals, ref_output.get(),
149-
rows, cols, scale_stride);
161+
rows, cols, scale_stride, e4m3_max);
150162

151163
// Compare results from TE and reference impls
152164
auto [atol, rtol] = getTolerances(otype);
@@ -156,7 +168,9 @@ void performTest_dequantize_nvfp4(const size_t rows, const size_t cols,
156168
// Dequantize NVFP4 with GEMM-swizzled scales and compare against compact path.
157169
template <typename OutputType>
158170
void performTest_dequantize_nvfp4_swizzled(const size_t rows, const size_t cols,
159-
const bool row_scaled_nvfp4) {
171+
const bool row_scaled_nvfp4,
172+
const NVTENVFP44Over6Mode mode,
173+
const int e4m3_max) {
160174
using namespace test;
161175
DType otype = TypeInfo<OutputType>::dtype;
162176

@@ -165,6 +179,8 @@ void performTest_dequantize_nvfp4_swizzled(const size_t rows, const size_t cols,
165179

166180
Tensor quantized_compact("quantized_compact", std::vector<size_t>{rows, cols},
167181
DType::kFloat4E2M1, true, false, NVTE_NVFP4_1D_SCALING);
182+
quantized_compact.set_nvfp4_e4m3_max(e4m3_max);
183+
ASSERT_EQ(quantized_compact.nvfp4_e4m3_max(), e4m3_max);
168184
if (row_scaled_nvfp4) {
169185
quantized_compact.set_row_scaled_nvfp4(true);
170186
} else if (rows > 0 && cols > 0) {
@@ -174,7 +190,9 @@ void performTest_dequantize_nvfp4_swizzled(const size_t rows, const size_t cols,
174190
}
175191

176192
if (rows > 0 && cols > 0) {
177-
nvte_quantize(input.data(), quantized_compact.data(), 0);
193+
QuantizationConfigWrapper quant_config;
194+
quant_config.set_nvfp4_4over6_mode(mode);
195+
nvte_quantize_v2(input.data(), quantized_compact.data(), quant_config, 0);
178196
cudaDeviceSynchronize();
179197
}
180198

@@ -186,6 +204,8 @@ void performTest_dequantize_nvfp4_swizzled(const size_t rows, const size_t cols,
186204
// Create tensor with same FP4 data but swizzled scales
187205
Tensor quantized_swizzled("quantized_swizzled", std::vector<size_t>{rows, cols},
188206
DType::kFloat4E2M1, true, false, NVTE_NVFP4_1D_SCALING);
207+
quantized_swizzled.set_nvfp4_e4m3_max(e4m3_max);
208+
ASSERT_EQ(quantized_swizzled.nvfp4_e4m3_max(), e4m3_max);
189209
if (row_scaled_nvfp4) {
190210
quantized_swizzled.set_row_scaled_nvfp4(true);
191211
} else {
@@ -260,7 +280,8 @@ std::vector<std::pair<size_t, size_t>> nvfp4_tensor_dims = {
260280
class DequantizeNVFP4TestSuite : public ::testing::TestWithParam
261281
<std::tuple<std::pair<size_t, size_t>,
262282
transformer_engine::DType,
263-
bool>> {};
283+
bool,
284+
NVFP4DequantizeTestConfig>> {};
264285

265286
TEST_P(DequantizeNVFP4TestSuite, TestDequantizeNVFP4)
266287
{
@@ -271,10 +292,12 @@ TEST_P(DequantizeNVFP4TestSuite, TestDequantizeNVFP4)
271292
const auto tensor_size = std::get<0>(GetParam());
272293
const DType output_type = std::get<1>(GetParam());
273294
const bool row_scaled_nvfp4 = std::get<2>(GetParam());
295+
const NVFP4DequantizeTestConfig config = std::get<3>(GetParam());
274296

275297
TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(output_type, OutputType,
276298
performTest_dequantize_nvfp4<OutputType>(
277-
tensor_size.first, tensor_size.second, row_scaled_nvfp4);
299+
tensor_size.first, tensor_size.second, row_scaled_nvfp4, config.mode,
300+
config.e4m3_max);
278301
);
279302
}
280303

@@ -284,21 +307,29 @@ INSTANTIATE_TEST_SUITE_P(
284307
::testing::Combine(
285308
::testing::ValuesIn(nvfp4_tensor_dims),
286309
::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16),
287-
::testing::Bool()),
310+
::testing::Bool(),
311+
::testing::Values(NVFP4DequantizeTestConfig{},
312+
NVFP4DequantizeTestConfig{kNVTENVFP44Over6MinMAE, 448},
313+
NVFP4DequantizeTestConfig{kNVTENVFP44Over6MinMAE, 256})),
288314
[](const testing::TestParamInfo<DequantizeNVFP4TestSuite::ParamType>& info)
289315
{
316+
const NVFP4DequantizeTestConfig config = std::get<3>(info.param);
317+
const bool use_4over6 = config.mode != kNVTENVFP44Over6Disabled;
290318
std::string name = std::to_string(std::get<0>(info.param).first) + "X" +
291319
std::to_string(std::get<0>(info.param).second) + "X" +
292320
test::typeName(std::get<1>(info.param)) + "X" +
293-
(std::get<2>(info.param) ? "RowScaled" : "PerTensor");
321+
(std::get<2>(info.param) ? "RowScaled" : "PerTensor") + "X" +
322+
(use_4over6 ? "FourOverSix" : "Default") + "X" +
323+
(config.e4m3_max == 256 ? "E4M3Max256" : "E4M3Max448");
294324
return name;
295325
}
296326
);
297327

298328
class DequantizeNVFP4SwizzledTestSuite : public ::testing::TestWithParam
299329
<std::tuple<std::pair<size_t, size_t>,
300330
transformer_engine::DType,
301-
bool>> {};
331+
bool,
332+
NVFP4DequantizeTestConfig>> {};
302333

303334
TEST_P(DequantizeNVFP4SwizzledTestSuite, TestDequantizeNVFP4Swizzled)
304335
{
@@ -309,10 +340,12 @@ TEST_P(DequantizeNVFP4SwizzledTestSuite, TestDequantizeNVFP4Swizzled)
309340
const auto tensor_size = std::get<0>(GetParam());
310341
const DType output_type = std::get<1>(GetParam());
311342
const bool row_scaled_nvfp4 = std::get<2>(GetParam());
343+
const NVFP4DequantizeTestConfig config = std::get<3>(GetParam());
312344

313345
TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(output_type, OutputType,
314346
performTest_dequantize_nvfp4_swizzled<OutputType>(
315-
tensor_size.first, tensor_size.second, row_scaled_nvfp4);
347+
tensor_size.first, tensor_size.second, row_scaled_nvfp4, config.mode,
348+
config.e4m3_max);
316349
);
317350
}
318351

@@ -322,13 +355,20 @@ INSTANTIATE_TEST_SUITE_P(
322355
::testing::Combine(
323356
::testing::ValuesIn(nvfp4_tensor_dims),
324357
::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16),
325-
::testing::Bool()),
358+
::testing::Bool(),
359+
::testing::Values(NVFP4DequantizeTestConfig{},
360+
NVFP4DequantizeTestConfig{kNVTENVFP44Over6MinMAE, 448},
361+
NVFP4DequantizeTestConfig{kNVTENVFP44Over6MinMAE, 256})),
326362
[](const testing::TestParamInfo<DequantizeNVFP4SwizzledTestSuite::ParamType>& info)
327363
{
364+
const NVFP4DequantizeTestConfig config = std::get<3>(info.param);
365+
const bool use_4over6 = config.mode != kNVTENVFP44Over6Disabled;
328366
std::string name = std::to_string(std::get<0>(info.param).first) + "X" +
329367
std::to_string(std::get<0>(info.param).second) + "X" +
330368
test::typeName(std::get<1>(info.param)) + "X" +
331369
(std::get<2>(info.param) ? "RowScaled" : "PerTensor") + "X" +
370+
(use_4over6 ? "FourOverSix" : "Default") + "X" +
371+
(config.e4m3_max == 256 ? "E4M3Max256" : "E4M3Max448") + "X" +
332372
"Swizzled";
333373
return name;
334374
}

tests/cpp/test_common.cu

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -440,6 +440,18 @@ void Tensor::set_row_scaled_nvfp4(bool row_scaled_nvfp4) {
440440
}
441441
}
442442

443+
void Tensor::set_nvfp4_e4m3_max(int nvfp4_e4m3_max) {
444+
NVTE_CHECK(tensor_.scaling_mode() == NVTE_NVFP4_1D_SCALING,
445+
"NVFP4 E4M3 max is only supported for NVFP4 tensors.");
446+
tensor_.set_nvfp4_e4m3_max(nvfp4_e4m3_max);
447+
}
448+
449+
int Tensor::nvfp4_e4m3_max() const {
450+
NVTE_CHECK(tensor_.scaling_mode() == NVTE_NVFP4_1D_SCALING,
451+
"NVFP4 E4M3 max is only supported for NVFP4 tensors.");
452+
return tensor_.get_nvfp4_e4m3_max();
453+
}
454+
443455
void Tensor::to_cpu() {
444456
if (data_rowwise_) { data_rowwise_->to_cpu(); }
445457
if (data_columnwise_) { data_columnwise_->to_cpu(); }

tests/cpp/test_common.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -293,10 +293,13 @@ class Tensor {
293293
return columnwise_;
294294
}
295295

296+
int nvfp4_e4m3_max() const;
297+
296298
void set_tensor_amax_nullptr();
297299

298300
void set_with_gemm_swizzled_scales(bool with_gemm_swizzled_scales);
299301
void set_row_scaled_nvfp4(bool row_scaled_nvfp4);
302+
void set_nvfp4_e4m3_max(int nvfp4_e4m3_max);
300303

301304
void to_cpu();
302305
void from_cpu();

0 commit comments

Comments
 (0)