Skip to content

Commit ac30686

Browse files
authored
Track global_amax for weight FP4 MSE sweep; Refactor to NVFP4StaticQantizer, NVFP4MSECalibrator (#849)
## What does this PR do? **Type of change:** ? <!-- Use one of the following: Bug fix, new feature, new example, new tests, documentation. --> **Overview:** ? ## Usage <!-- You can potentially add a usage example below. --> ```python # Add a code snippet demonstrating how to use this ``` ## Testing <!-- Mention how have you tested your change if applicable. --> ## Before your PR is "*Ready for review*" <!-- If you haven't finished some of the above items you can still open `Draft` PR. --> - **Make sure you read and follow [Contributor guidelines](https://github.com/NVIDIA/Model-Optimizer/blob/main/CONTRIBUTING.md)** and your commits are signed. - **Is this change backward compatible?**: Yes/No <!--- If No, explain why. --> - **Did you write any new necessary tests?**: Yes/No - **Did you add or update any necessary documentation?**: Yes/No - **Did you update [Changelog](https://github.com/NVIDIA/Model-Optimizer/blob/main/CHANGELOG.rst)?**: Yes/No <!--- Only for new features, API changes, critical bug fixes or bw breaking changes. --> ## Additional Information <!-- E.g. related issue. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit ## Release Notes * **New Features** * Added NVFP4StaticQuantizer for improved 4-bit quantization with enhanced precision control * Introduced NVFP4MSECalibrator with flexible candidate generation for calibration optimization * **Improvements** * Optimized GPU kernels for Hopper+ graphics cards with better performance * Extended Triton support to broader GPU compatibility * Enhanced backward compatibility for restoring previously quantized models * **Tests** * Added comprehensive test coverage for new quantizers and calibration methods <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: realAsma <akuriparambi@nvidia.com>
1 parent 3393e98 commit ac30686

13 files changed

Lines changed: 685 additions & 454 deletions

File tree

modelopt/torch/quantization/calib/mse.py

Lines changed: 69 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from .. import utils as quant_utils
2525
from .calibrator import _Calibrator
2626

27-
__all__ = ["MseCalibrator"]
27+
__all__ = ["MseCalibrator", "NVFP4MSECalibrator"]
2828

2929

3030
class MseCalibrator(_Calibrator):
@@ -39,7 +39,6 @@ def __init__(
3939
stop_multiplier: float = 4.0,
4040
quant_func: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None = None,
4141
error_func: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None = None,
42-
fp8_scale_sweep: bool = False,
4342
):
4443
"""Initialize MSE calibrator.
4544
@@ -54,9 +53,6 @@ def __init__(
5453
Should have signature: quant_func(x, amax) -> quantized_x.
5554
error_func: Function to compute error between x and xq.
5655
Default is F.mse_loss(x, xq, reduction='none').
57-
fp8_scale_sweep: If True, sweep over all 128 possible FP8 E4M3 scale values
58-
instead of using multipliers. This is specifically for NVFP4
59-
per-block quantization where scales are stored in FP8 format.
6056
"""
6157
super().__init__(num_bits=None, axis=axis, unsigned=None)
6258
self._initial_amax = amax
@@ -67,17 +63,21 @@ def __init__(
6763

6864
self._quant_func = quant_func
6965
self._error_func = error_func
70-
self._losses_sum = [None] * self._num_steps
71-
self._candidate_amaxs = [None] * self._num_steps
72-
self._fp8_scale_sweep = fp8_scale_sweep
73-
if fp8_scale_sweep:
74-
# For FP8 scale sweep, we always have exactly 126 valid FP8 E4M3 values
75-
# (128 total - 2 invalid: byte 0 = zero, byte 127 = NaN)
76-
self._num_steps = 126
77-
self._losses_sum = [None] * self._num_steps
78-
self._candidate_amaxs = [None] * self._num_steps
79-
80-
self._amax = None
66+
self._losses_sum: list[torch.Tensor | None] | None = None
67+
self._candidates: torch.Tensor | None = None
68+
self._amax: torch.Tensor | None = None
69+
70+
def _generate_candidates(self, device: torch.device) -> torch.Tensor:
71+
"""Generate candidate multipliers. Override in subclasses for different candidate sets."""
72+
return torch.linspace(
73+
self._start_multiplier, self._stop_multiplier, steps=self._num_steps, device=device
74+
)
75+
76+
def _compute_candidate_amax(self, candidates: torch.Tensor) -> torch.Tensor:
77+
"""Compute amax from candidates. Override in subclasses for different amax computation."""
78+
if candidates.ndim != 0: # Called during final compute amax
79+
candidates = candidates.view_as(self._initial_amax)
80+
return self._initial_amax * candidates
8181

8282
@torch.no_grad()
8383
def collect(self, x: torch.Tensor):
@@ -87,39 +87,22 @@ def collect(self, x: torch.Tensor):
8787
x: Input tensor.
8888
"""
8989
if self._quant_func is None:
90-
raise RuntimeError(
91-
"Quantization function not set. Msecalibrator requires a quant_func to be provided."
92-
)
90+
raise RuntimeError("Quantization function not set.")
9391

9492
x = x.detach().to(dtype=torch.float32)
95-
9693
device = x.device
9794

98-
if self._fp8_scale_sweep:
99-
global_amax = quant_utils.reduce_amax(x, axis=None, keepdims=False, squeeze_scalar=True)
100-
101-
# Generate all 128 possible FP8 E4M3 values (0-127 as uint8, viewed as float8_e4m3fn)
102-
# Create uint8 tensor with values 0-127, view as float8_e4m3fn, then convert to float32
103-
uint8_values = torch.arange(0, 128, dtype=torch.uint8, device=device)
104-
fp8_values = uint8_values.view(torch.float8_e4m3fn).float()
105-
106-
# Filter out invalid values (NaN, inf, and zero) which aren't useful as multipliers
107-
valid_mask = torch.isfinite(fp8_values) & (fp8_values > 0)
108-
fp8_values_valid = fp8_values[valid_mask]
95+
candidates = self._generate_candidates(device)
96+
if self._candidates is None:
97+
self._candidates = candidates
98+
self._num_steps = len(candidates)
99+
self._losses_sum = [None] * self._num_steps
109100

110-
candidates = fp8_values_valid / 448.0
111-
else:
112-
candidates = torch.linspace(
113-
self._start_multiplier, self._stop_multiplier, steps=self._num_steps, device=device
114-
)
115-
# Get reduce axis for per-channel quantization
101+
assert self._losses_sum is not None
116102
reduce_axis = quant_utils.convert_quantization_axis_to_reduce_axis(x, self._axis)
117103

118104
for step, candidate in enumerate(candidates):
119-
if self._fp8_scale_sweep:
120-
candidate_amax = (global_amax * candidate) * torch.ones_like(self._initial_amax)
121-
else:
122-
candidate_amax = self._initial_amax * candidate
105+
candidate_amax = self._compute_candidate_amax(candidate)
123106
xq = self._quant_func(x, candidate_amax)
124107

125108
if self._error_func is not None:
@@ -129,28 +112,16 @@ def collect(self, x: torch.Tensor):
129112

130113
loss = quant_utils.reduce_sum(error, axis=reduce_axis, keepdims=False)
131114

132-
if self._candidate_amaxs[step] is None:
133-
self._candidate_amaxs[step] = candidate_amax
134-
135115
if self._losses_sum[step] is None:
136116
self._losses_sum[step] = loss.clone()
137117
else:
138118
self._losses_sum[step] += loss
139119

140120
def reset(self):
141121
"""Reset the stored losses and amax value."""
142-
self._losses_sum = [None] * self._num_steps
143-
self._candidate_amaxs = [None] * self._num_steps
122+
self._losses_sum = None
123+
self._candidates = None
144124
self._amax = None
145-
146-
def clear(self):
147-
"""Clear all cached data to free GPU memory.
148-
149-
Call this after compute_amax() and load_calib_amax() are done.
150-
"""
151-
self._losses_sum = []
152-
self._candidate_amaxs = []
153-
154125
if self._initial_amax is not None:
155126
del self._initial_amax
156127
self._initial_amax = None
@@ -162,49 +133,28 @@ def compute_amax(self, verbose: bool = False):
162133
Args:
163134
verbose: If True, print the ratio of best_amax to initial_amax.
164135
"""
165-
if not any(loss_sum is not None for loss_sum in self._losses_sum):
136+
if self._losses_sum is None or not any(loss is not None for loss in self._losses_sum):
166137
return None
167138

168-
# Check if this is per-tensor or per-channel based on the first loss
169-
first_loss_sum = None
170-
for loss_sum in self._losses_sum:
171-
if loss_sum is not None:
172-
first_loss_sum = loss_sum
173-
break
174-
175-
if first_loss_sum is None:
139+
first_loss = next((loss for loss in self._losses_sum if loss is not None), None)
140+
if first_loss is None:
176141
return None
177142

178-
# Collect losses for all steps
179-
losses_per_step = []
143+
# Stack losses: [num_steps] or [num_steps, num_channels]
144+
losses = []
180145
for step in range(self._num_steps):
181146
if self._losses_sum[step] is not None:
182-
losses_per_step.append(self._losses_sum[step])
183-
# No data for this step, use inf
184-
elif first_loss_sum.ndim == 0:
185-
losses_per_step.append(torch.tensor(float("inf"), device=first_loss_sum.device))
147+
losses.append(self._losses_sum[step])
148+
elif first_loss.ndim == 0:
149+
losses.append(torch.tensor(float("inf"), device=first_loss.device))
186150
else:
187-
losses_per_step.append(torch.full_like(first_loss_sum, float("inf")))
188-
189-
# Stack to get [num_steps] for per-tensor or [num_steps, num_channels] for per-channel
190-
losses_per_step = torch.stack(losses_per_step)
151+
losses.append(torch.full_like(first_loss, float("inf")))
191152

192-
# Find best step(s): scalar for per-tensor, [num_channels] for per-channel
193-
best_steps = torch.argmin(losses_per_step, dim=0)
194-
195-
# Stack candidate amaxs and select based on best_steps
196-
candidate_amaxs = torch.stack(self._candidate_amaxs)
197-
198-
if first_loss_sum.ndim == 0:
199-
# Per-tensor case: best_steps is a scalar
200-
self._amax = self._candidate_amaxs[best_steps.item()]
201-
else:
202-
# Per-channel case: best_steps is a tensor
203-
num_channels = best_steps.shape[0]
204-
self._amax = candidate_amaxs[
205-
best_steps, torch.arange(num_channels, device=best_steps.device)
206-
]
207-
self._amax = self._amax.reshape(self._initial_amax.shape)
153+
losses = torch.stack(losses)
154+
best_indices = torch.argmin(losses, dim=0)
155+
assert self._candidates is not None
156+
best_candidates = self._candidates[best_indices]
157+
self._amax = self._compute_candidate_amax(best_candidates)
208158

209159
if verbose:
210160
ratio = self._amax / self._initial_amax
@@ -219,3 +169,32 @@ def compute_amax(self, verbose: bool = False):
219169
)
220170

221171
return self._amax
172+
173+
174+
class NVFP4MSECalibrator(MseCalibrator):
175+
"""Per-block FP8 scale sweep calibrator for NVFP4 static quantization."""
176+
177+
def __init__(
178+
self,
179+
amax: torch.Tensor, # per_block_amax shape [num_blocks]
180+
global_amax: torch.Tensor, # scalar
181+
axis: int | tuple | list | None = None,
182+
quant_func: Callable | None = None,
183+
error_func: Callable | None = None,
184+
):
185+
"""Initialize NVFP4 MSE calibrator with per-block and global amax."""
186+
super().__init__(amax=amax, axis=axis, quant_func=quant_func, error_func=error_func)
187+
self._global_amax = global_amax
188+
189+
def _compute_candidate_amax(self, candidates: torch.Tensor) -> torch.Tensor:
190+
if candidates.ndim != 0: # Called during final compute amax
191+
candidates = candidates.view_as(self._initial_amax)
192+
return torch.ones_like(self._initial_amax) * self._global_amax * candidates
193+
194+
def _generate_candidates(self, device: torch.device) -> torch.Tensor:
195+
"""Generate 126 valid FP8 E4M3 scale candidates."""
196+
uint8_values = torch.arange(0, 128, dtype=torch.uint8, device=device)
197+
fp8_values = uint8_values.view(torch.float8_e4m3fn).float()
198+
valid_mask = torch.isfinite(fp8_values) & (fp8_values > 0)
199+
fp8_values = fp8_values[valid_mask]
200+
return fp8_values / 448.0

modelopt/torch/quantization/conversion.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
_QuantizeExportConfig,
3636
)
3737
from .nn import (
38+
NVFP4StaticQuantizer,
3839
QuantModule,
3940
QuantModuleRegistry,
4041
SequentialQuantizer,
@@ -125,6 +126,12 @@ def restore_quantizer_state(model: nn.Module, config: QuantizeConfig, metadata:
125126
for name, module in model.named_modules():
126127
if isinstance(module, TensorQuantizer):
127128
name = get_unwrapped_name(name, model)
129+
state = quantizer_state_dict[name]
130+
# TODO: Add a registry for TensorQuantizers and avoid this manual conversion.
131+
if state.get("_is_nvfp4_static_quantizer") and not isinstance(
132+
module, NVFP4StaticQuantizer
133+
):
134+
NVFP4StaticQuantizer.from_tensor_quantizer(module)
128135
module.set_from_modelopt_state(quantizer_state_dict[name])
129136

130137
for name, module in model.named_modules():

modelopt/torch/quantization/model_calib.py

Lines changed: 30 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,9 @@
3232
from modelopt.torch.utils.network import bind_forward_method, unpatch_forward_method
3333
from modelopt.torch.utils.perf import get_used_gpu_mem_fraction
3434

35-
from .calib import MseCalibrator
35+
from .calib import MseCalibrator, NVFP4MSECalibrator
3636
from .conversion import create_and_replace_svdquant_linear_on_the_fly, set_quantizer_by_cfg_context
37-
from .nn import QuantModule, SequentialQuantizer, TensorQuantizer
37+
from .nn import NVFP4StaticQuantizer, QuantModule, SequentialQuantizer, TensorQuantizer
3838
from .utils import (
3939
disable_calib,
4040
enable_fake_quant,
@@ -44,6 +44,7 @@
4444
is_quantized_linear,
4545
is_quantized_row_parallel_linear,
4646
quantizer_attr_names,
47+
reduce_amax,
4748
weight_attr_names,
4849
)
4950

@@ -299,7 +300,7 @@ def mse_calibrate(
299300
weight_quantizers = []
300301
seen_modules = set()
301302

302-
for name, module in model.named_modules():
303+
for name, module in list(model.named_modules()):
303304
if isinstance(module, TensorQuantizer) and not module._disabled:
304305
if module._calibrator is not None and not module._dynamic and hasattr(module, "_amax"):
305306
# Get the initial amax from max calibration
@@ -311,6 +312,24 @@ def mse_calibrate(
311312
and module._block_sizes is not None
312313
and module._block_sizes.get("scale_bits") == (4, 3)
313314
)
315+
316+
if is_nvfp4_static:
317+
# Compute and set global_amax
318+
global_amax = reduce_amax(initial_amax, axis=None)
319+
320+
# Convert to NVFP4StaticQuantizer in-place
321+
NVFP4StaticQuantizer.from_tensor_quantizer(module, global_amax=global_amax)
322+
323+
if fp8_scale_sweep and is_nvfp4_static:
324+
# Replace calibrator with NVFP4MSECalibrator
325+
module._calibrator = NVFP4MSECalibrator(
326+
amax=initial_amax,
327+
axis=module._calibrator._axis,
328+
global_amax=module.global_amax,
329+
quant_func=partial(_mse_quant_func, quantizer=module),
330+
)
331+
continue
332+
314333
if fp8_scale_sweep and not is_nvfp4_static:
315334
warnings.warn(
316335
f"fp8_scale_sweep is enabled but quantizer '{name}' is not NVFP4 static "
@@ -325,7 +344,6 @@ def mse_calibrate(
325344
start_multiplier=start_multiplier,
326345
stop_multiplier=stop_multiplier,
327346
quant_func=partial(_mse_quant_func, quantizer=module),
328-
fp8_scale_sweep=fp8_scale_sweep and is_nvfp4_static,
329347
)
330348

331349
# Identify weight quantizers by checking if they have corresponding weight parameters
@@ -344,40 +362,12 @@ def mse_calibrate(
344362
# This ensures weights are only calibrated once, not during every forward pass
345363
for parent_module, weight_name, weight_quantizer in weight_quantizers:
346364
# Enable calibration mode for the weight quantizer
347-
weight_quantizer.disable_quant()
348-
weight_quantizer.enable_calib()
349-
365+
enable_stats_collection(parent_module)
350366
with enable_weight_access_and_writeback(parent_module, model):
351367
weight = getattr(parent_module, weight_name)
352368
weight_quantizer(weight)
353-
354-
# Step 4: Disable weight quantizers during forward loop
355-
for _, _, weight_quantizer in weight_quantizers:
356-
weight_quantizer.disable()
357-
358-
# Step 5: Collect data with MSE calibrators for activation quantizers only
359-
enable_stats_collection(model)
360-
if forward_loop is None:
361-
# If no forward loop, nothing else to do since weights are already calibrated
362-
pass
363-
else:
364-
# Run forward loop - only activation quantizers will collect data
365-
forward_loop(model)
366-
367-
# Step 6: Re-enable weight quantizers before finalizing calibration
368-
# This ensures finish_stats_collection processes them correctly
369-
for _, _, weight_quantizer in weight_quantizers:
370-
weight_quantizer.enable()
371-
372-
# Step 7: Compute optimal amax and load it for all quantizers (weights + activations)
373-
finish_stats_collection(model, method="mse")
374-
375-
# Step 8: Free GPU memory by clearing calibrator data
376-
for name, module in model.named_modules():
377-
if isinstance(module, TensorQuantizer) and not module._disabled:
378-
if hasattr(module, "_calibrator") and getattr(module, "_calibrator", None) is not None:
379-
if hasattr(module._calibrator, "clear"):
380-
module._calibrator.clear()
369+
finish_stats_collection(parent_module, method="mse")
370+
weight_quantizer._calibrator.reset()
381371

382372
# TODO: Sync amax across distributed processes
383373

@@ -393,23 +383,19 @@ def enable_stats_collection(model: nn.Module):
393383
module.disable()
394384

395385

396-
def finish_stats_collection(model: nn.Module, method: str | None = None):
386+
def finish_stats_collection(model: nn.Module, method: str | None = None, **kwargs):
397387
"""Finish stats collection for all quantizers in the model."""
398388
for _, module in model.named_modules():
399389
if not isinstance(module, TensorQuantizer) or module._disabled:
400390
continue
401391

402392
cal = getattr(module, "_calibrator", None)
403393
if cal and not getattr(module, "_dynamic", False):
404-
if method in {"mse", "entropy"}:
394+
if method in {"entropy"}:
405395
if cal.compute_amax(method) is not None:
406-
if method == "entropy":
407-
module.load_calib_amax("entropy")
408-
else:
409-
module.load_calib_amax()
410-
elif cal.compute_amax() is not None:
411-
# Max calibrator
412-
module.load_calib_amax()
396+
module.load_calib_amax("entropy", **kwargs)
397+
elif cal.compute_amax(**kwargs) is not None:
398+
module.load_calib_amax(**kwargs)
413399

414400
if module.bias_calibrator is not None and module.bias_type == "static":
415401
module.load_calib_bias()

0 commit comments

Comments
 (0)