Skip to content

Commit 4631d97

Browse files
ptrendxpre-commit-ci[bot]vthumbe1503
authored
[pyTorch] Replace the make_empty implementation to use C++ implementation (#2666)
* Replace the make_empty implementation to use C++ implementation for the known quantizers Signed-off-by: Przemek Tredak <ptredak@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix lint Signed-off-by: Przemek Tredak <ptredak@nvidia.com> * Handle the device passed as string Signed-off-by: Przemek Tredak <ptredak@nvidia.com> * Fix Signed-off-by: Przemek Tredak <ptredak@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix Signed-off-by: Przemek Tredak <ptredak@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fixes Signed-off-by: Przemek Tredak <ptredak@nvidia.com> * Replace the make_empty implementation to use C++ implementation for the known quantizers Signed-off-by: Przemek Tredak <ptredak@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix lint Signed-off-by: Przemek Tredak <ptredak@nvidia.com> * Handle the device passed as string Signed-off-by: Przemek Tredak <ptredak@nvidia.com> * Fix Signed-off-by: Przemek Tredak <ptredak@nvidia.com> * Fixes Signed-off-by: Przemek Tredak <ptredak@nvidia.com> * Fix duplicate create_empty_quantized_tensor after merge The merge with main introduced duplicate function definition, declaration, and pybind registration for create_empty_quantized_tensor. Remove the duplicates. Signed-off-by: Przemek Tredak <ptredak@nvidia.com> * Fix device index resolution in create_tensor Change the device parameter from at::Device with default torch::kCUDA to std::optional<at::Device> with default nullopt. When no device is specified, resolve to the current CUDA device via c10::cuda::current_device(), ensuring the device always has a valid index. This fixes autograd engine assertions when tensors created without an explicit device are used in backward passes. Signed-off-by: Przemek Tredak <ptredak@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Guard make_empty for custom quantizers without C++ converter Custom quantizers that set self.custom = True and don't override make_empty() will now get a clear NotImplementedError instead of hitting an opaque C++ NVTE_ERROR("Unexpected type for quantizer"). Signed-off-by: Przemek Tredak <ptredak@nvidia.com> * Fix the device from the passed data case Signed-off-by: Przemek Tredak <ptredak@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Przemek Tredak <ptredak@nvidia.com> Signed-off-by: vthumbe1503 <vthumbe@nvidia.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: vthumbe1503 <vthumbe@nvidia.com>
1 parent c3a1d30 commit 4631d97

10 files changed

Lines changed: 142 additions & 352 deletions

File tree

transformer_engine/pytorch/csrc/common.h

Lines changed: 25 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -102,8 +102,9 @@ class Quantizer {
102102
virtual void set_quantization_params(TensorWrapper* tensor) const = 0;
103103

104104
/*! @brief Construct a tensor with uninitialized data */
105-
virtual std::pair<TensorWrapper, py::object> create_tensor(const std::vector<size_t>& shape,
106-
DType dtype) const = 0;
105+
virtual std::pair<TensorWrapper, py::object> create_tensor(
106+
const std::vector<size_t>& shape, DType dtype,
107+
std::optional<at::Device> device = std::nullopt, bool pin_memory = false) const = 0;
107108

108109
/*! @brief Construct a grouped tensor with uninitialized data */
109110
virtual std::pair<GroupedTensorWrapper, py::object> create_grouped_tensor(
@@ -144,8 +145,9 @@ class NoneQuantizer : public Quantizer {
144145

145146
void set_quantization_params(TensorWrapper* tensor) const override {}
146147

147-
std::pair<TensorWrapper, py::object> create_tensor(const std::vector<size_t>& shape,
148-
DType dtype) const override;
148+
std::pair<TensorWrapper, py::object> create_tensor(
149+
const std::vector<size_t>& shape, DType dtype,
150+
std::optional<at::Device> device = std::nullopt, bool pin_memory = false) const override;
149151

150152
std::pair<GroupedTensorWrapper, py::object> create_grouped_tensor(
151153
size_t num_tensors, const std::vector<size_t>& logical_shape, DType dtype,
@@ -174,19 +176,20 @@ class Float8Quantizer : public Quantizer {
174176

175177
void set_quantization_params(TensorWrapper* tensor) const override;
176178

177-
std::pair<TensorWrapper, py::object> create_tensor(const std::vector<size_t>& shape,
178-
DType dtype) const override;
179+
std::pair<TensorWrapper, py::object> create_tensor(
180+
const std::vector<size_t>& shape, DType dtype,
181+
std::optional<at::Device> device = std::nullopt, bool pin_memory = false) const override;
179182

180183
std::pair<GroupedTensorWrapper, py::object> create_grouped_tensor(
181184
size_t num_tensors, const std::vector<size_t>& logical_shape, DType dtype,
182185
py::object quantizer, const std::optional<at::Tensor>& first_dims, size_t logical_first_dim,
183186
size_t logical_last_dim) const override;
184187

185188
/*! @brief Construct a tensor with pre-initialized data */
186-
std::pair<TensorWrapper, py::object> create_tensor(const std::vector<size_t>& shape, DType dtype,
187-
std::optional<at::Tensor> data,
188-
std::optional<at::Tensor> transpose,
189-
std::optional<at::Tensor> scale_inv) const;
189+
std::pair<TensorWrapper, py::object> create_tensor(
190+
const std::vector<size_t>& shape, DType dtype, std::optional<at::Tensor> data,
191+
std::optional<at::Tensor> transpose, std::optional<at::Tensor> scale_inv,
192+
std::optional<at::Device> device = std::nullopt, bool pin_memory = false) const;
190193

191194
std::pair<TensorWrapper, py::object> convert_and_update_tensor(py::object shape) const override;
192195

@@ -208,8 +211,9 @@ class Float8CurrentScalingQuantizer : public Quantizer {
208211

209212
void set_quantization_params(TensorWrapper* tensor) const override;
210213

211-
std::pair<TensorWrapper, py::object> create_tensor(const std::vector<size_t>& shape,
212-
DType dtype) const override;
214+
std::pair<TensorWrapper, py::object> create_tensor(
215+
const std::vector<size_t>& shape, DType dtype,
216+
std::optional<at::Device> device = std::nullopt, bool pin_memory = false) const override;
213217

214218
std::pair<GroupedTensorWrapper, py::object> create_grouped_tensor(
215219
size_t num_tensors, const std::vector<size_t>& logical_shape, DType dtype,
@@ -270,8 +274,9 @@ class Float8BlockQuantizer : public Quantizer {
270274
// Create a python Float8BlockQuantized tensor and C++ wrapper
271275
// for the tensor. Should set quantized data, scales for rowwise
272276
// and optionally columnwise usage.
273-
std::pair<TensorWrapper, py::object> create_tensor(const std::vector<size_t>& shape,
274-
DType dtype) const override;
277+
std::pair<TensorWrapper, py::object> create_tensor(
278+
const std::vector<size_t>& shape, DType dtype,
279+
std::optional<at::Device> device = std::nullopt, bool pin_memory = false) const override;
275280

276281
std::pair<GroupedTensorWrapper, py::object> create_grouped_tensor(
277282
size_t num_tensors, const std::vector<size_t>& logical_shape, DType dtype,
@@ -294,8 +299,9 @@ class MXFP8Quantizer : public Quantizer {
294299

295300
void set_quantization_params(TensorWrapper* tensor) const override;
296301

297-
std::pair<TensorWrapper, py::object> create_tensor(const std::vector<size_t>& shape,
298-
DType dtype) const override;
302+
std::pair<TensorWrapper, py::object> create_tensor(
303+
const std::vector<size_t>& shape, DType dtype,
304+
std::optional<at::Device> device = std::nullopt, bool pin_memory = false) const override;
299305

300306
std::pair<GroupedTensorWrapper, py::object> create_grouped_tensor(
301307
size_t num_tensors, const std::vector<size_t>& logical_shape, DType dtype,
@@ -333,8 +339,9 @@ class NVFP4Quantizer : public Quantizer {
333339

334340
void set_quantization_params(TensorWrapper* tensor) const override;
335341

336-
std::pair<TensorWrapper, py::object> create_tensor(const std::vector<size_t>& shape,
337-
DType dtype) const override;
342+
std::pair<TensorWrapper, py::object> create_tensor(
343+
const std::vector<size_t>& shape, DType dtype,
344+
std::optional<at::Device> device = std::nullopt, bool pin_memory = false) const override;
338345

339346
std::pair<GroupedTensorWrapper, py::object> create_grouped_tensor(
340347
size_t num_tensors, const std::vector<size_t>& logical_shape, DType dtype,

transformer_engine/pytorch/csrc/extensions.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -320,9 +320,12 @@ std::vector<at::Tensor> bulk_allocate(const std::vector<std::vector<size_t>> &sh
320320
std::optional<std::vector<size_t>> alignments = std::nullopt);
321321

322322
/***************************************************************************************************
323-
* Cast
323+
* Quantize
324324
**************************************************************************************************/
325325

326+
py::object create_empty_quantized_tensor(py::handle quantizer, const std::vector<size_t> &shape,
327+
at::ScalarType dtype, at::Device device, bool pin_memory);
328+
326329
py::object quantize(const at::Tensor &tensor, py::handle quantizer, const py::object &output,
327330
std::optional<at::Tensor> noop_flag);
328331

transformer_engine/pytorch/csrc/extensions/cast.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,14 @@ py::object quantize(const at::Tensor &tensor, py::handle quantizer, const py::ob
6565
return output_py;
6666
}
6767

68+
py::object create_empty_quantized_tensor(py::handle quantizer, const std::vector<size_t> &shape,
69+
at::ScalarType dtype, at::Device device, bool pin_memory) {
70+
auto quantizer_cpp = convert_quantizer(quantizer);
71+
auto te_dtype = GetTransformerEngineDType(dtype);
72+
auto [_, output_py] = quantizer_cpp->create_tensor(shape, te_dtype, device, pin_memory);
73+
return output_py;
74+
}
75+
6876
namespace {
6977

7078
// helper functions for NVFP4 grouped quantization (cuda graph safe with shapes stored in device without D2H copy)

transformer_engine/pytorch/csrc/extensions/pybind.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
139139
py::arg("output") = py::none(), py::arg("noop") = py::none());
140140
m.def("dequantize", &transformer_engine::pytorch::dequantize, "Dequantize", py::arg("input"),
141141
py::arg("otype"));
142+
m.def("create_empty_quantized_tensor",
143+
&transformer_engine::pytorch::create_empty_quantized_tensor,
144+
"Create an empty quantized tensor", py::arg("quantizer"), py::arg("shape"),
145+
py::arg("dtype"), py::arg("device"), py::arg("pin_memory"));
142146
m.def("group_quantize", transformer_engine::pytorch::group_quantize, py::arg("tensor"),
143147
py::arg("quantizer"), py::arg("num_tensors"), py::arg("first_dims"));
144148
m.def("group_dequantize", transformer_engine::pytorch::group_dequantize,

0 commit comments

Comments
 (0)