Skip to content

Commit 69f9a76

Browse files
daniserebmeenchen
authored andcommitted
Cleanup code in get_weights_scaling_factor_from_quantizer of MXFP8QTensor
Signed-off-by: Daniel Serebrenik <daserebrenik@nvidia.com>
1 parent aaec5a1 commit 69f9a76

2 files changed

Lines changed: 12 additions & 22 deletions

File tree

modelopt/torch/quantization/qtensor/mxfp8_tensor.py

Lines changed: 0 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -132,15 +132,6 @@ def get_weights_scaling_factor_from_quantizer(
132132
assert scale.dtype == cls.SCALE_DTYPE, (
133133
f"MXFP8 scale must be {cls.SCALE_DTYPE} (E8M0 format), got {scale.dtype}"
134134
)
135-
136-
# Reshape if needed (same number of elements but wrong shape)
137-
if scale.shape != expected_shape:
138-
expected_numel = 1
139-
for dim in expected_shape:
140-
expected_numel *= dim
141-
if scale.numel() == expected_numel:
142-
scale = scale.reshape(expected_shape)
143-
144135
assert scale.shape == expected_shape, (
145136
f"Scale shape {scale.shape} does not match expected shape {expected_shape}"
146137
)
@@ -179,12 +170,6 @@ def quantize_with_scale(
179170
f"Weight inner dimension ({in_dim}) must be divisible by MXFP8 block size ({cls.BLOCK_SIZE})"
180171
)
181172

182-
# Reshape scale if needed (same number of elements but wrong shape)
183-
expected_shape = (*weight.shape[:-1], num_blocks)
184-
if e8m0_scale.shape != expected_shape:
185-
if e8m0_scale.numel() == weight.numel() // cls.BLOCK_SIZE:
186-
e8m0_scale = e8m0_scale.reshape(expected_shape)
187-
188173
# Convert E8M0 biased exponent to scale factor: scale = 2^(127 - exponent)
189174
scale_factor = torch.exp2(127 - e8m0_scale.float())
190175

@@ -258,13 +243,6 @@ def dequantize(self, dtype: torch.dtype = None, **kwargs) -> torch.Tensor:
258243
# Convert E8M0 biased exponent back to scale factor: descale = 2^(exponent - 127)
259244
descale = torch.exp2(e8m0_scale.float() - 127)
260245

261-
# Reshape descale to match blocked tensor for broadcasting
262-
expected_scale_shape = (*quantized_data.shape[:-1], num_blocks)
263-
if descale.shape != expected_scale_shape and descale.numel() == num_blocks * (
264-
quantized_data.numel() // quantized_data.shape[-1]
265-
):
266-
descale = descale.view(expected_scale_shape)
267-
268246
dequantized = quantized_blocked * descale.unsqueeze(-1)
269247

270248
# Reshape and crop back to original shape

tests/gpu/torch/quantization/test_qtensor_cuda.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -897,3 +897,15 @@ class MockQuantizer:
897897

898898
with pytest.raises(AssertionError, match="Scale shape .* does not match expected shape"):
899899
MXFP8QTensor.get_weights_scaling_factor_from_quantizer(weight, quantizer)
900+
901+
@pytest.mark.parametrize("device", ["cuda"])
902+
@pytest.mark.parametrize("input_dtype", [torch.float32, torch.float16, torch.bfloat16])
903+
def test_mxfp8_dequantize_default_dtype(self, device, input_dtype):
904+
"""Test dequantize uses original dtype when dtype=None."""
905+
input_tensor = torch.randn(64, 64, dtype=input_dtype, device=device)
906+
qtensor, e8m0_scale = MXFP8QTensor.quantize(input_tensor)
907+
908+
# Dequantize without specifying dtype
909+
dequant = qtensor.dequantize(scale=e8m0_scale)
910+
911+
assert dequant.dtype == input_dtype

0 commit comments

Comments
 (0)