|
49 | 49 | INT4QTensor, |
50 | 50 | INT8QTensor, |
51 | 51 | MXFP4QTensor, |
| 52 | + MXFP8QTensor, |
52 | 53 | NF4QTensor, |
53 | 54 | NVFP4QTensor, |
54 | 55 | QTensorWrapper, |
@@ -649,8 +650,32 @@ def _real_quantize(self, inputs): |
649 | 650 | assert self._is_real_quantize_support(), "Real quantization not supported for this format." |
650 | 651 |
|
651 | 652 | buffer_to_register = {} |
652 | | - if self._num_bits == (4, 3): |
653 | | - # FP8 quantization |
| 653 | + # Check MX formats first (before FP8) since MXFP8 also has num_bits=(4,3) |
| 654 | + if ( |
| 655 | + self._block_sizes |
| 656 | + and self._block_sizes.get("scale_bits") == (8, 0) |
| 657 | + and self._block_sizes.get("type") == "dynamic" |
| 658 | + ): |
| 659 | + # MX quantization (MXFP4/MXFP8) |
| 660 | + if self._num_bits == (2, 1): |
| 661 | + # MXFP4 |
| 662 | + outputs, scales = MXFP4QTensor.quantize(inputs, self._block_sizes[-1]) |
| 663 | + buffer_to_register["_scale"] = scales |
| 664 | + elif self._num_bits == (4, 3): |
| 665 | + # MXFP8 |
| 666 | + assert self._block_sizes[-1] == MXFP8QTensor.BLOCK_SIZE, ( |
| 667 | + f"MXFP8 requires block size {MXFP8QTensor.BLOCK_SIZE}, " |
| 668 | + f"got {self._block_sizes[-1]}" |
| 669 | + ) |
| 670 | + outputs, scales = MXFP8QTensor.quantize(inputs) |
| 671 | + buffer_to_register["_scale"] = scales |
| 672 | + else: |
| 673 | + raise ValueError( |
| 674 | + f"Unsupported MX format: num_bits={self._num_bits}. " |
| 675 | + f"Expected (2, 1) for MXFP4 or (4, 3) for MXFP8." |
| 676 | + ) |
| 677 | + elif self._num_bits == (4, 3): |
| 678 | + # FP8 quantization (non-MX) |
654 | 679 | # For per-tensor/per-channel quantization, we might need amax which is synced across all ranks |
655 | 680 | # For blockwise quantization, amax will be recomputed in the kernel |
656 | 681 | use_amax = self.amax is not None and not (self._block_sizes and self.amax.numel() == 1) |
@@ -683,18 +708,6 @@ def _real_quantize(self, inputs): |
683 | 708 | buffer_to_register["_scale"] = _scale |
684 | 709 | buffer_to_register["_double_scale"] = _double_scale |
685 | 710 | buffer_to_register["_scale_zeros"] = _scale_zeros |
686 | | - elif ( |
687 | | - self._block_sizes.get("scale_bits") == (8, 0) |
688 | | - and self._block_sizes.get("type") == "dynamic" |
689 | | - ): |
690 | | - # MX quantization |
691 | | - if self._num_bits == (2, 1): |
692 | | - outputs, scales = MXFP4QTensor.quantize(inputs, self._block_sizes[-1]) |
693 | | - buffer_to_register["_scale"] = scales |
694 | | - else: |
695 | | - raise ValueError( |
696 | | - f"Real quantization for MX {self._num_bits} format is not supported." |
697 | | - ) |
698 | 711 | elif self._block_sizes.get("scale_bits") == (4, 3): |
699 | 712 | # NVFP4 default quantization |
700 | 713 | # Return real quantized tensor and store scales inside TensorQuantizer |
|
0 commit comments