Skip to content

Commit d7e72f4

Browse files
authored
Refine static NVFP4 MSE calibration (#1536)
### What does this PR do? Type of change: Bug fix Refines static NVFP4 MSE calibration and forces static NVFP4 amax state to stay FP32 across calibration loading, quantizer promotion, dtype casts, and restore paths. Main changes: - Tighten max/MSE calibration bootstrap and static NVFP4 quantizer promotion. - Keep static NVFP4 `_amax` and `_global_amax` in FP32. - Update focused GPU/unit coverage for FP8 sweep calibration, promotion, restore, and FP32 amax preservation. ### Usage ```yaml algorithm: method: mse fp8_scale_sweep: true ``` ### Testing ```bash pre-commit run --files modelopt/torch/quantization/nn/modules/tensor_quantizer.py tests/gpu/torch/quantization/test_nvfp4_static_quantizer_cuda.py pytest_pwd tests/gpu/torch/quantization/test_nvfp4_static_quantizer_cuda.py -q ``` ### Before your PR is "*Ready for review*" Make sure you read and follow [Contributor guidelines](https://github.com/NVIDIA/Model-Optimizer/blob/main/CONTRIBUTING.md) and your commits are signed (`git commit -s -S`). Make sure you read and follow the [Security Best Practices](https://github.com/NVIDIA/Model-Optimizer/blob/main/SECURITY.md#security-coding-practices-for-contributors). - Is this change backward compatible?: Yes - If you copied code from any other sources or added a new PIP dependency, did you follow guidance in `CONTRIBUTING.md`: N/A - Did you write any new necessary tests?: Yes - Did you update [Changelog](https://github.com/NVIDIA/Model-Optimizer/blob/main/CHANGELOG.rst)?: N/A - Did you get Claude approval on this PR?: Yes ### Additional Information N/A Signed-off-by: realAsma <akuriparambi@nvidia.com>
1 parent 7aa0c95 commit d7e72f4

12 files changed

Lines changed: 695 additions & 306 deletions

File tree

examples/llm_ptq/cast_mxfp4_to_nvfp4.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -297,8 +297,8 @@ def force_weight_quantizers_static(quant_cfg: list) -> None:
297297
The MXFP4 -> NVFP4 cast needs the per-block weight ``_amax`` to be recorded
298298
by max-cal (so it can be paired with the pinned global_amax later). Setting
299299
``block_sizes['type'] = 'static'`` makes ``is_static_block_quant`` True so
300-
``promote_nvfp4_static_quantizers`` picks the entry up automatically at the
301-
end of max_calibrate.
300+
static NVFP4 finalization picks the entry up automatically at the end of
301+
max_calibrate.
302302
"""
303303
for i, entry in enumerate(quant_cfg):
304304
qname = entry.get("quantizer_name", "")

modelopt/torch/quantization/calib/mse.py

Lines changed: 32 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -178,9 +178,8 @@ class NVFP4MSECalibrator(MseCalibrator):
178178
Uses a fused Triton kernel as an internal fast path on the first ``collect`` call
179179
when (a) ``error_func is None``, (b) the input tensor is on CUDA in the standard
180180
blocked ``[n_blocks, block_size]`` layout, and (c) Triton + the kernel package are
181-
importable. Falls back to the reference 126-step Python sweep otherwise (custom
182-
``error_func`` users, multi-``collect`` activation flows, CPU inputs, or when the
183-
fast path is disabled via ``MODELOPT_NVFP4_TRITON_SWEEP=0``).
181+
importable. Falls back to the reference 126-step Python sweep otherwise and caches
182+
the final amax immediately, so this calibrator is one-shot between resets.
184183
"""
185184

186185
def __init__(
@@ -193,14 +192,16 @@ def __init__(
193192
):
194193
"""Initialize NVFP4 MSE calibrator with per-block and global amax."""
195194
super().__init__(amax=amax, axis=axis, quant_func=quant_func, error_func=error_func)
196-
self._global_amax = global_amax
197-
# Set by the Triton fast path on its (one-shot) collect; consumed by compute_amax.
198-
self._best_amax_fast: torch.Tensor | None = None
195+
self._global_amax = global_amax.to(dtype=torch.float32)
196+
# Set by collect() after either sweep path; consumed by compute_amax.
197+
self._best_amax: torch.Tensor | None = None
199198

200199
def _compute_candidate_amax(self, candidates: torch.Tensor) -> torch.Tensor:
201200
if candidates.ndim != 0: # Called during final compute amax
202201
candidates = candidates.view_as(self._initial_amax)
203-
return torch.ones_like(self._initial_amax) * self._global_amax * candidates
202+
return torch.ones_like(self._initial_amax, dtype=torch.float32) * (
203+
self._global_amax * candidates
204+
)
204205

205206
def _generate_candidates(self, device: torch.device) -> torch.Tensor:
206207
"""Generate the 126 valid FP8 E4M3 scale candidates."""
@@ -235,35 +236,39 @@ def _can_use_triton_fast_path(self, x: torch.Tensor) -> bool:
235236

236237
@torch.no_grad()
237238
def collect(self, x: torch.Tensor):
238-
"""Collect input statistics. Uses the Triton fast path when eligible."""
239-
if self._best_amax_fast is not None:
239+
"""Collect input statistics and cache the final per-block amax."""
240+
if self._best_amax is not None:
240241
raise RuntimeError(
241-
"NVFP4MSECalibrator: the Triton fast path produced a final amax on a "
242-
"previous collect() call; multi-collect after the fast path is not "
243-
"supported. Call reset() to start a fresh cycle, set "
244-
"MODELOPT_NVFP4_TRITON_SWEEP=0, or pass a non-None error_func to force "
245-
"the reference path for activation-style accumulation."
242+
"NVFP4MSECalibrator: a previous collect() call produced a final amax; "
243+
"multi-collect is not supported. Call reset() to start a fresh cycle."
246244
)
247-
# Fast path is eligible only on the first call, before the reference accumulator
248-
# has produced any state.
249-
if self._losses_sum is None and self._can_use_triton_fast_path(x):
245+
if self._can_use_triton_fast_path(x):
250246
from modelopt.torch.kernels.quantization.gemm import nvfp4_fp8_scale_sweep
251247

252248
best_flat = nvfp4_fp8_scale_sweep(x.detach(), self._global_amax, block_size=x.shape[-1])
253-
# Match the original shape/dtype of the initial amax so downstream
254-
# load_calib_amax behaves identically to the reference path.
255-
self._best_amax_fast = best_flat.reshape(self._initial_amax.shape).to(
256-
self._initial_amax.dtype
257-
)
249+
# Store the selected amax in fp32; the fake-quant kernel still returns
250+
# tensors in the requested output dtype.
251+
self._best_amax = best_flat.reshape(self._initial_amax.shape).to(dtype=torch.float32)
258252
return
253+
254+
self._run_reference_collect(x)
255+
256+
def _run_reference_collect(self, x: torch.Tensor):
259257
super().collect(x)
258+
best_amax = super().compute_amax(verbose=False)
259+
self._best_amax = best_amax.to(dtype=torch.float32) if best_amax is not None else None
260+
self._losses_sum = None
261+
# Synchronize before calibrating another weight so reference MSE sweeps do
262+
# not overlap. _losses_sum stores one fp32 reduced loss per candidate per
263+
# block. With 16-element NVFP4 blocks and bf16 weights, this is roughly
264+
# 128 / 16 * (4 / 2) = 16x the calibrated weight size.
265+
if x.is_cuda:
266+
torch.cuda.synchronize(x.device)
260267

261268
@torch.no_grad()
262269
def compute_amax(self, verbose: bool = False):
263-
"""Return the per-block amax — from the fast path if it ran, else from the reference sweep."""
264-
if self._best_amax_fast is not None:
265-
return self._best_amax_fast
266-
return super().compute_amax(verbose=verbose)
270+
"""Return the cached per-block amax."""
271+
return self._best_amax
267272

268273
def reset(self):
269274
"""Reset per-cycle state. Keep ``_initial_amax`` so the calibrator stays reusable.
@@ -273,7 +278,7 @@ def reset(self):
273278
small enough to keep so a follow-up ``collect()`` can run again on the same
274279
calibrator instance.
275280
"""
276-
self._best_amax_fast = None
281+
self._best_amax = None
277282
self._losses_sum = None
278283
self._candidates = None
279284
self._amax = None

modelopt/torch/quantization/config.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -722,9 +722,9 @@ class MseCalibConfig(QuantizeAlgorithmConfig):
722722
Finds a scale s (via amax a, with s = a / q_max) that minimizes the
723723
reconstruction error of a tensor after uniform Q→DQ:
724724
725-
s* = argmin_s E[(X - DQ(Q(X; s)))^2], X{weights | activations}
725+
s* = argmin_s E[(W - DQ(Q(W; s)))^2], W ∈ weights
726726
727-
When fp8_scale_sweep is enabled, step_size is ignored.
727+
When fp8_scale_sweep is enabled for a supported FP8-scale format, step_size is ignored.
728728
"""
729729

730730
method: Literal["mse"] = ModeloptField("mse")
@@ -755,8 +755,8 @@ class MseCalibConfig(QuantizeAlgorithmConfig):
755755
default=False,
756756
title="Enable FP8 scale sweep for NVFP4 per-block quantization.",
757757
description="If True, sweep all 128 FP8 E4M3 scale values instead of using multipliers. "
758-
"Only applies to NVFP4 weight quantization. When enabled, num_steps, step_size, "
759-
"start_multiplier, and stop_multiplier are ignored.",
758+
"Applies to ModelOpt static NVFP4 weight quantizers and registered custom backends with "
759+
"FP8 sweep support. Other weight quantizers use the multiplier search.",
760760
)
761761

762762
distributed_sync: bool | None = ModeloptField(

0 commit comments

Comments
 (0)