@@ -31,6 +31,18 @@ std::vector<size_t> get_tensor_shape(const TensorWrapper &tensor) {
3131 return std::vector<size_t >(shape.data , shape.data + shape.ndim );
3232}
3333
34+ void allreduce_nvfp4_amax_tensors (NVFP4Quantizer *nvfp4_quantizer_cpp,
35+ std::vector<at::Tensor> &&amax_tensors) {
36+ if (!nvfp4_quantizer_cpp->with_amax_reduction || amax_tensors.empty ()) {
37+ return ;
38+ }
39+ c10d::AllreduceCoalescedOptions opts;
40+ opts.reduceOp = c10d::ReduceOp::MAX ;
41+ NVTE_SCOPED_GIL_RELEASE ({
42+ nvfp4_quantizer_cpp->amax_reduction_group ->allreduce_coalesced (amax_tensors, opts)->wait ();
43+ });
44+ }
45+
3446} // namespace
3547
3648py::object quantize (const at::Tensor &tensor, py::handle quantizer, const py::object &output,
@@ -71,6 +83,51 @@ py::object quantize(const at::Tensor &tensor, py::handle quantizer, const py::ob
7183 return output_py;
7284}
7385
86+ py::object nvfp4_quantize_with_amax (const at::Tensor &tensor, py::handle quantizer,
87+ const at::Tensor &rowwise_amax,
88+ const at::Tensor &columnwise_amax) {
89+ using namespace transformer_engine ::pytorch::detail;
90+ init_extension ();
91+
92+ NVTE_CHECK (tensor.dim () >= 2 , " Tensor must be at least 2D" );
93+ NVTE_CHECK (rowwise_amax.is_cuda () && columnwise_amax.is_cuda (),
94+ " Precomputed amax tensors must be CUDA tensors." );
95+ NVTE_CHECK (
96+ rowwise_amax.scalar_type () == at::kFloat && columnwise_amax.scalar_type () == at::kFloat ,
97+ " Precomputed amax tensors must be float32." );
98+ NVTE_CHECK (rowwise_amax.numel () == 1 && columnwise_amax.numel () == 1 ,
99+ " nvfp4_quantize_with_amax expects scalar rowwise and columnwise amaxes." );
100+
101+ auto quantizer_cpp = convert_quantizer (quantizer);
102+ NVTE_CHECK (IsNVFP4Quantizers (quantizer.ptr ()),
103+ " nvfp4_quantize_with_amax only supports NVFP4 quantizers." );
104+ NVFP4Quantizer *nvfp4_quantizer_cpp = static_cast <NVFP4Quantizer *>(quantizer_cpp.get ());
105+
106+ auto input_contiguous = tensor.contiguous ();
107+ auto input_cpp = makeTransformerEngineTensor (input_contiguous);
108+
109+ const auto shape = get_tensor_shape (input_cpp);
110+ const auto fake_dtype = input_cpp.dtype ();
111+ auto [output_cpp, output_py] = quantizer_cpp->create_tensor (shape, fake_dtype);
112+
113+ if (output_cpp.get_amax ().data_ptr != nullptr ) {
114+ output_cpp.set_amax (rowwise_amax.data_ptr (), DType::kFloat32 , getTensorShape (rowwise_amax));
115+ output_py.attr (" _amax_rowwise" ) = py::cast (rowwise_amax);
116+ }
117+ if (output_cpp.get_columnwise_amax ().data_ptr != nullptr ) {
118+ output_cpp.set_columnwise_amax (columnwise_amax.data_ptr (), DType::kFloat32 ,
119+ getTensorShape (columnwise_amax));
120+ output_py.attr (" _amax_columnwise" ) = py::cast (columnwise_amax);
121+ }
122+
123+ nvfp4_quantizer_cpp->quantize_impl (input_cpp, output_cpp, std::nullopt , false );
124+ if (quantizer_cpp->optimize_for_gemm && !output_cpp.get_with_gemm_swizzled_scales ()) {
125+ inplace_swizzle_scale_for_gemm (output_py);
126+ }
127+
128+ return output_py;
129+ }
130+
74131py::object create_empty_quantized_tensor (py::handle quantizer, const std::vector<size_t > &shape,
75132 at::ScalarType dtype, at::Device device, bool pin_memory) {
76133 auto quantizer_cpp = convert_quantizer (quantizer);
@@ -84,14 +141,19 @@ namespace {
84141// helper functions for NVFP4 grouped quantization (cuda graph safe with shapes stored in device without D2H copy)
85142void group_quantize_nvfp4_impl (const GroupedTensorWrapper &grouped_input_tensor,
86143 GroupedTensorWrapper &grouped_output_tensor,
87- NVFP4Quantizer *nvfp4_quantizer_cpp, cudaStream_t stream) {
144+ NVFP4Quantizer *nvfp4_quantizer_cpp, cudaStream_t stream,
145+ bool compute_amax) {
88146 size_t num_tensors = grouped_input_tensor.num_tensors ();
89147
90148 // assert the 2D scaling case, since 2D scaling grouped quant kernel is not ready yet
91149 NVTE_CHECK (!nvfp4_quantizer_cpp->with_2d_quantization ,
92150 " 2D scaling grouped quant kernel is not ready yet" );
93151 NVTE_CHECK (nvfp4_quantizer_cpp->nvfp4_4over6_mode == kNVTENVFP44Over6Disabled ,
94152 " NVFP4 4over6 quantization is not supported for grouped quantization." );
153+ NVTE_CHECK (nvfp4_quantizer_cpp->with_rht ,
154+ " graph safe grouped quant kernel for non-RHT path is not ready yet" );
155+ NVTE_CHECK (nvfp4_quantizer_cpp->with_post_rht_amax ,
156+ " grouped NVFP4 RHT quantization expects post-RHT amax buffers." );
95157
96158 auto quant_config_cpp = QuantizationConfigWrapper ();
97159
@@ -122,37 +184,24 @@ void group_quantize_nvfp4_impl(const GroupedTensorWrapper &grouped_input_tensor,
122184 quant_config_cpp.set_use_fast_math (true );
123185 }
124186
125- // so far, only the RHT path has grouped kernel support
126- // grouped kernels for non-RHT path will be added later
127-
128- if (nvfp4_quantizer_cpp->with_rht ) {
129- // post-RHT amax or not
130- if (nvfp4_quantizer_cpp->with_post_rht_amax ) {
131- NVTE_SCOPED_GIL_RELEASE ({
132- nvte_group_hadamard_transform_amax_graph_safe (
133- grouped_input_tensor.data (), grouped_output_tensor.data (), 0 ,
134- nvfp4_quantizer_cpp->rht_matrix_random_sign_mask_t , stream);
135- });
136- } else {
137- NVTE_ERROR (" graph safe grouped quant kernel for non-RHT path is not ready yet" );
138- }
139-
140- // RHT cast fusion
141- auto tile_scheduler_workspace_torch =
142- at::empty ({1 }, at::device (at::kCUDA ).dtype (torch::kInt32 ));
143- auto nvte_tile_scheduler_workspace =
144- makeTransformerEngineTensor (tile_scheduler_workspace_torch);
145-
146- auto rht_matrix_nvte = makeTransformerEngineTensor (nvfp4_quantizer_cpp->rht_matrix );
187+ if (compute_amax) {
147188 NVTE_SCOPED_GIL_RELEASE ({
148- nvte_group_hadamard_transform_cast_fusion_graph_safe (
149- grouped_input_tensor.data (), grouped_output_tensor.data (), rht_matrix_nvte. data () ,
150- quant_config_cpp, nvte_tile_scheduler_workspace. data () , stream);
189+ nvte_group_hadamard_transform_amax_graph_safe (
190+ grouped_input_tensor.data (), grouped_output_tensor.data (), 0 ,
191+ nvfp4_quantizer_cpp-> rht_matrix_random_sign_mask_t , stream);
151192 });
152-
153- } else {
154- NVTE_ERROR (" graph safe grouped quant kernel for non-RHT path is not ready yet" );
155193 }
194+
195+ // RHT cast fusion
196+ auto tile_scheduler_workspace_torch = at::empty ({1 }, at::device (at::kCUDA ).dtype (torch::kInt32 ));
197+ auto nvte_tile_scheduler_workspace = makeTransformerEngineTensor (tile_scheduler_workspace_torch);
198+
199+ auto rht_matrix_nvte = makeTransformerEngineTensor (nvfp4_quantizer_cpp->rht_matrix );
200+ NVTE_SCOPED_GIL_RELEASE ({
201+ nvte_group_hadamard_transform_cast_fusion_graph_safe (
202+ grouped_input_tensor.data (), grouped_output_tensor.data (), rht_matrix_nvte.data (),
203+ quant_config_cpp, nvte_tile_scheduler_workspace.data (), stream);
204+ });
156205}
157206
158207} // namespace
@@ -214,7 +263,7 @@ py::object group_quantize(const at::Tensor &tensor, py::handle quantizer, const
214263 // NVFP4 grouped quantization
215264 NVFP4Quantizer *nvfp4_quantizer_cpp = static_cast <NVFP4Quantizer *>(quantizer_cpp.get ());
216265 group_quantize_nvfp4_impl (grouped_input_tensor, grouped_output_tensor_cpp,
217- nvfp4_quantizer_cpp, at::cuda::getCurrentCUDAStream ());
266+ nvfp4_quantizer_cpp, at::cuda::getCurrentCUDAStream (), true );
218267 break ;
219268 }
220269 case GroupedQuantizationMode::MXFP8_GROUPED_QUANTIZE : {
@@ -234,6 +283,79 @@ py::object group_quantize(const at::Tensor &tensor, py::handle quantizer, const
234283 return py::reinterpret_borrow<py::object>(grouped_output_py);
235284}
236285
286+ py::object nvfp4_group_quantize_with_amax (const at::Tensor &tensor, py::handle quantizer,
287+ const size_t num_tensors,
288+ std::optional<at::Tensor> first_dims,
289+ const at::Tensor &rowwise_amax,
290+ const at::Tensor &columnwise_amax,
291+ std::optional<at::Tensor> tensor_offsets) {
292+ using namespace transformer_engine ::pytorch::detail;
293+ init_extension ();
294+
295+ NVTE_CHECK (tensor.dim () == 2 , " Tensor must be 2D" );
296+ NVTE_CHECK (rowwise_amax.is_cuda () && columnwise_amax.is_cuda (),
297+ " Precomputed amax tensors must be CUDA tensors." );
298+ NVTE_CHECK (
299+ rowwise_amax.scalar_type () == at::kFloat && columnwise_amax.scalar_type () == at::kFloat ,
300+ " Precomputed amax tensors must be float32." );
301+ NVTE_CHECK (rowwise_amax.numel () == static_cast <int64_t >(num_tensors),
302+ " Rowwise amax must contain one value per group." );
303+ NVTE_CHECK (columnwise_amax.numel () == static_cast <int64_t >(num_tensors),
304+ " Columnwise amax must contain one value per group." );
305+
306+ std::vector<size_t > logical_shape;
307+ for (const auto &d : tensor.sizes ()) {
308+ logical_shape.push_back (d);
309+ }
310+ const auto logical_first_dim = logical_shape[0 ];
311+ const auto logical_last_dim = logical_shape[1 ];
312+
313+ bool empty_input_buffer = logical_first_dim == 0 || logical_last_dim == 0 ;
314+
315+ auto quantizer_cpp = convert_quantizer (quantizer);
316+ NVTE_CHECK (IsNVFP4Quantizers (quantizer.ptr ()),
317+ " nvfp4_group_quantize_with_amax only supports NVFP4 quantizers." );
318+ NVFP4Quantizer *nvfp4_quantizer_cpp = static_cast <NVFP4Quantizer *>(quantizer_cpp.get ());
319+
320+ auto grouped_input_tensor = GroupedTensorWrapper (num_tensors, logical_shape);
321+ grouped_input_tensor.set_rowwise_data (
322+ tensor.data_ptr (), GetTransformerEngineDType (tensor.scalar_type ()), getTensorShape (tensor));
323+
324+ auto [grouped_output_tensor_cpp, grouped_output_py] = quantizer_cpp->create_grouped_tensor (
325+ num_tensors, logical_shape, GetTransformerEngineDType (tensor.scalar_type ()),
326+ py::reinterpret_borrow<py::object>(quantizer), first_dims, tensor_offsets, logical_first_dim,
327+ logical_last_dim);
328+
329+ if (grouped_output_tensor_cpp.get_amax ().data_ptr != nullptr ) {
330+ grouped_output_tensor_cpp.set_amax (rowwise_amax.data_ptr (), DType::kFloat32 ,
331+ getTensorShape (rowwise_amax));
332+ grouped_output_py.attr (" amax" ) = py::cast (rowwise_amax);
333+ }
334+ if (grouped_output_tensor_cpp.get_columnwise_amax ().data_ptr != nullptr ) {
335+ grouped_output_tensor_cpp.set_columnwise_amax (columnwise_amax.data_ptr (), DType::kFloat32 ,
336+ getTensorShape (columnwise_amax));
337+ grouped_output_py.attr (" columnwise_amax" ) = py::cast (columnwise_amax);
338+ }
339+
340+ std::vector<at::Tensor> amax_tensors;
341+ if (grouped_output_tensor_cpp.get_amax ().data_ptr != nullptr ) {
342+ amax_tensors.push_back (rowwise_amax);
343+ }
344+ if (grouped_output_tensor_cpp.get_columnwise_amax ().data_ptr != nullptr ) {
345+ amax_tensors.push_back (columnwise_amax);
346+ }
347+ allreduce_nvfp4_amax_tensors (nvfp4_quantizer_cpp, std::move (amax_tensors));
348+
349+ if (empty_input_buffer) {
350+ return py::reinterpret_borrow<py::object>(grouped_output_py);
351+ }
352+
353+ group_quantize_nvfp4_impl (grouped_input_tensor, grouped_output_tensor_cpp, nvfp4_quantizer_cpp,
354+ at::cuda::getCurrentCUDAStream (), false );
355+
356+ return py::reinterpret_borrow<py::object>(grouped_output_py);
357+ }
358+
237359py::object bgrad_group_quantize (const at::Tensor &tensor, py::handle quantizer,
238360 const size_t num_tensors, std::optional<at::Tensor> first_dims,
239361 std::optional<at::Tensor> tensor_offsets) {
0 commit comments