@@ -839,12 +839,6 @@ def test_mxfp8_get_weights_scaling_factor_asserts_non_divisible(self):
839839 @pytest .mark .parametrize ("device" , ["cuda" ])
840840 def test_mxfp8_quantize_with_scale_asserts (self , device ):
841841 """Test quantize_with_scale raises assertions for invalid inputs."""
842- # Test 1D weight assertion
843- weight_1d = torch .randn (64 , dtype = torch .float32 , device = device )
844- scale = torch .randint (0 , 255 , (2 ,), dtype = torch .uint8 , device = device )
845- with pytest .raises (AssertionError , match = "Weight must be at least 2D" ):
846- MXFP8QTensor .quantize_with_scale (weight_1d , scale )
847-
848842 # Test wrong scale dtype assertion
849843 weight = torch .randn (64 , 64 , dtype = torch .float32 , device = device )
850844 wrong_dtype_scale = torch .randn (64 , 2 , dtype = torch .float32 , device = device )
@@ -856,34 +850,3 @@ def test_mxfp8_quantize_with_scale_asserts(self, device):
856850 scale = torch .randint (0 , 255 , (64 , 1 ), dtype = torch .uint8 , device = device )
857851 with pytest .raises (AssertionError , match = "must be divisible by MXFP8 block size" ):
858852 MXFP8QTensor .quantize_with_scale (weight_bad_dim , scale )
859-
860- @pytest .mark .parametrize ("device" , ["cuda" ])
861- def test_mxfp8_quantize_dequantize_asserts (self , device ):
862- """Test quantize and dequantize raise assertions for invalid inputs."""
863- # Test empty tensor assertion
864- empty_tensor = torch .empty (0 , dtype = torch .float32 , device = device )
865- with pytest .raises (AssertionError , match = "Input tensor must not be empty" ):
866- MXFP8QTensor .quantize (empty_tensor )
867-
868- # Test 0D tensor assertion
869- scalar_tensor = torch .tensor (1.0 , dtype = torch .float32 , device = device )
870- with pytest .raises (AssertionError , match = "Input must have at least 1 dimension" ):
871- MXFP8QTensor .quantize (scalar_tensor )
872-
873- # Test non-floating point assertion
874- int_tensor = torch .randint (0 , 10 , (32 , 32 ), dtype = torch .int32 , device = device )
875- with pytest .raises (AssertionError , match = "Input must be floating point" ):
876- MXFP8QTensor .quantize (int_tensor )
877-
878- # Create a valid quantized tensor for dequantize tests
879- input_tensor = torch .randn (64 , 64 , dtype = torch .float32 , device = device )
880- qtensor , e8m0_scale = MXFP8QTensor .quantize (input_tensor )
881-
882- # Test missing scale assertion
883- with pytest .raises (AssertionError , match = "dequantize requires 'scale' in kwargs" ):
884- qtensor .dequantize (dtype = torch .float32 )
885-
886- # Test wrong scale dtype assertion
887- wrong_dtype_scale = torch .randn (64 , 2 , dtype = torch .float32 , device = device )
888- with pytest .raises (AssertionError , match = "e8m0_scale must be" ):
889- qtensor .dequantize (dtype = torch .float32 , scale = wrong_dtype_scale )
0 commit comments