@@ -650,8 +650,31 @@ def _real_quantize(self, inputs):
650650 assert self ._is_real_quantize_support (), "Real quantization not supported for this format."
651651
652652 buffer_to_register = {}
653- if self ._num_bits == (4 , 3 ):
654- # FP8 quantization
653+ # Check MX formats first (before FP8) since MXFP8 also has num_bits=(4,3)
654+ if (
655+ self ._block_sizes
656+ and self ._block_sizes .get ("scale_bits" ) == (8 , 0 )
657+ and self ._block_sizes .get ("type" ) == "dynamic"
658+ ):
659+ # MX quantization (MXFP4/MXFP8)
660+ if self ._num_bits == (2 , 1 ):
661+ # MXFP4
662+ outputs , scales = MXFP4QTensor .quantize (inputs , self ._block_sizes [- 1 ])
663+ buffer_to_register ["_scale" ] = scales
664+ elif self ._num_bits == (4 , 3 ):
665+ # MXFP8
666+ assert self ._block_sizes [- 1 ] == MXFP8QTensor .BLOCK_SIZE , (
667+ f"MXFP8 requires block size { MXFP8QTensor .BLOCK_SIZE } , "
668+ f"got { self ._block_sizes [- 1 ]} "
669+ )
670+ outputs , scales = MXFP8QTensor .quantize (inputs )
671+ buffer_to_register ["_scale" ] = scales
672+ else :
673+ raise ValueError (
674+ f"Real quantization for MX { self ._num_bits } format is not supported."
675+ )
676+ elif self ._num_bits == (4 , 3 ):
677+ # FP8 quantization (non-MX)
655678 # For per-tensor/per-channel quantization, we might need amax which is synced across all ranks
656679 # For blockwise quantization, amax will be recomputed in the kernel
657680 use_amax = self .amax is not None and not (self ._block_sizes and self .amax .numel () == 1 )
@@ -684,27 +707,6 @@ def _real_quantize(self, inputs):
684707 buffer_to_register ["_scale" ] = _scale
685708 buffer_to_register ["_double_scale" ] = _double_scale
686709 buffer_to_register ["_scale_zeros" ] = _scale_zeros
687- elif (
688- self ._block_sizes .get ("scale_bits" ) == (8 , 0 )
689- and self ._block_sizes .get ("type" ) == "dynamic"
690- ):
691- # MX quantization
692- if self ._num_bits == (2 , 1 ):
693- # MXFP4
694- outputs , scales = MXFP4QTensor .quantize (inputs , self ._block_sizes [- 1 ])
695- buffer_to_register ["_scale" ] = scales
696- elif self ._num_bits == (4 , 3 ):
697- # MXFP8
698- assert self ._block_sizes [- 1 ] == MXFP8QTensor .BLOCK_SIZE , (
699- f"MXFP8 requires block size { MXFP8QTensor .BLOCK_SIZE } , "
700- f"got { self ._block_sizes [- 1 ]} "
701- )
702- outputs , scales = MXFP8QTensor .quantize (inputs )
703- buffer_to_register ["_scale" ] = scales
704- else :
705- raise ValueError (
706- f"Real quantization for MX { self ._num_bits } format is not supported."
707- )
708710 elif self ._block_sizes .get ("scale_bits" ) == (4 , 3 ):
709711 # NVFP4 default quantization
710712 # Return real quantized tensor and store scales inside TensorQuantizer
0 commit comments