Skip to content

Commit 8c28bb5

Browse files
daniserebmeenchen
authored andcommitted
Remove excessive assertions in MXFP8QTensor
Signed-off-by: Daniel Serebrenik <daserebrenik@nvidia.com>
1 parent d9c22e5 commit 8c28bb5

2 files changed

Lines changed: 3 additions & 47 deletions

File tree

modelopt/torch/quantization/qtensor/mxfp8_tensor.py

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -148,14 +148,14 @@ def quantize_with_scale(
148148
This method is useful for export paths where the scale has already been computed.
149149
150150
Args:
151-
weight: The weight tensor to quantize. Must be at least 2D.
151+
weight: The weight tensor to quantize. Must be at least 1D.
152152
e8m0_scale: E8M0 scale as uint8 biased exponent (bias = 127).
153-
Shape should be [..., out_dim, in_dim // 32].
153+
Shape should be [..., out_dim, in_dim // 32] for 2D+ tensors,
154+
or [in_dim // 32] for 1D tensors.
154155
155156
Returns:
156157
torch.Tensor: Quantized weight as float8_e4m3fn with same shape as input.
157158
"""
158-
assert weight.dim() >= 2, f"Weight must be at least 2D, got {weight.dim()}D"
159159
assert e8m0_scale.dtype == cls.SCALE_DTYPE, (
160160
f"e8m0_scale must be {cls.SCALE_DTYPE} (E8M0 format), got {e8m0_scale.dtype}"
161161
)
@@ -201,10 +201,6 @@ def quantize(cls, input: torch.Tensor) -> tuple:
201201
Returns:
202202
tuple: (MXFP8QTensor, e8m0_scale) where e8m0_scale is uint8 biased exponent.
203203
"""
204-
assert input.numel() > 0, "Input tensor must not be empty"
205-
assert input.dim() >= 1, f"Input must have at least 1 dimension, got {input.dim()}D"
206-
assert input.is_floating_point(), f"Input must be floating point, got {input.dtype}"
207-
208204
original_shape = input.shape
209205
original_dtype = input.dtype
210206

@@ -234,9 +230,6 @@ def dequantize(self, dtype: torch.dtype = None, **kwargs) -> torch.Tensor:
234230
assert "scale" in kwargs, "dequantize requires 'scale' in kwargs"
235231

236232
e8m0_scale = kwargs["scale"]
237-
assert e8m0_scale.dtype == self.SCALE_DTYPE, (
238-
f"e8m0_scale must be {self.SCALE_DTYPE} (E8M0 format), got {e8m0_scale.dtype}"
239-
)
240233

241234
if dtype is None:
242235
dtype = self.metadata["dtype"]

tests/gpu/torch/quantization/test_qtensor_cuda.py

Lines changed: 0 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -839,12 +839,6 @@ def test_mxfp8_get_weights_scaling_factor_asserts_non_divisible(self):
839839
@pytest.mark.parametrize("device", ["cuda"])
840840
def test_mxfp8_quantize_with_scale_asserts(self, device):
841841
"""Test quantize_with_scale raises assertions for invalid inputs."""
842-
# Test 1D weight assertion
843-
weight_1d = torch.randn(64, dtype=torch.float32, device=device)
844-
scale = torch.randint(0, 255, (2,), dtype=torch.uint8, device=device)
845-
with pytest.raises(AssertionError, match="Weight must be at least 2D"):
846-
MXFP8QTensor.quantize_with_scale(weight_1d, scale)
847-
848842
# Test wrong scale dtype assertion
849843
weight = torch.randn(64, 64, dtype=torch.float32, device=device)
850844
wrong_dtype_scale = torch.randn(64, 2, dtype=torch.float32, device=device)
@@ -856,34 +850,3 @@ def test_mxfp8_quantize_with_scale_asserts(self, device):
856850
scale = torch.randint(0, 255, (64, 1), dtype=torch.uint8, device=device)
857851
with pytest.raises(AssertionError, match="must be divisible by MXFP8 block size"):
858852
MXFP8QTensor.quantize_with_scale(weight_bad_dim, scale)
859-
860-
@pytest.mark.parametrize("device", ["cuda"])
861-
def test_mxfp8_quantize_dequantize_asserts(self, device):
862-
"""Test quantize and dequantize raise assertions for invalid inputs."""
863-
# Test empty tensor assertion
864-
empty_tensor = torch.empty(0, dtype=torch.float32, device=device)
865-
with pytest.raises(AssertionError, match="Input tensor must not be empty"):
866-
MXFP8QTensor.quantize(empty_tensor)
867-
868-
# Test 0D tensor assertion
869-
scalar_tensor = torch.tensor(1.0, dtype=torch.float32, device=device)
870-
with pytest.raises(AssertionError, match="Input must have at least 1 dimension"):
871-
MXFP8QTensor.quantize(scalar_tensor)
872-
873-
# Test non-floating point assertion
874-
int_tensor = torch.randint(0, 10, (32, 32), dtype=torch.int32, device=device)
875-
with pytest.raises(AssertionError, match="Input must be floating point"):
876-
MXFP8QTensor.quantize(int_tensor)
877-
878-
# Create a valid quantized tensor for dequantize tests
879-
input_tensor = torch.randn(64, 64, dtype=torch.float32, device=device)
880-
qtensor, e8m0_scale = MXFP8QTensor.quantize(input_tensor)
881-
882-
# Test missing scale assertion
883-
with pytest.raises(AssertionError, match="dequantize requires 'scale' in kwargs"):
884-
qtensor.dequantize(dtype=torch.float32)
885-
886-
# Test wrong scale dtype assertion
887-
wrong_dtype_scale = torch.randn(64, 2, dtype=torch.float32, device=device)
888-
with pytest.raises(AssertionError, match="e8m0_scale must be"):
889-
qtensor.dequantize(dtype=torch.float32, scale=wrong_dtype_scale)

0 commit comments

Comments
 (0)