Skip to content

Commit aafd388

Browse files
authored
add FP8 sweep option for static NVFP4 MSE (#758)
## What does this PR do? **Type of change:** ? <!-- Use one of the following: Bug fix, new feature, new example, new tests, documentation. --> new feature **Overview:** ? Adds fp8_scale_sweep mode to MSE calibrator for optimizing FP8-quantized per-block scales in NVFP4 format. ## Usage <!-- You can potentially add a usage example below. --> Tested with this config ```python NVFP4_WEIGHT_MSE_FP8_SWEEP_CFG = { "quant_cfg": { "*weight_quantizer": { "num_bits": (2, 1), "block_sizes": {-1: 16, "type": "static", "scale_bits": (4, 3)}, "axis": None, "enable": True, }, "*input_quantizer": { "enable": False, }, **_default_disabled_quantizer_cfg, }, "algorithm": { "method": "mse", "fp8_scale_sweep": True, }, } ``` ## Testing <!-- Mention how have you tested your change if applicable. --> ## Before your PR is "*Ready for review*" <!-- If you haven't finished some of the above items you can still open `Draft` PR. --> - **Make sure you read and follow [Contributor guidelines](https://github.com/NVIDIA/Model-Optimizer/blob/main/CONTRIBUTING.md)** and your commits are signed. - **Is this change backward compatible?**: Yes/No <!--- If No, explain why. --> - **Did you write any new necessary tests?**: Yes/No - **Did you add or update any necessary documentation?**: Yes/No - **Did you update [Changelog](https://github.com/NVIDIA/Model-Optimizer/blob/main/CHANGELOG.rst)?**: Yes/No <!--- Only for new features, API changes, critical bug fixes or bw breaking changes. --> ## Additional Information <!-- E.g. related issue. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit ## Release Notes **New Features** - Added FP8 scale sweep option for quantization calibration, enabling optimized scale value sweeping for NVFP4 per-block quantization. - Introduced new NVFP4_WEIGHT_MSE_CFG configuration preset for improved weight quantization workflows. **Tests** - Added test coverage validating FP8 scale sweep functionality and reset behavior. <sub>✏️ Tip: You can customize this high-level summary in your review settings.</sub> <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com>
1 parent 3840309 commit aafd388

9 files changed

Lines changed: 221 additions & 66 deletions

File tree

modelopt/torch/quantization/calib/mse.py

Lines changed: 37 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -39,20 +39,24 @@ def __init__(
3939
stop_multiplier: float = 4.0,
4040
quant_func: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None = None,
4141
error_func: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None = None,
42+
fp8_scale_sweep: bool = False,
4243
):
4344
"""Initialize MSE calibrator.
4445
4546
Args:
4647
amax: Initial amax value (required).
4748
axis: Quantization axis. None means per-tensor quantization.
4849
step_size: Step size for amax search. The number of steps is computed as
49-
ceil((stop_multiplier - start_multiplier) / step_size) + 1.
50+
ceil((stop_multiplier - start_multiplier) / step_size) + 1.
5051
start_multiplier: Starting multiplier for amax search.
5152
stop_multiplier: Ending multiplier for amax search.
5253
quant_func: Function that quantizes input tensor given an amax value.
53-
Should have signature: quant_func(x, amax) -> quantized_x.
54+
Should have signature: quant_func(x, amax) -> quantized_x.
5455
error_func: Function to compute error between x and xq.
55-
Default is F.mse_loss(x, xq, reduction='none').
56+
Default is F.mse_loss(x, xq, reduction='none').
57+
fp8_scale_sweep: If True, sweep over all 128 possible FP8 E4M3 scale values
58+
instead of using multipliers. This is specifically for NVFP4
59+
per-block quantization where scales are stored in FP8 format.
5660
"""
5761
super().__init__(num_bits=None, axis=axis, unsigned=None)
5862
self._initial_amax = amax
@@ -65,6 +69,13 @@ def __init__(
6569
self._error_func = error_func
6670
self._losses_sum = [None] * self._num_steps
6771
self._candidate_amaxs = [None] * self._num_steps
72+
self._fp8_scale_sweep = fp8_scale_sweep
73+
if fp8_scale_sweep:
74+
# For FP8 scale sweep, we always have exactly 126 valid FP8 E4M3 values
75+
# (128 total - 2 invalid: byte 0 = zero, byte 127 = NaN)
76+
self._num_steps = 126
77+
self._losses_sum = [None] * self._num_steps
78+
self._candidate_amaxs = [None] * self._num_steps
6879

6980
self._amax = None
7081

@@ -83,14 +94,32 @@ def collect(self, x: torch.Tensor):
8394
x = x.detach().to(dtype=torch.float32)
8495

8596
device = x.device
86-
multipliers = torch.linspace(
87-
self._start_multiplier, self._stop_multiplier, steps=self._num_steps, device=device
88-
)
97+
98+
if self._fp8_scale_sweep:
99+
global_amax = quant_utils.reduce_amax(x, axis=None, keepdims=False, squeeze_scalar=True)
100+
101+
# Generate all 128 possible FP8 E4M3 values (0-127 as uint8, viewed as float8_e4m3fn)
102+
# Create uint8 tensor with values 0-127, view as float8_e4m3fn, then convert to float32
103+
uint8_values = torch.arange(0, 128, dtype=torch.uint8, device=device)
104+
fp8_values = uint8_values.view(torch.float8_e4m3fn).float()
105+
106+
# Filter out invalid values (NaN, inf, and zero) which aren't useful as multipliers
107+
valid_mask = torch.isfinite(fp8_values) & (fp8_values > 0)
108+
fp8_values_valid = fp8_values[valid_mask]
109+
110+
candidates = fp8_values_valid / 448.0
111+
else:
112+
candidates = torch.linspace(
113+
self._start_multiplier, self._stop_multiplier, steps=self._num_steps, device=device
114+
)
89115
# Get reduce axis for per-channel quantization
90116
reduce_axis = quant_utils.convert_quantization_axis_to_reduce_axis(x, self._axis)
91117

92-
for step, multiplier in enumerate(multipliers):
93-
candidate_amax = self._initial_amax * multiplier
118+
for step, candidate in enumerate(candidates):
119+
if self._fp8_scale_sweep:
120+
candidate_amax = (global_amax * candidate) * torch.ones_like(self._initial_amax)
121+
else:
122+
candidate_amax = self._initial_amax * candidate
94123
xq = self._quant_func(x, candidate_amax)
95124

96125
if self._error_func is not None:

modelopt/torch/quantization/config.py

Lines changed: 10 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -387,30 +387,6 @@
387387
"algorithm": "max",
388388
}
389389

390-
NVFP4_WEIGHT_ACT_MSE_CFG = {
391-
"quant_cfg": {
392-
"*weight_quantizer": {
393-
"num_bits": (2, 1),
394-
"block_sizes": {-1: 16, "type": "static", "scale_bits": (4, 3)},
395-
"axis": None,
396-
"enable": True,
397-
},
398-
"*input_quantizer": {
399-
"num_bits": (2, 1),
400-
"block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)},
401-
"axis": None,
402-
"enable": True,
403-
},
404-
**_default_disabled_quantizer_cfg,
405-
},
406-
"algorithm": {
407-
"method": "mse",
408-
"step_size": 0.25,
409-
"start_multiplier": 0.25,
410-
"stop_multiplier": 2.0,
411-
},
412-
}
413-
414390
NVFP4_AWQ_LITE_CFG = {
415391
"quant_cfg": {
416392
"*weight_quantizer": {
@@ -1040,6 +1016,8 @@ class MseCalibConfig(QuantizeAlgorithmConfig):
10401016
reconstruction error of a tensor after uniform Q→DQ:
10411017
10421018
s* = argmin_s E[(X - DQ(Q(X; s)))^2], X ∈ {weights | activations}
1019+
1020+
When fp8_scale_sweep is enabled, step_size is ignored.
10431021
"""
10441022

10451023
method: Literal["mse"] = ModeloptField("mse")
@@ -1066,6 +1044,14 @@ class MseCalibConfig(QuantizeAlgorithmConfig):
10661044
description="Ending multiplier for amax search range (multiplies initial amax).",
10671045
)
10681046

1047+
fp8_scale_sweep: bool | None = ModeloptField(
1048+
default=False,
1049+
title="Enable FP8 scale sweep for NVFP4 per-block quantization.",
1050+
description="If True, sweep all 128 FP8 E4M3 scale values instead of using multipliers. "
1051+
"Only applies to NVFP4 weight quantization. When enabled, num_steps, step_size, "
1052+
"start_multiplier, and stop_multiplier are ignored.",
1053+
)
1054+
10691055
distributed_sync: bool | None = ModeloptField(
10701056
default=True,
10711057
title="Whether to sync the amax across the distributed processes.",

modelopt/torch/quantization/model_calib.py

Lines changed: 42 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,30 @@ def sync_quantizer_amax_across_tp(
197197
)
198198

199199

200+
def _mse_quant_func(x, amax, quantizer):
201+
"""Quantization function for MSE calibration."""
202+
original_amax = quantizer._amax.clone() if hasattr(quantizer, "_amax") else None
203+
quantizer._amax = amax
204+
205+
with (
206+
enable_quant(quantizer),
207+
disable_calib(quantizer),
208+
enable_fake_quant(quantizer),
209+
):
210+
if hasattr(quantizer, "_original_shape"):
211+
x = quantizer._reset_to_original_shape(x)
212+
xq = quantizer(x)
213+
if hasattr(quantizer, "_block_reshape_size"):
214+
xq = xq.reshape(quantizer._block_reshape_size)
215+
216+
if original_amax is not None:
217+
quantizer._amax = original_amax
218+
else:
219+
delattr(quantizer, "_amax")
220+
221+
return xq
222+
223+
200224
@torch.no_grad()
201225
def mse_calibrate(
202226
model: nn.Module,
@@ -205,6 +229,7 @@ def mse_calibrate(
205229
step_size: float = 0.1,
206230
start_multiplier: float = 0.25,
207231
stop_multiplier: float = 4.0,
232+
fp8_scale_sweep: bool = False,
208233
):
209234
"""Calibrate the model using MSE-based amax search.
210235
@@ -220,6 +245,10 @@ def mse_calibrate(
220245
step_size: Step size for amax search (default: 0.1).
221246
start_multiplier: Starting multiplier for amax search (default: 0.25).
222247
stop_multiplier: Ending multiplier for amax search (default: 4.0).
248+
fp8_scale_sweep: If True, sweep over all 128 possible FP8 E4M3 scale values
249+
for NVFP4 per-block quantization instead of using multipliers.
250+
This is specifically designed for optimizing the FP8-quantized
251+
per-block scales in NVFP4 format (default: False).
223252
224253
See :class:`MseCalibConfig <modelopt.torch.quantization.config.MseCalibConfig>` for
225254
details on the remaining arguments.
@@ -238,27 +267,17 @@ def mse_calibrate(
238267
# Get the initial amax from max calibration
239268
initial_amax = module._amax.clone().detach()
240269

241-
def quant_func(x, amax, quantizer=module):
242-
original_amax = quantizer._amax.clone() if hasattr(quantizer, "_amax") else None
243-
quantizer._amax = amax
244-
245-
with (
246-
enable_quant(quantizer),
247-
disable_calib(quantizer),
248-
enable_fake_quant(quantizer),
249-
):
250-
if hasattr(quantizer, "_original_shape"):
251-
x = quantizer._reset_to_original_shape(x)
252-
xq = quantizer(x)
253-
if hasattr(quantizer, "_block_reshape_size"):
254-
xq = xq.reshape(quantizer._block_reshape_size)
255-
256-
if original_amax is not None:
257-
quantizer._amax = original_amax
258-
else:
259-
delattr(quantizer, "_amax")
260-
261-
return xq
270+
is_nvfp4_static = (
271+
module.is_static_block_quant
272+
and module._num_bits == (2, 1)
273+
and module._block_sizes is not None
274+
and module._block_sizes.get("scale_bits") == (4, 3)
275+
)
276+
if fp8_scale_sweep and not is_nvfp4_static:
277+
warnings.warn(
278+
f"fp8_scale_sweep is enabled but quantizer '{name}' is not NVFP4 static "
279+
"block quantization. fp8_scale_sweep will be ignored for this quantizer."
280+
)
262281

263282
# Create MSE calibrator with quant_func
264283
module._calibrator = MseCalibrator(
@@ -267,7 +286,8 @@ def quant_func(x, amax, quantizer=module):
267286
step_size=step_size,
268287
start_multiplier=start_multiplier,
269288
stop_multiplier=stop_multiplier,
270-
quant_func=quant_func,
289+
quant_func=partial(_mse_quant_func, quantizer=module),
290+
fp8_scale_sweep=fp8_scale_sweep and is_nvfp4_static,
271291
)
272292

273293
# Identify weight quantizers by checking if they have corresponding weight parameters

modelopt/torch/quantization/nn/modules/tensor_quantizer.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -753,11 +753,12 @@ def _fake_quantize(self, inputs):
753753
elif self._num_bits == (2, 1) and self.is_static_block_quant:
754754
outputs = static_blockwise_fp4_fake_quant(
755755
inputs,
756-
amax / 6.0,
756+
None, # scale
757757
None, # scale_fp8_quant_amax
758758
False, # skip_scale_quant
759759
inputs.dtype, # out_dtype
760760
self._pass_through_bwd, # pass_through_bwd
761+
amax, # amax
761762
)
762763
elif isinstance(self._num_bits, tuple):
763764
# Float-point quantization, e.g., FP8

modelopt/torch/quantization/tensor_quant.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -574,21 +574,23 @@ def forward(
574574
skip_scale_quant,
575575
out_dtype,
576576
pass_through_bwd=False,
577+
amax=None,
577578
):
578579
"""Forward method."""
579-
_save_for_backward_if_needed(ctx, pass_through_bwd, x, scale)
580+
_save_for_backward_if_needed(ctx, pass_through_bwd, x, scale if scale is not None else amax)
580581
return triton_kernel.static_blockwise_fp4_fake_quant(
581582
x,
582583
scale,
583584
scale_fp8_quant_amax,
584585
skip_scale_quant,
585586
out_dtype,
587+
amax,
586588
)
587589

588590
@staticmethod
589591
def backward(ctx, grad_outputs):
590592
"""Implements straight through estimation with clipping."""
591-
return _fake_quant_backward_function(ctx, grad_outputs, num_args=6)
593+
return _fake_quant_backward_function(ctx, grad_outputs, num_args=7)
592594

593595

594596
def _tensor_quant(inputs, amax, num_bits=8, unsigned=False, narrow_range=True):

modelopt/torch/quantization/triton/fp4_kernel.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -406,20 +406,32 @@ def static_blockwise_fp4_fake_quant_kernel(
406406

407407
def static_blockwise_fp4_fake_quant(
408408
x: torch.Tensor,
409-
scale: torch.Tensor,
409+
scale: torch.Tensor | None = None,
410410
scale_fp8_quant_amax: torch.Tensor | None = None,
411411
skip_scale_quant: bool = False,
412412
out_dtype: torch.dtype | None = None,
413+
amax: torch.Tensor | None = None,
413414
):
414415
"""Static blockwise FP4 fake quantization using Triton kernel.
415416
416417
Args:
417418
x: [NUM_FP4_BLOCKS, BLOCK_SIZE] on CUDA.
418-
scale: [NUM_FP4_BLOCKS] or [NUM_FP4_BLOCKS, 1] on CUDA.
419+
scale: [NUM_FP4_BLOCKS] or [NUM_FP4_BLOCKS, 1] on CUDA. Mutually exclusive with amax.
419420
scale_fp8_quant_amax: Absolute max range for FP8 quantization of scale. If None, computed from scale.
420421
skip_scale_quant: If True, skip FP8 quantization of scale.
421422
out_dtype: Output dtype. Defaults to x.dtype if None.
423+
amax: [NUM_FP4_BLOCKS] or [NUM_FP4_BLOCKS, 1] on CUDA. If provided, scale = amax / 6.0.
424+
Mutually exclusive with scale.
422425
"""
426+
if scale is None and amax is None:
427+
raise ValueError("Either scale or amax must be provided")
428+
if scale is not None and amax is not None:
429+
raise ValueError("Cannot provide both scale and amax")
430+
431+
if amax is not None:
432+
scale = amax / 6.0 # FP4 max representable value is 6.0
433+
434+
assert scale is not None # Guaranteed by validation above
423435
assert x.ndim == 2
424436
NUM_FP4_BLOCKS, BLOCK_SIZE = x.shape
425437

tests/gpu/torch/quantization/test_quantize_cuda.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,30 @@
4343
"enable": True,
4444
},
4545
},
46-
"algorithm": "mse",
46+
"algorithm": {
47+
"method": "mse",
48+
"step_size": 0.25,
49+
"start_multiplier": 0.25,
50+
"stop_multiplier": 2.0,
51+
},
52+
}
53+
54+
NVFP4_WEIGHT_MSE_FP8_SWEEP_CFG = {
55+
"quant_cfg": {
56+
"*weight_quantizer": {
57+
"num_bits": (2, 1),
58+
"block_sizes": {-1: 16, "type": "static", "scale_bits": (4, 3)},
59+
"axis": None,
60+
"enable": True,
61+
},
62+
"*input_quantizer": {
63+
"enable": False,
64+
},
65+
},
66+
"algorithm": {
67+
"method": "mse",
68+
"fp8_scale_sweep": True,
69+
},
4770
}
4871

4972

@@ -71,6 +94,7 @@
7194
mtq.NVFP4_KV_ROTATE_CFG,
7295
mtq.FP8_2D_BLOCKWISE_WEIGHT_ONLY_CFG,
7396
NVFP4_WEIGHT_ACT_MSE_CFG,
97+
NVFP4_WEIGHT_MSE_FP8_SWEEP_CFG,
7498
],
7599
)
76100
def test_quantize(model_cls, config):
@@ -88,6 +112,7 @@ def test_quantize(model_cls, config):
88112
mtq.NVFP4_KV_ROTATE_CFG,
89113
mtq.FP8_2D_BLOCKWISE_WEIGHT_ONLY_CFG,
90114
NVFP4_WEIGHT_ACT_MSE_CFG,
115+
NVFP4_WEIGHT_MSE_FP8_SWEEP_CFG,
91116
]:
92117
if get_cuda_ext_mx() is None:
93118
pytest.skip("cuda_ext_mx is not available")

0 commit comments

Comments
 (0)