Skip to content

Commit 3f64073

Browse files
Enable NVFP4 grouped MLP GLU RHT amax path (#3073)
* Enable NVFP4 grouped MLP GLU RHT amax path Signed-off-by: Siddhartha Raman S <sraman@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Address NVFP4 GLU RHT amax review comments Signed-off-by: Siddhartha Raman S <sraman@nvidia.com> * Deduplicate grouped NVFP4 quantize helper Signed-off-by: Siddhartha Raman S <sraman@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Route precomputed amax through NVFP4 quantize Signed-off-by: Siddhartha Raman S <sraman@nvidia.com> * Handle empty NVFP4 precomputed amax reduction Signed-off-by: Siddhartha Raman S <sraman@nvidia.com> * Fix NVFP4 amax quantize binding signature Signed-off-by: Siddhartha Raman S <sraman@nvidia.com> --------- Signed-off-by: Siddhartha Raman S <sraman@nvidia.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent fc92624 commit 3f64073

7 files changed

Lines changed: 350 additions & 87 deletions

File tree

transformer_engine/pytorch/csrc/common.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -380,6 +380,8 @@ class NVFP4Quantizer : public Quantizer {
380380

381381
void quantize(const TensorWrapper& input, TensorWrapper& out,
382382
const std::optional<TensorWrapper>& noop_flag = std::nullopt) override;
383+
void quantize_impl(const TensorWrapper& input, TensorWrapper& out,
384+
const std::optional<TensorWrapper>& noop_flag, bool compute_amax);
383385

384386
/*! @brief Quantize to NVFP4, skipping local amax computation
385387
*
@@ -392,8 +394,6 @@ class NVFP4Quantizer : public Quantizer {
392394
std::vector<size_t> get_scale_shape(const std::vector<size_t>& shape, bool columnwise) const;
393395

394396
private:
395-
void quantize_impl(const TensorWrapper& input, TensorWrapper& out,
396-
const std::optional<TensorWrapper>& noop_flag, bool compute_amax);
397397
void quantize_with_rht_unfused_helper(const TensorWrapper& input, TensorWrapper& out,
398398
TensorWrapper& rht_output_t_cpp,
399399
QuantizationConfigWrapper& quant_config,

transformer_engine/pytorch/csrc/extensions.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -335,12 +335,23 @@ py::object create_empty_quantized_tensor(py::handle quantizer, const std::vector
335335
py::object quantize(const at::Tensor &tensor, py::handle quantizer, const py::object &output,
336336
std::optional<at::Tensor> noop_flag);
337337

338+
py::object nvfp4_quantize_with_amax(const at::Tensor &tensor, py::handle quantizer,
339+
const at::Tensor &rowwise_amax,
340+
const at::Tensor &columnwise_amax);
341+
338342
py::object dequantize(const py::handle &input, DType otype);
339343

340344
py::object group_quantize(const at::Tensor &tensor, py::handle quantizer, const size_t num_tensors,
341345
std::optional<at::Tensor> first_dims,
342346
std::optional<at::Tensor> tensor_offsets);
343347

348+
py::object nvfp4_group_quantize_with_amax(const at::Tensor &tensor, py::handle quantizer,
349+
const size_t num_tensors,
350+
std::optional<at::Tensor> first_dims,
351+
const at::Tensor &rowwise_amax,
352+
const at::Tensor &columnwise_amax,
353+
std::optional<at::Tensor> tensor_offsets);
354+
344355
py::object group_dequantize(const py::handle &input, DType otype);
345356

346357
py::object bgrad_group_quantize(const at::Tensor &tensor, py::handle quantizer,

transformer_engine/pytorch/csrc/extensions/cast.cpp

Lines changed: 152 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,18 @@ std::vector<size_t> get_tensor_shape(const TensorWrapper &tensor) {
3131
return std::vector<size_t>(shape.data, shape.data + shape.ndim);
3232
}
3333

34+
void allreduce_nvfp4_amax_tensors(NVFP4Quantizer *nvfp4_quantizer_cpp,
35+
std::vector<at::Tensor> &&amax_tensors) {
36+
if (!nvfp4_quantizer_cpp->with_amax_reduction || amax_tensors.empty()) {
37+
return;
38+
}
39+
c10d::AllreduceCoalescedOptions opts;
40+
opts.reduceOp = c10d::ReduceOp::MAX;
41+
NVTE_SCOPED_GIL_RELEASE({
42+
nvfp4_quantizer_cpp->amax_reduction_group->allreduce_coalesced(amax_tensors, opts)->wait();
43+
});
44+
}
45+
3446
} // namespace
3547

3648
py::object quantize(const at::Tensor &tensor, py::handle quantizer, const py::object &output,
@@ -71,6 +83,51 @@ py::object quantize(const at::Tensor &tensor, py::handle quantizer, const py::ob
7183
return output_py;
7284
}
7385

86+
py::object nvfp4_quantize_with_amax(const at::Tensor &tensor, py::handle quantizer,
87+
const at::Tensor &rowwise_amax,
88+
const at::Tensor &columnwise_amax) {
89+
using namespace transformer_engine::pytorch::detail;
90+
init_extension();
91+
92+
NVTE_CHECK(tensor.dim() >= 2, "Tensor must be at least 2D");
93+
NVTE_CHECK(rowwise_amax.is_cuda() && columnwise_amax.is_cuda(),
94+
"Precomputed amax tensors must be CUDA tensors.");
95+
NVTE_CHECK(
96+
rowwise_amax.scalar_type() == at::kFloat && columnwise_amax.scalar_type() == at::kFloat,
97+
"Precomputed amax tensors must be float32.");
98+
NVTE_CHECK(rowwise_amax.numel() == 1 && columnwise_amax.numel() == 1,
99+
"nvfp4_quantize_with_amax expects scalar rowwise and columnwise amaxes.");
100+
101+
auto quantizer_cpp = convert_quantizer(quantizer);
102+
NVTE_CHECK(IsNVFP4Quantizers(quantizer.ptr()),
103+
"nvfp4_quantize_with_amax only supports NVFP4 quantizers.");
104+
NVFP4Quantizer *nvfp4_quantizer_cpp = static_cast<NVFP4Quantizer *>(quantizer_cpp.get());
105+
106+
auto input_contiguous = tensor.contiguous();
107+
auto input_cpp = makeTransformerEngineTensor(input_contiguous);
108+
109+
const auto shape = get_tensor_shape(input_cpp);
110+
const auto fake_dtype = input_cpp.dtype();
111+
auto [output_cpp, output_py] = quantizer_cpp->create_tensor(shape, fake_dtype);
112+
113+
if (output_cpp.get_amax().data_ptr != nullptr) {
114+
output_cpp.set_amax(rowwise_amax.data_ptr(), DType::kFloat32, getTensorShape(rowwise_amax));
115+
output_py.attr("_amax_rowwise") = py::cast(rowwise_amax);
116+
}
117+
if (output_cpp.get_columnwise_amax().data_ptr != nullptr) {
118+
output_cpp.set_columnwise_amax(columnwise_amax.data_ptr(), DType::kFloat32,
119+
getTensorShape(columnwise_amax));
120+
output_py.attr("_amax_columnwise") = py::cast(columnwise_amax);
121+
}
122+
123+
nvfp4_quantizer_cpp->quantize_impl(input_cpp, output_cpp, std::nullopt, false);
124+
if (quantizer_cpp->optimize_for_gemm && !output_cpp.get_with_gemm_swizzled_scales()) {
125+
inplace_swizzle_scale_for_gemm(output_py);
126+
}
127+
128+
return output_py;
129+
}
130+
74131
py::object create_empty_quantized_tensor(py::handle quantizer, const std::vector<size_t> &shape,
75132
at::ScalarType dtype, at::Device device, bool pin_memory) {
76133
auto quantizer_cpp = convert_quantizer(quantizer);
@@ -84,14 +141,19 @@ namespace {
84141
// helper functions for NVFP4 grouped quantization (cuda graph safe with shapes stored in device without D2H copy)
85142
void group_quantize_nvfp4_impl(const GroupedTensorWrapper &grouped_input_tensor,
86143
GroupedTensorWrapper &grouped_output_tensor,
87-
NVFP4Quantizer *nvfp4_quantizer_cpp, cudaStream_t stream) {
144+
NVFP4Quantizer *nvfp4_quantizer_cpp, cudaStream_t stream,
145+
bool compute_amax) {
88146
size_t num_tensors = grouped_input_tensor.num_tensors();
89147

90148
// assert the 2D scaling case, since 2D scaling grouped quant kernel is not ready yet
91149
NVTE_CHECK(!nvfp4_quantizer_cpp->with_2d_quantization,
92150
"2D scaling grouped quant kernel is not ready yet");
93151
NVTE_CHECK(nvfp4_quantizer_cpp->nvfp4_4over6_mode == kNVTENVFP44Over6Disabled,
94152
"NVFP4 4over6 quantization is not supported for grouped quantization.");
153+
NVTE_CHECK(nvfp4_quantizer_cpp->with_rht,
154+
"graph safe grouped quant kernel for non-RHT path is not ready yet");
155+
NVTE_CHECK(nvfp4_quantizer_cpp->with_post_rht_amax,
156+
"grouped NVFP4 RHT quantization expects post-RHT amax buffers.");
95157

96158
auto quant_config_cpp = QuantizationConfigWrapper();
97159

@@ -122,37 +184,24 @@ void group_quantize_nvfp4_impl(const GroupedTensorWrapper &grouped_input_tensor,
122184
quant_config_cpp.set_use_fast_math(true);
123185
}
124186

125-
// so far, only the RHT path has grouped kernel support
126-
// grouped kernels for non-RHT path will be added later
127-
128-
if (nvfp4_quantizer_cpp->with_rht) {
129-
// post-RHT amax or not
130-
if (nvfp4_quantizer_cpp->with_post_rht_amax) {
131-
NVTE_SCOPED_GIL_RELEASE({
132-
nvte_group_hadamard_transform_amax_graph_safe(
133-
grouped_input_tensor.data(), grouped_output_tensor.data(), 0,
134-
nvfp4_quantizer_cpp->rht_matrix_random_sign_mask_t, stream);
135-
});
136-
} else {
137-
NVTE_ERROR("graph safe grouped quant kernel for non-RHT path is not ready yet");
138-
}
139-
140-
// RHT cast fusion
141-
auto tile_scheduler_workspace_torch =
142-
at::empty({1}, at::device(at::kCUDA).dtype(torch::kInt32));
143-
auto nvte_tile_scheduler_workspace =
144-
makeTransformerEngineTensor(tile_scheduler_workspace_torch);
145-
146-
auto rht_matrix_nvte = makeTransformerEngineTensor(nvfp4_quantizer_cpp->rht_matrix);
187+
if (compute_amax) {
147188
NVTE_SCOPED_GIL_RELEASE({
148-
nvte_group_hadamard_transform_cast_fusion_graph_safe(
149-
grouped_input_tensor.data(), grouped_output_tensor.data(), rht_matrix_nvte.data(),
150-
quant_config_cpp, nvte_tile_scheduler_workspace.data(), stream);
189+
nvte_group_hadamard_transform_amax_graph_safe(
190+
grouped_input_tensor.data(), grouped_output_tensor.data(), 0,
191+
nvfp4_quantizer_cpp->rht_matrix_random_sign_mask_t, stream);
151192
});
152-
153-
} else {
154-
NVTE_ERROR("graph safe grouped quant kernel for non-RHT path is not ready yet");
155193
}
194+
195+
// RHT cast fusion
196+
auto tile_scheduler_workspace_torch = at::empty({1}, at::device(at::kCUDA).dtype(torch::kInt32));
197+
auto nvte_tile_scheduler_workspace = makeTransformerEngineTensor(tile_scheduler_workspace_torch);
198+
199+
auto rht_matrix_nvte = makeTransformerEngineTensor(nvfp4_quantizer_cpp->rht_matrix);
200+
NVTE_SCOPED_GIL_RELEASE({
201+
nvte_group_hadamard_transform_cast_fusion_graph_safe(
202+
grouped_input_tensor.data(), grouped_output_tensor.data(), rht_matrix_nvte.data(),
203+
quant_config_cpp, nvte_tile_scheduler_workspace.data(), stream);
204+
});
156205
}
157206

158207
} // namespace
@@ -214,7 +263,7 @@ py::object group_quantize(const at::Tensor &tensor, py::handle quantizer, const
214263
// NVFP4 grouped quantization
215264
NVFP4Quantizer *nvfp4_quantizer_cpp = static_cast<NVFP4Quantizer *>(quantizer_cpp.get());
216265
group_quantize_nvfp4_impl(grouped_input_tensor, grouped_output_tensor_cpp,
217-
nvfp4_quantizer_cpp, at::cuda::getCurrentCUDAStream());
266+
nvfp4_quantizer_cpp, at::cuda::getCurrentCUDAStream(), true);
218267
break;
219268
}
220269
case GroupedQuantizationMode::MXFP8_GROUPED_QUANTIZE: {
@@ -234,6 +283,79 @@ py::object group_quantize(const at::Tensor &tensor, py::handle quantizer, const
234283
return py::reinterpret_borrow<py::object>(grouped_output_py);
235284
}
236285

286+
py::object nvfp4_group_quantize_with_amax(const at::Tensor &tensor, py::handle quantizer,
287+
const size_t num_tensors,
288+
std::optional<at::Tensor> first_dims,
289+
const at::Tensor &rowwise_amax,
290+
const at::Tensor &columnwise_amax,
291+
std::optional<at::Tensor> tensor_offsets) {
292+
using namespace transformer_engine::pytorch::detail;
293+
init_extension();
294+
295+
NVTE_CHECK(tensor.dim() == 2, "Tensor must be 2D");
296+
NVTE_CHECK(rowwise_amax.is_cuda() && columnwise_amax.is_cuda(),
297+
"Precomputed amax tensors must be CUDA tensors.");
298+
NVTE_CHECK(
299+
rowwise_amax.scalar_type() == at::kFloat && columnwise_amax.scalar_type() == at::kFloat,
300+
"Precomputed amax tensors must be float32.");
301+
NVTE_CHECK(rowwise_amax.numel() == static_cast<int64_t>(num_tensors),
302+
"Rowwise amax must contain one value per group.");
303+
NVTE_CHECK(columnwise_amax.numel() == static_cast<int64_t>(num_tensors),
304+
"Columnwise amax must contain one value per group.");
305+
306+
std::vector<size_t> logical_shape;
307+
for (const auto &d : tensor.sizes()) {
308+
logical_shape.push_back(d);
309+
}
310+
const auto logical_first_dim = logical_shape[0];
311+
const auto logical_last_dim = logical_shape[1];
312+
313+
bool empty_input_buffer = logical_first_dim == 0 || logical_last_dim == 0;
314+
315+
auto quantizer_cpp = convert_quantizer(quantizer);
316+
NVTE_CHECK(IsNVFP4Quantizers(quantizer.ptr()),
317+
"nvfp4_group_quantize_with_amax only supports NVFP4 quantizers.");
318+
NVFP4Quantizer *nvfp4_quantizer_cpp = static_cast<NVFP4Quantizer *>(quantizer_cpp.get());
319+
320+
auto grouped_input_tensor = GroupedTensorWrapper(num_tensors, logical_shape);
321+
grouped_input_tensor.set_rowwise_data(
322+
tensor.data_ptr(), GetTransformerEngineDType(tensor.scalar_type()), getTensorShape(tensor));
323+
324+
auto [grouped_output_tensor_cpp, grouped_output_py] = quantizer_cpp->create_grouped_tensor(
325+
num_tensors, logical_shape, GetTransformerEngineDType(tensor.scalar_type()),
326+
py::reinterpret_borrow<py::object>(quantizer), first_dims, tensor_offsets, logical_first_dim,
327+
logical_last_dim);
328+
329+
if (grouped_output_tensor_cpp.get_amax().data_ptr != nullptr) {
330+
grouped_output_tensor_cpp.set_amax(rowwise_amax.data_ptr(), DType::kFloat32,
331+
getTensorShape(rowwise_amax));
332+
grouped_output_py.attr("amax") = py::cast(rowwise_amax);
333+
}
334+
if (grouped_output_tensor_cpp.get_columnwise_amax().data_ptr != nullptr) {
335+
grouped_output_tensor_cpp.set_columnwise_amax(columnwise_amax.data_ptr(), DType::kFloat32,
336+
getTensorShape(columnwise_amax));
337+
grouped_output_py.attr("columnwise_amax") = py::cast(columnwise_amax);
338+
}
339+
340+
std::vector<at::Tensor> amax_tensors;
341+
if (grouped_output_tensor_cpp.get_amax().data_ptr != nullptr) {
342+
amax_tensors.push_back(rowwise_amax);
343+
}
344+
if (grouped_output_tensor_cpp.get_columnwise_amax().data_ptr != nullptr) {
345+
amax_tensors.push_back(columnwise_amax);
346+
}
347+
allreduce_nvfp4_amax_tensors(nvfp4_quantizer_cpp, std::move(amax_tensors));
348+
349+
if (empty_input_buffer) {
350+
return py::reinterpret_borrow<py::object>(grouped_output_py);
351+
}
352+
353+
group_quantize_nvfp4_impl(grouped_input_tensor, grouped_output_tensor_cpp, nvfp4_quantizer_cpp,
354+
at::cuda::getCurrentCUDAStream(), false);
355+
356+
return py::reinterpret_borrow<py::object>(grouped_output_py);
357+
}
358+
237359
py::object bgrad_group_quantize(const at::Tensor &tensor, py::handle quantizer,
238360
const size_t num_tensors, std::optional<at::Tensor> first_dims,
239361
std::optional<at::Tensor> tensor_offsets) {

transformer_engine/pytorch/csrc/extensions/pybind.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,15 @@ void init_router_bindings(pybind11::module &m) {
164164
py::arg("grad_aux_loss"), "Fused aux loss bwd");
165165
}
166166

167+
void bind_quantize_with_amax_extensions(py::module_ &m) {
168+
m.def("nvfp4_quantize_with_amax", nvfp4_quantize_with_amax, py::arg("tensor"),
169+
py::arg("quantizer"), py::arg("rowwise_amax"), py::arg("columnwise_amax"));
170+
m.def("nvfp4_group_quantize_with_amax", nvfp4_group_quantize_with_amax, py::arg("tensor"),
171+
py::arg("quantizer"), py::arg("num_tensors"), py::arg("first_dims"),
172+
py::arg("rowwise_amax"), py::arg("columnwise_amax"),
173+
py::arg("tensor_offsets") = py::none());
174+
}
175+
167176
} // namespace transformer_engine::pytorch
168177

169178
#include "common/util/pybind_helper.h"
@@ -195,6 +204,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
195204
m.def("group_quantize", transformer_engine::pytorch::group_quantize, py::arg("tensor"),
196205
py::arg("quantizer"), py::arg("num_tensors"), py::arg("first_dims"),
197206
py::arg("tensor_offsets") = py::none());
207+
transformer_engine::pytorch::bind_quantize_with_amax_extensions(m);
198208
m.def("group_dequantize", transformer_engine::pytorch::group_dequantize,
199209
"Dequantize group tensor", py::arg("input"), py::arg("otype"));
200210
m.def("bgrad_group_quantize", transformer_engine::pytorch::bgrad_group_quantize,

transformer_engine/pytorch/csrc/quantizer.cpp

Lines changed: 39 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -2331,8 +2331,40 @@ void NVFP4Quantizer::quantize_with_rht_unfused_helper(
23312331
void NVFP4Quantizer::quantize_impl(const TensorWrapper& input, TensorWrapper& out,
23322332
const std::optional<TensorWrapper>& noop_flag,
23332333
bool compute_amax) {
2334+
auto reduce_amaxes = [&]() {
2335+
if (!this->with_amax_reduction) {
2336+
return;
2337+
}
2338+
2339+
std::vector<at::Tensor> amax_tensors;
2340+
auto make_amax_tensor = [](void* data_ptr) {
2341+
NVTE_CHECK(data_ptr != nullptr, "Could not find amax pointer for NVFP4 amax reduction.");
2342+
return at::from_blob(
2343+
data_ptr, std::vector<int64_t>{1},
2344+
[](void*) {}, // deleter doing nothing since it doesn't own the data
2345+
at::device(at::kCUDA).dtype(torch::kFloat32));
2346+
};
2347+
if (rowwise_usage) {
2348+
amax_tensors.push_back(make_amax_tensor(out.get_amax().data_ptr));
2349+
}
2350+
if (columnwise_usage) {
2351+
amax_tensors.push_back(make_amax_tensor(out.get_columnwise_amax().data_ptr));
2352+
}
2353+
if (amax_tensors.empty()) {
2354+
return;
2355+
}
2356+
2357+
c10d::AllreduceCoalescedOptions opts;
2358+
opts.reduceOp = c10d::ReduceOp::MAX;
2359+
NVTE_SCOPED_GIL_RELEASE(
2360+
{ this->amax_reduction_group->allreduce_coalesced(amax_tensors, opts)->wait(); });
2361+
};
2362+
23342363
// Nothing to be done if input is empty
23352364
if (input.numel() == 0) {
2365+
if (!compute_amax) {
2366+
reduce_amaxes();
2367+
}
23362368
return;
23372369
}
23382370

@@ -2431,10 +2463,12 @@ void NVFP4Quantizer::quantize_impl(const TensorWrapper& input, TensorWrapper& ou
24312463
// We need:
24322464
// 1. Rowwise amax = amax for input
24332465
// 2. Columnwise amax = amax for RHT(input.t)
2434-
NVTE_SCOPED_GIL_RELEASE({
2435-
nvte_hadamard_transform_amax(input.data(), out.data(), 0,
2436-
this->rht_matrix_random_sign_mask_t, stream);
2437-
});
2466+
if (compute_amax) {
2467+
NVTE_SCOPED_GIL_RELEASE({
2468+
nvte_hadamard_transform_amax(input.data(), out.data(), 0,
2469+
this->rht_matrix_random_sign_mask_t, stream);
2470+
});
2471+
}
24382472
} else {
24392473
// raise error since it's not supported yet
24402474
NVTE_ERROR(
@@ -2467,27 +2501,7 @@ void NVFP4Quantizer::quantize_impl(const TensorWrapper& input, TensorWrapper& ou
24672501
}
24682502
}
24692503

2470-
// amax reduction
2471-
if (this->with_amax_reduction) {
2472-
std::vector<at::Tensor> amax_tensors;
2473-
// push amax tensors inside if they need to be reduced
2474-
auto make_amax_tensor = [](void* data_ptr) {
2475-
return at::from_blob(
2476-
data_ptr, std::vector<int64_t>{1},
2477-
[](void*) {}, // deleter doing nothing since it doesn't own the data
2478-
at::device(at::kCUDA).dtype(torch::kFloat32));
2479-
};
2480-
if (rowwise_usage) {
2481-
amax_tensors.push_back(make_amax_tensor(out.get_amax().data_ptr));
2482-
}
2483-
if (columnwise_usage) {
2484-
amax_tensors.push_back(make_amax_tensor(out.get_columnwise_amax().data_ptr));
2485-
}
2486-
c10d::AllreduceCoalescedOptions opts;
2487-
opts.reduceOp = c10d::ReduceOp::MAX;
2488-
NVTE_SCOPED_GIL_RELEASE(
2489-
{ this->amax_reduction_group->allreduce_coalesced(amax_tensors, opts)->wait(); });
2490-
}
2504+
reduce_amaxes();
24912505

24922506
// Fast math toggle: RHT transform can be accelerated
24932507
// What math is accelerated? Only the high precision math, so numerical impact is minimal

0 commit comments

Comments
 (0)