Skip to content

Commit 9adc057

Browse files
realAsmasugunav14
authored andcommitted
Track global_amax for weight FP4 MSE sweep; Refactor to NVFP4StaticQantizer, NVFP4MSECalibrator (#849)
## What does this PR do? **Type of change:** ? <!-- Use one of the following: Bug fix, new feature, new example, new tests, documentation. --> **Overview:** ? ## Usage <!-- You can potentially add a usage example below. --> ```python # Add a code snippet demonstrating how to use this ``` ## Testing <!-- Mention how have you tested your change if applicable. --> ## Before your PR is "*Ready for review*" <!-- If you haven't finished some of the above items you can still open `Draft` PR. --> - **Make sure you read and follow [Contributor guidelines](https://github.com/NVIDIA/Model-Optimizer/blob/main/CONTRIBUTING.md)** and your commits are signed. - **Is this change backward compatible?**: Yes/No <!--- If No, explain why. --> - **Did you write any new necessary tests?**: Yes/No - **Did you add or update any necessary documentation?**: Yes/No - **Did you update [Changelog](https://github.com/NVIDIA/Model-Optimizer/blob/main/CHANGELOG.rst)?**: Yes/No <!--- Only for new features, API changes, critical bug fixes or bw breaking changes. --> ## Additional Information <!-- E.g. related issue. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit ## Release Notes * **New Features** * Added NVFP4StaticQuantizer for improved 4-bit quantization with enhanced precision control * Introduced NVFP4MSECalibrator with flexible candidate generation for calibration optimization * **Improvements** * Optimized GPU kernels for Hopper+ graphics cards with better performance * Extended Triton support to broader GPU compatibility * Enhanced backward compatibility for restoring previously quantized models * **Tests** * Added comprehensive test coverage for new quantizers and calibration methods <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: realAsma <akuriparambi@nvidia.com>
1 parent eef96cb commit 9adc057

13 files changed

Lines changed: 685 additions & 454 deletions

File tree

modelopt/torch/quantization/calib/mse.py

Lines changed: 69 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from .. import utils as quant_utils
2525
from .calibrator import _Calibrator
2626

27-
__all__ = ["MseCalibrator"]
27+
__all__ = ["MseCalibrator", "NVFP4MSECalibrator"]
2828

2929

3030
class MseCalibrator(_Calibrator):
@@ -39,7 +39,6 @@ def __init__(
3939
stop_multiplier: float = 4.0,
4040
quant_func: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None = None,
4141
error_func: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None = None,
42-
fp8_scale_sweep: bool = False,
4342
):
4443
"""Initialize MSE calibrator.
4544
@@ -54,9 +53,6 @@ def __init__(
5453
Should have signature: quant_func(x, amax) -> quantized_x.
5554
error_func: Function to compute error between x and xq.
5655
Default is F.mse_loss(x, xq, reduction='none').
57-
fp8_scale_sweep: If True, sweep over all 128 possible FP8 E4M3 scale values
58-
instead of using multipliers. This is specifically for NVFP4
59-
per-block quantization where scales are stored in FP8 format.
6056
"""
6157
super().__init__(num_bits=None, axis=axis, unsigned=None)
6258
self._initial_amax = amax
@@ -67,17 +63,21 @@ def __init__(
6763

6864
self._quant_func = quant_func
6965
self._error_func = error_func
70-
self._losses_sum = [None] * self._num_steps
71-
self._candidate_amaxs = [None] * self._num_steps
72-
self._fp8_scale_sweep = fp8_scale_sweep
73-
if fp8_scale_sweep:
74-
# For FP8 scale sweep, we always have exactly 126 valid FP8 E4M3 values
75-
# (128 total - 2 invalid: byte 0 = zero, byte 127 = NaN)
76-
self._num_steps = 126
77-
self._losses_sum = [None] * self._num_steps
78-
self._candidate_amaxs = [None] * self._num_steps
79-
80-
self._amax = None
66+
self._losses_sum: list[torch.Tensor | None] | None = None
67+
self._candidates: torch.Tensor | None = None
68+
self._amax: torch.Tensor | None = None
69+
70+
def _generate_candidates(self, device: torch.device) -> torch.Tensor:
71+
"""Generate candidate multipliers. Override in subclasses for different candidate sets."""
72+
return torch.linspace(
73+
self._start_multiplier, self._stop_multiplier, steps=self._num_steps, device=device
74+
)
75+
76+
def _compute_candidate_amax(self, candidates: torch.Tensor) -> torch.Tensor:
77+
"""Compute amax from candidates. Override in subclasses for different amax computation."""
78+
if candidates.ndim != 0: # Called during final compute amax
79+
candidates = candidates.view_as(self._initial_amax)
80+
return self._initial_amax * candidates
8181

8282
@torch.no_grad()
8383
def collect(self, x: torch.Tensor):
@@ -87,39 +87,22 @@ def collect(self, x: torch.Tensor):
8787
x: Input tensor.
8888
"""
8989
if self._quant_func is None:
90-
raise RuntimeError(
91-
"Quantization function not set. Msecalibrator requires a quant_func to be provided."
92-
)
90+
raise RuntimeError("Quantization function not set.")
9391

9492
x = x.detach().to(dtype=torch.float32)
95-
9693
device = x.device
9794

98-
if self._fp8_scale_sweep:
99-
global_amax = quant_utils.reduce_amax(x, axis=None, keepdims=False, squeeze_scalar=True)
100-
101-
# Generate all 128 possible FP8 E4M3 values (0-127 as uint8, viewed as float8_e4m3fn)
102-
# Create uint8 tensor with values 0-127, view as float8_e4m3fn, then convert to float32
103-
uint8_values = torch.arange(0, 128, dtype=torch.uint8, device=device)
104-
fp8_values = uint8_values.view(torch.float8_e4m3fn).float()
105-
106-
# Filter out invalid values (NaN, inf, and zero) which aren't useful as multipliers
107-
valid_mask = torch.isfinite(fp8_values) & (fp8_values > 0)
108-
fp8_values_valid = fp8_values[valid_mask]
95+
candidates = self._generate_candidates(device)
96+
if self._candidates is None:
97+
self._candidates = candidates
98+
self._num_steps = len(candidates)
99+
self._losses_sum = [None] * self._num_steps
109100

110-
candidates = fp8_values_valid / 448.0
111-
else:
112-
candidates = torch.linspace(
113-
self._start_multiplier, self._stop_multiplier, steps=self._num_steps, device=device
114-
)
115-
# Get reduce axis for per-channel quantization
101+
assert self._losses_sum is not None
116102
reduce_axis = quant_utils.convert_quantization_axis_to_reduce_axis(x, self._axis)
117103

118104
for step, candidate in enumerate(candidates):
119-
if self._fp8_scale_sweep:
120-
candidate_amax = (global_amax * candidate) * torch.ones_like(self._initial_amax)
121-
else:
122-
candidate_amax = self._initial_amax * candidate
105+
candidate_amax = self._compute_candidate_amax(candidate)
123106
xq = self._quant_func(x, candidate_amax)
124107

125108
if self._error_func is not None:
@@ -129,28 +112,16 @@ def collect(self, x: torch.Tensor):
129112

130113
loss = quant_utils.reduce_sum(error, axis=reduce_axis, keepdims=False)
131114

132-
if self._candidate_amaxs[step] is None:
133-
self._candidate_amaxs[step] = candidate_amax
134-
135115
if self._losses_sum[step] is None:
136116
self._losses_sum[step] = loss.clone()
137117
else:
138118
self._losses_sum[step] += loss
139119

140120
def reset(self):
141121
"""Reset the stored losses and amax value."""
142-
self._losses_sum = [None] * self._num_steps
143-
self._candidate_amaxs = [None] * self._num_steps
122+
self._losses_sum = None
123+
self._candidates = None
144124
self._amax = None
145-
146-
def clear(self):
147-
"""Clear all cached data to free GPU memory.
148-
149-
Call this after compute_amax() and load_calib_amax() are done.
150-
"""
151-
self._losses_sum = []
152-
self._candidate_amaxs = []
153-
154125
if self._initial_amax is not None:
155126
del self._initial_amax
156127
self._initial_amax = None
@@ -162,49 +133,28 @@ def compute_amax(self, verbose: bool = False):
162133
Args:
163134
verbose: If True, print the ratio of best_amax to initial_amax.
164135
"""
165-
if not any(loss_sum is not None for loss_sum in self._losses_sum):
136+
if self._losses_sum is None or not any(loss is not None for loss in self._losses_sum):
166137
return None
167138

168-
# Check if this is per-tensor or per-channel based on the first loss
169-
first_loss_sum = None
170-
for loss_sum in self._losses_sum:
171-
if loss_sum is not None:
172-
first_loss_sum = loss_sum
173-
break
174-
175-
if first_loss_sum is None:
139+
first_loss = next((loss for loss in self._losses_sum if loss is not None), None)
140+
if first_loss is None:
176141
return None
177142

178-
# Collect losses for all steps
179-
losses_per_step = []
143+
# Stack losses: [num_steps] or [num_steps, num_channels]
144+
losses = []
180145
for step in range(self._num_steps):
181146
if self._losses_sum[step] is not None:
182-
losses_per_step.append(self._losses_sum[step])
183-
# No data for this step, use inf
184-
elif first_loss_sum.ndim == 0:
185-
losses_per_step.append(torch.tensor(float("inf"), device=first_loss_sum.device))
147+
losses.append(self._losses_sum[step])
148+
elif first_loss.ndim == 0:
149+
losses.append(torch.tensor(float("inf"), device=first_loss.device))
186150
else:
187-
losses_per_step.append(torch.full_like(first_loss_sum, float("inf")))
188-
189-
# Stack to get [num_steps] for per-tensor or [num_steps, num_channels] for per-channel
190-
losses_per_step = torch.stack(losses_per_step)
151+
losses.append(torch.full_like(first_loss, float("inf")))
191152

192-
# Find best step(s): scalar for per-tensor, [num_channels] for per-channel
193-
best_steps = torch.argmin(losses_per_step, dim=0)
194-
195-
# Stack candidate amaxs and select based on best_steps
196-
candidate_amaxs = torch.stack(self._candidate_amaxs)
197-
198-
if first_loss_sum.ndim == 0:
199-
# Per-tensor case: best_steps is a scalar
200-
self._amax = self._candidate_amaxs[best_steps.item()]
201-
else:
202-
# Per-channel case: best_steps is a tensor
203-
num_channels = best_steps.shape[0]
204-
self._amax = candidate_amaxs[
205-
best_steps, torch.arange(num_channels, device=best_steps.device)
206-
]
207-
self._amax = self._amax.reshape(self._initial_amax.shape)
153+
losses = torch.stack(losses)
154+
best_indices = torch.argmin(losses, dim=0)
155+
assert self._candidates is not None
156+
best_candidates = self._candidates[best_indices]
157+
self._amax = self._compute_candidate_amax(best_candidates)
208158

209159
if verbose:
210160
ratio = self._amax / self._initial_amax
@@ -219,3 +169,32 @@ def compute_amax(self, verbose: bool = False):
219169
)
220170

221171
return self._amax
172+
173+
174+
class NVFP4MSECalibrator(MseCalibrator):
175+
"""Per-block FP8 scale sweep calibrator for NVFP4 static quantization."""
176+
177+
def __init__(
178+
self,
179+
amax: torch.Tensor, # per_block_amax shape [num_blocks]
180+
global_amax: torch.Tensor, # scalar
181+
axis: int | tuple | list | None = None,
182+
quant_func: Callable | None = None,
183+
error_func: Callable | None = None,
184+
):
185+
"""Initialize NVFP4 MSE calibrator with per-block and global amax."""
186+
super().__init__(amax=amax, axis=axis, quant_func=quant_func, error_func=error_func)
187+
self._global_amax = global_amax
188+
189+
def _compute_candidate_amax(self, candidates: torch.Tensor) -> torch.Tensor:
190+
if candidates.ndim != 0: # Called during final compute amax
191+
candidates = candidates.view_as(self._initial_amax)
192+
return torch.ones_like(self._initial_amax) * self._global_amax * candidates
193+
194+
def _generate_candidates(self, device: torch.device) -> torch.Tensor:
195+
"""Generate 126 valid FP8 E4M3 scale candidates."""
196+
uint8_values = torch.arange(0, 128, dtype=torch.uint8, device=device)
197+
fp8_values = uint8_values.view(torch.float8_e4m3fn).float()
198+
valid_mask = torch.isfinite(fp8_values) & (fp8_values > 0)
199+
fp8_values = fp8_values[valid_mask]
200+
return fp8_values / 448.0

modelopt/torch/quantization/conversion.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
_QuantizeExportConfig,
3636
)
3737
from .nn import (
38+
NVFP4StaticQuantizer,
3839
QuantModule,
3940
QuantModuleRegistry,
4041
SequentialQuantizer,
@@ -125,6 +126,12 @@ def restore_quantizer_state(model: nn.Module, config: QuantizeConfig, metadata:
125126
for name, module in model.named_modules():
126127
if isinstance(module, TensorQuantizer):
127128
name = get_unwrapped_name(name, model)
129+
state = quantizer_state_dict[name]
130+
# TODO: Add a registry for TensorQuantizers and avoid this manual conversion.
131+
if state.get("_is_nvfp4_static_quantizer") and not isinstance(
132+
module, NVFP4StaticQuantizer
133+
):
134+
NVFP4StaticQuantizer.from_tensor_quantizer(module)
128135
module.set_from_modelopt_state(quantizer_state_dict[name])
129136

130137
for name, module in model.named_modules():

modelopt/torch/quantization/model_calib.py

Lines changed: 30 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,9 @@
3838
)
3939
from modelopt.torch.utils.perf import get_used_gpu_mem_fraction
4040

41-
from .calib import MseCalibrator
41+
from .calib import MseCalibrator, NVFP4MSECalibrator
4242
from .conversion import create_and_replace_svdquant_linear_on_the_fly, set_quantizer_by_cfg_context
43-
from .nn import QuantModule, SequentialQuantizer, TensorQuantizer
43+
from .nn import NVFP4StaticQuantizer, QuantModule, SequentialQuantizer, TensorQuantizer
4444
from .utils import (
4545
disable_calib,
4646
enable_fake_quant,
@@ -50,6 +50,7 @@
5050
is_quantized_linear,
5151
is_quantized_row_parallel_linear,
5252
quantizer_attr_names,
53+
reduce_amax,
5354
weight_attr_names,
5455
)
5556

@@ -305,7 +306,7 @@ def mse_calibrate(
305306
weight_quantizers = []
306307
seen_modules = set()
307308

308-
for name, module in model.named_modules():
309+
for name, module in list(model.named_modules()):
309310
if isinstance(module, TensorQuantizer) and not module._disabled:
310311
if module._calibrator is not None and not module._dynamic and hasattr(module, "_amax"):
311312
# Get the initial amax from max calibration
@@ -317,6 +318,24 @@ def mse_calibrate(
317318
and module._block_sizes is not None
318319
and module._block_sizes.get("scale_bits") == (4, 3)
319320
)
321+
322+
if is_nvfp4_static:
323+
# Compute and set global_amax
324+
global_amax = reduce_amax(initial_amax, axis=None)
325+
326+
# Convert to NVFP4StaticQuantizer in-place
327+
NVFP4StaticQuantizer.from_tensor_quantizer(module, global_amax=global_amax)
328+
329+
if fp8_scale_sweep and is_nvfp4_static:
330+
# Replace calibrator with NVFP4MSECalibrator
331+
module._calibrator = NVFP4MSECalibrator(
332+
amax=initial_amax,
333+
axis=module._calibrator._axis,
334+
global_amax=module.global_amax,
335+
quant_func=partial(_mse_quant_func, quantizer=module),
336+
)
337+
continue
338+
320339
if fp8_scale_sweep and not is_nvfp4_static:
321340
warnings.warn(
322341
f"fp8_scale_sweep is enabled but quantizer '{name}' is not NVFP4 static "
@@ -331,7 +350,6 @@ def mse_calibrate(
331350
start_multiplier=start_multiplier,
332351
stop_multiplier=stop_multiplier,
333352
quant_func=partial(_mse_quant_func, quantizer=module),
334-
fp8_scale_sweep=fp8_scale_sweep and is_nvfp4_static,
335353
)
336354

337355
# Identify weight quantizers by checking if they have corresponding weight parameters
@@ -350,40 +368,12 @@ def mse_calibrate(
350368
# This ensures weights are only calibrated once, not during every forward pass
351369
for parent_module, weight_name, weight_quantizer in weight_quantizers:
352370
# Enable calibration mode for the weight quantizer
353-
weight_quantizer.disable_quant()
354-
weight_quantizer.enable_calib()
355-
371+
enable_stats_collection(parent_module)
356372
with enable_weight_access_and_writeback(parent_module, model):
357373
weight = getattr(parent_module, weight_name)
358374
weight_quantizer(weight)
359-
360-
# Step 4: Disable weight quantizers during forward loop
361-
for _, _, weight_quantizer in weight_quantizers:
362-
weight_quantizer.disable()
363-
364-
# Step 5: Collect data with MSE calibrators for activation quantizers only
365-
enable_stats_collection(model)
366-
if forward_loop is None:
367-
# If no forward loop, nothing else to do since weights are already calibrated
368-
pass
369-
else:
370-
# Run forward loop - only activation quantizers will collect data
371-
forward_loop(model)
372-
373-
# Step 6: Re-enable weight quantizers before finalizing calibration
374-
# This ensures finish_stats_collection processes them correctly
375-
for _, _, weight_quantizer in weight_quantizers:
376-
weight_quantizer.enable()
377-
378-
# Step 7: Compute optimal amax and load it for all quantizers (weights + activations)
379-
finish_stats_collection(model, method="mse")
380-
381-
# Step 8: Free GPU memory by clearing calibrator data
382-
for name, module in model.named_modules():
383-
if isinstance(module, TensorQuantizer) and not module._disabled:
384-
if hasattr(module, "_calibrator") and getattr(module, "_calibrator", None) is not None:
385-
if hasattr(module._calibrator, "clear"):
386-
module._calibrator.clear()
375+
finish_stats_collection(parent_module, method="mse")
376+
weight_quantizer._calibrator.reset()
387377

388378
# TODO: Sync amax across distributed processes
389379

@@ -399,23 +389,19 @@ def enable_stats_collection(model: nn.Module):
399389
module.disable()
400390

401391

402-
def finish_stats_collection(model: nn.Module, method: str | None = None):
392+
def finish_stats_collection(model: nn.Module, method: str | None = None, **kwargs):
403393
"""Finish stats collection for all quantizers in the model."""
404394
for _, module in model.named_modules():
405395
if not isinstance(module, TensorQuantizer) or module._disabled:
406396
continue
407397

408398
cal = getattr(module, "_calibrator", None)
409399
if cal and not getattr(module, "_dynamic", False):
410-
if method in {"mse", "entropy"}:
400+
if method in {"entropy"}:
411401
if cal.compute_amax(method) is not None:
412-
if method == "entropy":
413-
module.load_calib_amax("entropy")
414-
else:
415-
module.load_calib_amax()
416-
elif cal.compute_amax() is not None:
417-
# Max calibrator
418-
module.load_calib_amax()
402+
module.load_calib_amax("entropy", **kwargs)
403+
elif cal.compute_amax(**kwargs) is not None:
404+
module.load_calib_amax(**kwargs)
419405

420406
if module.bias_calibrator is not None and module.bias_type == "static":
421407
module.load_calib_bias()

0 commit comments

Comments
 (0)