Skip to content

Commit d9c22e5

Browse files
daniserebmeenchen
authored andcommitted
Add more tests for MXFP8 (error handling)
Signed-off-by: Daniel Serebrenik <daserebrenik@nvidia.com>
1 parent ddfa953 commit d9c22e5

1 file changed

Lines changed: 83 additions & 0 deletions

File tree

tests/gpu/torch/quantization/test_qtensor_cuda.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -804,3 +804,86 @@ def test_mxfp8_get_weights_scaling_factor(self, device, input_shape):
804804
# The code clamps unbiased exponent to [-127, 127], giving biased range [0, 254]
805805
# Note: 255 (0xFF) represents NaN in E8M0 and should never appear from valid weights
806806
assert torch.all(e8m0_scale <= 254), "E8M0 scale contains NaN value (255)"
807+
808+
@pytest.mark.parametrize(
809+
("amax_value", "expected_exponent"),
810+
[
811+
(0.0, -127.0), # Zero amax: minimum exponent
812+
(448.0, 0.0), # E4M3_MAX: exponent 0
813+
(1.0, -8.0), # log2(1/448) ~ -8.8, ceil = -8
814+
(1e40, 127.0), # Very large amax: clamps to max
815+
(1e-50, -127.0), # Very small amax: clamps to min
816+
],
817+
)
818+
def test_mxfp8_compute_e8m0_exponent_edge_cases(self, amax_value, expected_exponent):
819+
"""Test _compute_e8m0_exponent handles edge cases correctly."""
820+
amax = torch.tensor([amax_value], device="cuda")
821+
exponent = MXFP8QTensor._compute_e8m0_exponent(amax)
822+
assert exponent.item() == expected_exponent, (
823+
f"amax={amax_value} should give exponent {expected_exponent}, got {exponent.item()}"
824+
)
825+
826+
def test_mxfp8_get_weights_scaling_factor_asserts_1d_weight(self):
827+
"""Test get_weights_scaling_factor raises assertion for 1D tensor."""
828+
weight_1d = torch.randn(64, device="cuda")
829+
with pytest.raises(AssertionError, match="Weight must be at least 2D"):
830+
MXFP8QTensor.get_weights_scaling_factor(weight_1d)
831+
832+
def test_mxfp8_get_weights_scaling_factor_asserts_non_divisible(self):
833+
"""Test get_weights_scaling_factor raises assertion when dim not divisible by 32."""
834+
# 33 is not divisible by 32
835+
weight = torch.randn(64, 33, device="cuda")
836+
with pytest.raises(AssertionError, match="must be divisible by MXFP8 block size"):
837+
MXFP8QTensor.get_weights_scaling_factor(weight)
838+
839+
@pytest.mark.parametrize("device", ["cuda"])
840+
def test_mxfp8_quantize_with_scale_asserts(self, device):
841+
"""Test quantize_with_scale raises assertions for invalid inputs."""
842+
# Test 1D weight assertion
843+
weight_1d = torch.randn(64, dtype=torch.float32, device=device)
844+
scale = torch.randint(0, 255, (2,), dtype=torch.uint8, device=device)
845+
with pytest.raises(AssertionError, match="Weight must be at least 2D"):
846+
MXFP8QTensor.quantize_with_scale(weight_1d, scale)
847+
848+
# Test wrong scale dtype assertion
849+
weight = torch.randn(64, 64, dtype=torch.float32, device=device)
850+
wrong_dtype_scale = torch.randn(64, 2, dtype=torch.float32, device=device)
851+
with pytest.raises(AssertionError, match="e8m0_scale must be"):
852+
MXFP8QTensor.quantize_with_scale(weight, wrong_dtype_scale)
853+
854+
# Test non-divisible dimension assertion
855+
weight_bad_dim = torch.randn(64, 33, dtype=torch.float32, device=device)
856+
scale = torch.randint(0, 255, (64, 1), dtype=torch.uint8, device=device)
857+
with pytest.raises(AssertionError, match="must be divisible by MXFP8 block size"):
858+
MXFP8QTensor.quantize_with_scale(weight_bad_dim, scale)
859+
860+
@pytest.mark.parametrize("device", ["cuda"])
861+
def test_mxfp8_quantize_dequantize_asserts(self, device):
862+
"""Test quantize and dequantize raise assertions for invalid inputs."""
863+
# Test empty tensor assertion
864+
empty_tensor = torch.empty(0, dtype=torch.float32, device=device)
865+
with pytest.raises(AssertionError, match="Input tensor must not be empty"):
866+
MXFP8QTensor.quantize(empty_tensor)
867+
868+
# Test 0D tensor assertion
869+
scalar_tensor = torch.tensor(1.0, dtype=torch.float32, device=device)
870+
with pytest.raises(AssertionError, match="Input must have at least 1 dimension"):
871+
MXFP8QTensor.quantize(scalar_tensor)
872+
873+
# Test non-floating point assertion
874+
int_tensor = torch.randint(0, 10, (32, 32), dtype=torch.int32, device=device)
875+
with pytest.raises(AssertionError, match="Input must be floating point"):
876+
MXFP8QTensor.quantize(int_tensor)
877+
878+
# Create a valid quantized tensor for dequantize tests
879+
input_tensor = torch.randn(64, 64, dtype=torch.float32, device=device)
880+
qtensor, e8m0_scale = MXFP8QTensor.quantize(input_tensor)
881+
882+
# Test missing scale assertion
883+
with pytest.raises(AssertionError, match="dequantize requires 'scale' in kwargs"):
884+
qtensor.dequantize(dtype=torch.float32)
885+
886+
# Test wrong scale dtype assertion
887+
wrong_dtype_scale = torch.randn(64, 2, dtype=torch.float32, device=device)
888+
with pytest.raises(AssertionError, match="e8m0_scale must be"):
889+
qtensor.dequantize(dtype=torch.float32, scale=wrong_dtype_scale)

0 commit comments

Comments
 (0)