Skip to content

Commit eda660a

Browse files
daniserebmeenchen
authored andcommitted
Fix TensorQuantizer to handle MXFP8.
Tested by test_qtensor_accuracy. Signed-off-by: Daniel Serebrenik <daserebrenik@nvidia.com>
1 parent 69f9a76 commit eda660a

1 file changed

Lines changed: 25 additions & 23 deletions

File tree

modelopt/torch/quantization/nn/modules/tensor_quantizer.py

Lines changed: 25 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -650,8 +650,31 @@ def _real_quantize(self, inputs):
650650
assert self._is_real_quantize_support(), "Real quantization not supported for this format."
651651

652652
buffer_to_register = {}
653-
if self._num_bits == (4, 3):
654-
# FP8 quantization
653+
# Check MX formats first (before FP8) since MXFP8 also has num_bits=(4,3)
654+
if (
655+
self._block_sizes
656+
and self._block_sizes.get("scale_bits") == (8, 0)
657+
and self._block_sizes.get("type") == "dynamic"
658+
):
659+
# MX quantization (MXFP4/MXFP8)
660+
if self._num_bits == (2, 1):
661+
# MXFP4
662+
outputs, scales = MXFP4QTensor.quantize(inputs, self._block_sizes[-1])
663+
buffer_to_register["_scale"] = scales
664+
elif self._num_bits == (4, 3):
665+
# MXFP8
666+
assert self._block_sizes[-1] == MXFP8QTensor.BLOCK_SIZE, (
667+
f"MXFP8 requires block size {MXFP8QTensor.BLOCK_SIZE}, "
668+
f"got {self._block_sizes[-1]}"
669+
)
670+
outputs, scales = MXFP8QTensor.quantize(inputs)
671+
buffer_to_register["_scale"] = scales
672+
else:
673+
raise ValueError(
674+
f"Real quantization for MX {self._num_bits} format is not supported."
675+
)
676+
elif self._num_bits == (4, 3):
677+
# FP8 quantization (non-MX)
655678
# For per-tensor/per-channel quantization, we might need amax which is synced across all ranks
656679
# For blockwise quantization, amax will be recomputed in the kernel
657680
use_amax = self.amax is not None and not (self._block_sizes and self.amax.numel() == 1)
@@ -684,27 +707,6 @@ def _real_quantize(self, inputs):
684707
buffer_to_register["_scale"] = _scale
685708
buffer_to_register["_double_scale"] = _double_scale
686709
buffer_to_register["_scale_zeros"] = _scale_zeros
687-
elif (
688-
self._block_sizes.get("scale_bits") == (8, 0)
689-
and self._block_sizes.get("type") == "dynamic"
690-
):
691-
# MX quantization
692-
if self._num_bits == (2, 1):
693-
# MXFP4
694-
outputs, scales = MXFP4QTensor.quantize(inputs, self._block_sizes[-1])
695-
buffer_to_register["_scale"] = scales
696-
elif self._num_bits == (4, 3):
697-
# MXFP8
698-
assert self._block_sizes[-1] == MXFP8QTensor.BLOCK_SIZE, (
699-
f"MXFP8 requires block size {MXFP8QTensor.BLOCK_SIZE}, "
700-
f"got {self._block_sizes[-1]}"
701-
)
702-
outputs, scales = MXFP8QTensor.quantize(inputs)
703-
buffer_to_register["_scale"] = scales
704-
else:
705-
raise ValueError(
706-
f"Real quantization for MX {self._num_bits} format is not supported."
707-
)
708710
elif self._block_sizes.get("scale_bits") == (4, 3):
709711
# NVFP4 default quantization
710712
# Return real quantized tensor and store scales inside TensorQuantizer

0 commit comments

Comments
 (0)