@@ -804,3 +804,86 @@ def test_mxfp8_get_weights_scaling_factor(self, device, input_shape):
804804 # The code clamps unbiased exponent to [-127, 127], giving biased range [0, 254]
805805 # Note: 255 (0xFF) represents NaN in E8M0 and should never appear from valid weights
806806 assert torch .all (e8m0_scale <= 254 ), "E8M0 scale contains NaN value (255)"
807+
808+ @pytest .mark .parametrize (
809+ ("amax_value" , "expected_exponent" ),
810+ [
811+ (0.0 , - 127.0 ), # Zero amax: minimum exponent
812+ (448.0 , 0.0 ), # E4M3_MAX: exponent 0
813+ (1.0 , - 8.0 ), # log2(1/448) ~ -8.8, ceil = -8
814+ (1e40 , 127.0 ), # Very large amax: clamps to max
815+ (1e-50 , - 127.0 ), # Very small amax: clamps to min
816+ ],
817+ )
818+ def test_mxfp8_compute_e8m0_exponent_edge_cases (self , amax_value , expected_exponent ):
819+ """Test _compute_e8m0_exponent handles edge cases correctly."""
820+ amax = torch .tensor ([amax_value ], device = "cuda" )
821+ exponent = MXFP8QTensor ._compute_e8m0_exponent (amax )
822+ assert exponent .item () == expected_exponent , (
823+ f"amax={ amax_value } should give exponent { expected_exponent } , got { exponent .item ()} "
824+ )
825+
826+ def test_mxfp8_get_weights_scaling_factor_asserts_1d_weight (self ):
827+ """Test get_weights_scaling_factor raises assertion for 1D tensor."""
828+ weight_1d = torch .randn (64 , device = "cuda" )
829+ with pytest .raises (AssertionError , match = "Weight must be at least 2D" ):
830+ MXFP8QTensor .get_weights_scaling_factor (weight_1d )
831+
832+ def test_mxfp8_get_weights_scaling_factor_asserts_non_divisible (self ):
833+ """Test get_weights_scaling_factor raises assertion when dim not divisible by 32."""
834+ # 33 is not divisible by 32
835+ weight = torch .randn (64 , 33 , device = "cuda" )
836+ with pytest .raises (AssertionError , match = "must be divisible by MXFP8 block size" ):
837+ MXFP8QTensor .get_weights_scaling_factor (weight )
838+
839+ @pytest .mark .parametrize ("device" , ["cuda" ])
840+ def test_mxfp8_quantize_with_scale_asserts (self , device ):
841+ """Test quantize_with_scale raises assertions for invalid inputs."""
842+ # Test 1D weight assertion
843+ weight_1d = torch .randn (64 , dtype = torch .float32 , device = device )
844+ scale = torch .randint (0 , 255 , (2 ,), dtype = torch .uint8 , device = device )
845+ with pytest .raises (AssertionError , match = "Weight must be at least 2D" ):
846+ MXFP8QTensor .quantize_with_scale (weight_1d , scale )
847+
848+ # Test wrong scale dtype assertion
849+ weight = torch .randn (64 , 64 , dtype = torch .float32 , device = device )
850+ wrong_dtype_scale = torch .randn (64 , 2 , dtype = torch .float32 , device = device )
851+ with pytest .raises (AssertionError , match = "e8m0_scale must be" ):
852+ MXFP8QTensor .quantize_with_scale (weight , wrong_dtype_scale )
853+
854+ # Test non-divisible dimension assertion
855+ weight_bad_dim = torch .randn (64 , 33 , dtype = torch .float32 , device = device )
856+ scale = torch .randint (0 , 255 , (64 , 1 ), dtype = torch .uint8 , device = device )
857+ with pytest .raises (AssertionError , match = "must be divisible by MXFP8 block size" ):
858+ MXFP8QTensor .quantize_with_scale (weight_bad_dim , scale )
859+
860+ @pytest .mark .parametrize ("device" , ["cuda" ])
861+ def test_mxfp8_quantize_dequantize_asserts (self , device ):
862+ """Test quantize and dequantize raise assertions for invalid inputs."""
863+ # Test empty tensor assertion
864+ empty_tensor = torch .empty (0 , dtype = torch .float32 , device = device )
865+ with pytest .raises (AssertionError , match = "Input tensor must not be empty" ):
866+ MXFP8QTensor .quantize (empty_tensor )
867+
868+ # Test 0D tensor assertion
869+ scalar_tensor = torch .tensor (1.0 , dtype = torch .float32 , device = device )
870+ with pytest .raises (AssertionError , match = "Input must have at least 1 dimension" ):
871+ MXFP8QTensor .quantize (scalar_tensor )
872+
873+ # Test non-floating point assertion
874+ int_tensor = torch .randint (0 , 10 , (32 , 32 ), dtype = torch .int32 , device = device )
875+ with pytest .raises (AssertionError , match = "Input must be floating point" ):
876+ MXFP8QTensor .quantize (int_tensor )
877+
878+ # Create a valid quantized tensor for dequantize tests
879+ input_tensor = torch .randn (64 , 64 , dtype = torch .float32 , device = device )
880+ qtensor , e8m0_scale = MXFP8QTensor .quantize (input_tensor )
881+
882+ # Test missing scale assertion
883+ with pytest .raises (AssertionError , match = "dequantize requires 'scale' in kwargs" ):
884+ qtensor .dequantize (dtype = torch .float32 )
885+
886+ # Test wrong scale dtype assertion
887+ wrong_dtype_scale = torch .randn (64 , 2 , dtype = torch .float32 , device = device )
888+ with pytest .raises (AssertionError , match = "e8m0_scale must be" ):
889+ qtensor .dequantize (dtype = torch .float32 , scale = wrong_dtype_scale )
0 commit comments