Skip to content

Commit cb3a0de

Browse files
daniserebmeenchen
authored andcommitted
Add fake quant test for MXFP8
Signed-off-by: Daniel Serebrenik <daserebrenik@nvidia.com>
1 parent 5d45ba1 commit cb3a0de

File tree

1 file changed

+44
-8
lines changed

1 file changed

+44
-8
lines changed

tests/gpu/torch/quantization/test_qtensor_cuda.py

Lines changed: 44 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -250,14 +250,6 @@ def test_amax_from_tensor_quantizer(
250250
torch.randn([512, 512], dtype=torch.float32),
251251
None,
252252
),
253-
# MXFP8
254-
(
255-
(4, 3),
256-
{-1: 32, "type": "dynamic", "scale_bits": (8, 0)},
257-
None,
258-
torch.randn([512, 512], dtype=torch.float32),
259-
None,
260-
),
261253
],
262254
)
263255
@pytest.mark.parametrize("device", ["cpu", "cuda"])
@@ -909,3 +901,47 @@ def test_mxfp8_dequantize_default_dtype(self, device, input_dtype):
909901
dequant = qtensor.dequantize(scale=e8m0_scale)
910902

911903
assert dequant.dtype == input_dtype
904+
905+
@pytest.mark.parametrize("device", ["cuda"])
906+
@pytest.mark.parametrize("input_dtype", [torch.float32, torch.float16, torch.bfloat16])
907+
@pytest.mark.parametrize(
908+
"input_shape",
909+
[
910+
(64, 64),
911+
(128, 128),
912+
(4, 64, 128), # 3D MoE shape
913+
],
914+
)
915+
def test_mxfp8_fake_quant(self, device, input_dtype, input_shape):
916+
"""Test MXFP8 fake quantization via TensorQuantizer matches real quant+dequant."""
917+
block_sizes = {-1: 32, "type": "dynamic", "scale_bits": (8, 0)}
918+
919+
# Create fake quant quantizer
920+
fake_quant_cfg = QuantizerAttributeConfig(
921+
num_bits=(4, 3), block_sizes=block_sizes, fake_quant=True, axis=None
922+
)
923+
fake_quantizer = TensorQuantizer(fake_quant_cfg).to(device)
924+
925+
# Create real quant quantizer
926+
real_quant_cfg = QuantizerAttributeConfig(
927+
num_bits=(4, 3), block_sizes=block_sizes, fake_quant=False, axis=None
928+
)
929+
real_quantizer = TensorQuantizer(real_quant_cfg).to(device)
930+
931+
# Test tensor
932+
test_tensor = torch.randn(input_shape, dtype=input_dtype, device=device)
933+
934+
# Fake quant output
935+
fake_quant_output = fake_quantizer(test_tensor)
936+
937+
# Real quant + dequant
938+
q_tensor = real_quantizer(test_tensor)
939+
real_dequant_output = real_quantizer(q_tensor)
940+
941+
# Verify fake quant matches real quant+dequant
942+
assert fake_quant_output.shape == test_tensor.shape
943+
assert fake_quant_output.dtype == test_tensor.dtype
944+
assert torch.allclose(fake_quant_output, real_dequant_output, rtol=5e-2, atol=5e-2), (
945+
f"Fake quant differs from real quant+dequant: "
946+
f"max diff = {(fake_quant_output - real_dequant_output).abs().max()}"
947+
)

0 commit comments

Comments
 (0)