Skip to content

Commit a30a126

Browse files
authored
Fix zero input shape for bgrad_group_quantize (NVIDIA#2854)
fix zero input shape for dbias Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
1 parent 77b8681 commit a30a126

2 files changed

Lines changed: 30 additions & 2 deletions

File tree

tests/pytorch/test_grouped_tensor.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -410,6 +410,27 @@ def test_quantize_grouped_mxfp8(self, shape: List[Tuple[int, int]], output_dbias
410410
expected_dbias = torch.stack([t.sum(dim=0) for t in input_tensors])
411411
assert torch.allclose(dbias, expected_dbias)
412412

413+
@pytest.mark.skipif(not mxfp8_available, reason=reason_for_no_mxfp8)
414+
def test_bgrad_group_quantize_zero_size_tensor(self) -> None:
415+
"""Test bgrad_group_quantize handles zero-row input without error."""
416+
num_tensors = 3
417+
last_dim = 1024
418+
grouped_input = torch.empty(0, last_dim, dtype=torch.bfloat16, device="cuda")
419+
420+
quantizer = MXFP8Quantizer(fp8_dtype=tex.DType.kFloat8E4M3)
421+
quantizer.set_usage(rowwise=True, columnwise=False)
422+
first_dims = torch.zeros(num_tensors, dtype=torch.int64, device="cuda")
423+
424+
grouped_output, dbias = tex.bgrad_group_quantize(
425+
grouped_input,
426+
quantizer,
427+
num_tensors,
428+
first_dims,
429+
)
430+
431+
assert dbias.shape == (num_tensors, last_dim)
432+
assert torch.all(dbias == 0)
433+
413434
@pytest.mark.parametrize("output_dbias", [False, True])
414435
@pytest.mark.skipif(not mxfp8_available, reason=reason_for_no_mxfp8)
415436
def test_group_quantize_cudagraph_capturable(self, output_dbias: bool) -> None:

transformer_engine/pytorch/csrc/extensions/cast.cpp

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -247,8 +247,7 @@ py::object bgrad_group_quantize(const at::Tensor &tensor, py::handle quantizer,
247247
const auto logical_first_dim = logical_shape[0];
248248
const auto logical_last_dim = logical_shape[1];
249249

250-
NVTE_CHECK(logical_first_dim > 0 && logical_last_dim > 0,
251-
"bgrad_group_quantize: empty input tensor is not supported.");
250+
bool empty_input_buffer = logical_first_dim == 0 || logical_last_dim == 0;
252251

253252
NVTE_CHECK(detail::IsMXFP8Quantizers(quantizer.ptr()),
254253
"bgrad_group_quantize: only MXFP8 quantizer is supported.");
@@ -264,6 +263,14 @@ py::object bgrad_group_quantize(const at::Tensor &tensor, py::handle quantizer,
264263
py::reinterpret_borrow<py::object>(quantizer), first_dims, logical_first_dim,
265264
logical_last_dim);
266265

266+
if (empty_input_buffer) {
267+
at::Tensor dbias_torch =
268+
at::zeros({static_cast<int64_t>(num_tensors), static_cast<int64_t>(logical_last_dim)},
269+
tensor.options());
270+
return py::make_tuple(py::reinterpret_borrow<py::object>(grouped_output_py),
271+
py::cast(std::move(dbias_torch)));
272+
}
273+
267274
const std::vector<size_t> dbias_logical_shape = {num_tensors, logical_last_dim};
268275
GroupedTensorWrapper grouped_dbias(num_tensors, dbias_logical_shape, NVTE_DELAYED_TENSOR_SCALING);
269276
at::Tensor dbias_torch =

0 commit comments

Comments
 (0)