Skip to content

Commit 4a2a15e

Browse files
daniserebmeenchen
authored andcommitted
Add weights_scaling_factor to quantize in MXFP8QTensor
Signed-off-by: Daniel Serebrenik <daserebrenik@nvidia.com>
1 parent cb3a0de commit 4a2a15e

File tree

2 files changed

+69
-15
lines changed

2 files changed

+69
-15
lines changed

modelopt/torch/quantization/qtensor/mxfp8_tensor.py

Lines changed: 24 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -144,23 +144,24 @@ def get_weights_scaling_factor_from_quantizer(
144144
def quantize_with_scale(
145145
cls,
146146
weight: torch.Tensor,
147-
e8m0_scale: torch.Tensor,
147+
weights_scaling_factor: torch.Tensor,
148148
) -> torch.Tensor:
149149
"""Quantize weight tensor using a pre-computed E8M0 scale.
150150
151151
This method is useful for export paths where the scale has already been computed.
152152
153153
Args:
154154
weight: The weight tensor to quantize. Must be at least 1D.
155-
e8m0_scale: E8M0 scale as uint8 biased exponent (bias = 127).
155+
weights_scaling_factor: E8M0 scale as uint8 biased exponent (bias = 127).
156156
Shape should be [..., out_dim, in_dim // 32] for 2D+ tensors,
157157
or [in_dim // 32] for 1D tensors.
158158
159159
Returns:
160160
torch.Tensor: Quantized weight as float8_e4m3fn with same shape as input.
161161
"""
162-
assert e8m0_scale.dtype == cls.SCALE_DTYPE, (
163-
f"e8m0_scale must be {cls.SCALE_DTYPE} (E8M0 format), got {e8m0_scale.dtype}"
162+
assert weights_scaling_factor.dtype == cls.SCALE_DTYPE, (
163+
f"weights_scaling_factor must be {cls.SCALE_DTYPE} (E8M0 format), "
164+
f"got {weights_scaling_factor.dtype}"
164165
)
165166

166167
in_dim = weight.shape[-1]
@@ -171,13 +172,13 @@ def quantize_with_scale(
171172
)
172173

173174
# Convert E8M0 biased exponent to scale factor: scale = 2^(127 - exponent)
174-
scale_factor = torch.exp2(127 - e8m0_scale.float())
175+
scale_factor = torch.exp2(127 - weights_scaling_factor.float())
175176

176177
# NOTE: vLLM/flashinfer may require this behavior:
177178
# scale_factor = torch.where(
178-
# e8m0_scale == 0,
179+
# weights_scaling_factor == 0,
179180
# 1.0,
180-
# torch.exp2(127 - e8m0_scale.float())
181+
# torch.exp2(127 - weights_scaling_factor.float())
181182
# )
182183

183184
weight_reshaped = weight.view(*weight.shape[:-1], num_blocks, cls.BLOCK_SIZE)
@@ -189,30 +190,39 @@ def quantize_with_scale(
189190
return quantized_weight.view(weight.shape)
190191

191192
@classmethod
192-
def quantize(cls, input: torch.Tensor) -> tuple:
193+
def quantize(
194+
cls,
195+
input: torch.Tensor,
196+
weights_scaling_factor: torch.Tensor | None = None,
197+
) -> tuple:
193198
"""Convert a tensor to MXFP8 quantized format.
194199
195200
Args:
196201
input (torch.Tensor): The input tensor to be quantized.
202+
weights_scaling_factor (torch.Tensor | None): Optional pre-computed E8M0 scale
203+
as uint8 biased exponent. If None, the scale will be computed from the input.
204+
Shape should be [..., in_dim // 32] matching input dimensions.
197205
198206
Returns:
199-
tuple: (MXFP8QTensor, e8m0_scale) where e8m0_scale is uint8 biased exponent.
207+
tuple: (MXFP8QTensor, weights_scaling_factor) where weights_scaling_factor is
208+
E8M0 scale as uint8 biased exponent.
200209
"""
201210
original_shape = input.shape
202211
original_dtype = input.dtype
203212

204213
input = reduce_block_padding(input, block_sizes={-1: cls.BLOCK_SIZE})
205-
input_amax = reduce_block_amax(input, block_sizes={-1: cls.BLOCK_SIZE})
206214

207-
e8m0_exponent = cls._compute_e8m0_exponent(input_amax)
208-
e8m0_scale = (e8m0_exponent + 127).to(cls.SCALE_DTYPE)
215+
if weights_scaling_factor is None:
216+
input_amax = reduce_block_amax(input, block_sizes={-1: cls.BLOCK_SIZE})
217+
e8m0_exponent = cls._compute_e8m0_exponent(input_amax)
218+
weights_scaling_factor = (e8m0_exponent + 127).to(cls.SCALE_DTYPE)
209219

210-
quantized_data = cls.quantize_with_scale(input, e8m0_scale)
220+
quantized_data = cls.quantize_with_scale(input, weights_scaling_factor)
211221

212222
# Crop back to original shape
213223
quantized_data = quantized_data[..., : original_shape[-1]]
214224

215-
return cls(original_shape, original_dtype, quantized_data), e8m0_scale
225+
return cls(original_shape, original_dtype, quantized_data), weights_scaling_factor
216226

217227
def dequantize(self, dtype: torch.dtype = None, **kwargs) -> torch.Tensor:
218228
"""Dequantize MXFP8 tensor back to the target dtype.

tests/gpu/torch/quantization/test_qtensor_cuda.py

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -797,6 +797,50 @@ def test_mxfp8_get_weights_scaling_factor(self, device, input_shape):
797797
# Note: 255 (0xFF) represents NaN in E8M0 and should never appear from valid weights
798798
assert torch.all(e8m0_scale <= 254), "E8M0 scale contains NaN value (255)"
799799

800+
@pytest.mark.parametrize("device", ["cuda", "cpu"])
801+
@pytest.mark.parametrize("input_dtype", [torch.float32, torch.float16, torch.bfloat16])
802+
@pytest.mark.parametrize(
803+
"input_shape",
804+
[
805+
(64, 64),
806+
(128, 128),
807+
(4, 64, 128), # 3D MoE shape
808+
# Note: All shapes must have last dim divisible by 32 since
809+
# get_weights_scaling_factor() requires this (unlike quantize() which pads)
810+
],
811+
)
812+
def test_mxfp8_quantize_with_precomputed_scale(self, device, input_dtype, input_shape):
813+
"""Test MXFP8 quantize() with pre-computed weights_scaling_factor."""
814+
test_tensor = torch.randn(input_shape, dtype=input_dtype, device=device)
815+
816+
# Quantize without pre-computed scale (baseline)
817+
qtensor_auto, scale_auto = MXFP8QTensor.quantize(test_tensor)
818+
819+
# Pre-compute scale and pass to quantize
820+
precomputed_scale = MXFP8QTensor.get_weights_scaling_factor(test_tensor)
821+
qtensor_precomputed, scale_precomputed = MXFP8QTensor.quantize(
822+
test_tensor, weights_scaling_factor=precomputed_scale
823+
)
824+
825+
# Verify scales match
826+
assert torch.equal(scale_auto, scale_precomputed), (
827+
"Pre-computed scale should match auto-computed scale"
828+
)
829+
830+
# Verify quantized data matches
831+
assert torch.equal(qtensor_auto._quantized_data, qtensor_precomputed._quantized_data), (
832+
"Quantized data should match when using pre-computed scale"
833+
)
834+
835+
# Verify dequantized results match
836+
dequant_auto = qtensor_auto.dequantize(dtype=input_dtype, scale=scale_auto)
837+
dequant_precomputed = qtensor_precomputed.dequantize(
838+
dtype=input_dtype, scale=scale_precomputed
839+
)
840+
assert torch.equal(dequant_auto, dequant_precomputed), (
841+
"Dequantized results should match"
842+
)
843+
800844
@pytest.mark.parametrize(
801845
("amax_value", "expected_exponent"),
802846
[
@@ -834,7 +878,7 @@ def test_mxfp8_quantize_with_scale_asserts(self, device):
834878
# Test wrong scale dtype assertion
835879
weight = torch.randn(64, 64, dtype=torch.float32, device=device)
836880
wrong_dtype_scale = torch.randn(64, 2, dtype=torch.float32, device=device)
837-
with pytest.raises(AssertionError, match="e8m0_scale must be"):
881+
with pytest.raises(AssertionError, match="weights_scaling_factor must be"):
838882
MXFP8QTensor.quantize_with_scale(weight, wrong_dtype_scale)
839883

840884
# Test non-divisible dimension assertion

0 commit comments

Comments
 (0)