@@ -250,14 +250,6 @@ def test_amax_from_tensor_quantizer(
250250 torch .randn ([512 , 512 ], dtype = torch .float32 ),
251251 None ,
252252 ),
253- # MXFP8
254- (
255- (4 , 3 ),
256- {- 1 : 32 , "type" : "dynamic" , "scale_bits" : (8 , 0 )},
257- None ,
258- torch .randn ([512 , 512 ], dtype = torch .float32 ),
259- None ,
260- ),
261253 ],
262254 )
263255 @pytest .mark .parametrize ("device" , ["cpu" , "cuda" ])
@@ -909,3 +901,47 @@ def test_mxfp8_dequantize_default_dtype(self, device, input_dtype):
909901 dequant = qtensor .dequantize (scale = e8m0_scale )
910902
911903 assert dequant .dtype == input_dtype
904+
905+ @pytest .mark .parametrize ("device" , ["cuda" ])
906+ @pytest .mark .parametrize ("input_dtype" , [torch .float32 , torch .float16 , torch .bfloat16 ])
907+ @pytest .mark .parametrize (
908+ "input_shape" ,
909+ [
910+ (64 , 64 ),
911+ (128 , 128 ),
912+ (4 , 64 , 128 ), # 3D MoE shape
913+ ],
914+ )
915+ def test_mxfp8_fake_quant (self , device , input_dtype , input_shape ):
916+ """Test MXFP8 fake quantization via TensorQuantizer matches real quant+dequant."""
917+ block_sizes = {- 1 : 32 , "type" : "dynamic" , "scale_bits" : (8 , 0 )}
918+
919+ # Create fake quant quantizer
920+ fake_quant_cfg = QuantizerAttributeConfig (
921+ num_bits = (4 , 3 ), block_sizes = block_sizes , fake_quant = True , axis = None
922+ )
923+ fake_quantizer = TensorQuantizer (fake_quant_cfg ).to (device )
924+
925+ # Create real quant quantizer
926+ real_quant_cfg = QuantizerAttributeConfig (
927+ num_bits = (4 , 3 ), block_sizes = block_sizes , fake_quant = False , axis = None
928+ )
929+ real_quantizer = TensorQuantizer (real_quant_cfg ).to (device )
930+
931+ # Test tensor
932+ test_tensor = torch .randn (input_shape , dtype = input_dtype , device = device )
933+
934+ # Fake quant output
935+ fake_quant_output = fake_quantizer (test_tensor )
936+
937+ # Real quant + dequant
938+ q_tensor = real_quantizer (test_tensor )
939+ real_dequant_output = real_quantizer (q_tensor )
940+
941+ # Verify fake quant matches real quant+dequant
942+ assert fake_quant_output .shape == test_tensor .shape
943+ assert fake_quant_output .dtype == test_tensor .dtype
944+ assert torch .allclose (fake_quant_output , real_dequant_output , rtol = 5e-2 , atol = 5e-2 ), (
945+ f"Fake quant differs from real quant+dequant: "
946+ f"max diff = { (fake_quant_output - real_dequant_output ).abs ().max ()} "
947+ )
0 commit comments