Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 41 additions & 1 deletion modelopt/torch/quantization/model_calib.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import warnings
from collections.abc import Callable
from functools import partial
from typing import TypeAlias

import torch
import torch.distributed as dist
Expand All @@ -36,7 +37,7 @@
from modelopt.torch.utils.distributed import DistributedProcessGroup, ParallelState
from modelopt.torch.utils.network import bind_forward_method, unpatch_forward_method

from .calib import MseCalibrator, NVFP4MSECalibrator
from .calib import MseCalibrator, NVFP4MSECalibrator, _Calibrator
from .conversion import create_and_replace_svdquant_linear_on_the_fly, set_quantizer_by_cfg_context
from .nn import NVFP4StaticQuantizer, QuantModule, SequentialQuantizer, TensorQuantizer
from .utils import (
Expand All @@ -56,6 +57,7 @@
from .utils.calib_utils import _GPTQ_HELPER_REGISTRY, GPTQHelper

__all__ = [
"CalibratorFactory",
"awq",
"layerwise_calibrate",
"local_hessian_calibrate",
Expand All @@ -64,6 +66,28 @@
"svdquant",
]

CalibratorFactory: TypeAlias = Callable[
[torch.Tensor, int | tuple | list | None, Callable[..., torch.Tensor]], _Calibrator
]

_FP8_SWEEP_CALIBRATOR_REGISTRY: dict[str, CalibratorFactory] = {}


def _register_fp8_sweep_calibrator(backend: str, calibrator_factory: CalibratorFactory) -> None:
"""Register a custom calibrator factory for a quantization backend.

When ``fp8_scale_sweep=True`` is passed to :func:`mse_calibrate`, any weight
quantizer whose ``backend`` attribute matches a registered key will use the
corresponding factory instead of the default :class:`MseCalibrator`.

Args:
backend: Backend name string (must match ``TensorQuantizer.backend``).
calibrator_factory: Callable with signature
``(amax: Tensor, axis: int | tuple | list | None, quant_func: Callable)``
that returns a :class:`_Calibrator` instance.
"""
_FP8_SWEEP_CALIBRATOR_REGISTRY[backend] = calibrator_factory


def weight_only_quantize(model: nn.Module):
"""Just quantize the weights of the model."""
Expand Down Expand Up @@ -341,6 +365,22 @@ def mse_calibrate(
# Convert to NVFP4StaticQuantizer in-place
NVFP4StaticQuantizer.from_tensor_quantizer(module, global_amax=global_amax)

if fp8_scale_sweep:
# Check if backend has a registered custom calibrator factory.
_backend: str | None = getattr(module, "backend", None)
backend_factory = (
_FP8_SWEEP_CALIBRATOR_REGISTRY.get(_backend)
if _backend is not None
else None
)
Comment thread
Fridah-nv marked this conversation as resolved.
if backend_factory is not None:
module._calibrator = backend_factory(
initial_amax,
module._calibrator._axis,
partial(_mse_quant_func, quantizer=module),
)
continue

if fp8_scale_sweep and is_nvfp4_static:
# Replace calibrator with NVFP4MSECalibrator
module._calibrator = NVFP4MSECalibrator(
Expand Down
114 changes: 114 additions & 0 deletions tests/unit/torch/quantization/test_mse_calibrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -526,3 +526,117 @@ def quant_func(x, amax):
assert a_best.numel() == 2
assert torch.all(torch.isfinite(a_best))
assert torch.all(a_best > 0)


class TestRegisterFP8SweepCalibrator:
"""Tests for _register_fp8_sweep_calibrator and its dispatch in mse_calibrate."""

def setup_method(self):
from modelopt.torch.quantization.model_calib import _FP8_SWEEP_CALIBRATOR_REGISTRY
from modelopt.torch.quantization.nn.modules.tensor_quantizer import (
_QUANT_FUNCTIONAL_BACKENDS,
)

self._orig_fp8_registry = dict(_FP8_SWEEP_CALIBRATOR_REGISTRY)
self._orig_quant_backends = dict(_QUANT_FUNCTIONAL_BACKENDS)

def teardown_method(self):
from modelopt.torch.quantization.model_calib import _FP8_SWEEP_CALIBRATOR_REGISTRY
from modelopt.torch.quantization.nn.modules.tensor_quantizer import (
_QUANT_FUNCTIONAL_BACKENDS,
)

_FP8_SWEEP_CALIBRATOR_REGISTRY.clear()
_FP8_SWEEP_CALIBRATOR_REGISTRY.update(self._orig_fp8_registry)
_QUANT_FUNCTIONAL_BACKENDS.clear()
_QUANT_FUNCTIONAL_BACKENDS.update(self._orig_quant_backends)

def _quantize_and_calibrate(self, backend_name, fp8_scale_sweep=True):
"""Quantize a small Linear with the given backend and run mse_calibrate."""
import modelopt.torch.quantization as mtq
from modelopt.torch.quantization.model_calib import mse_calibrate
from modelopt.torch.quantization.nn.modules.tensor_quantizer import register_quant_backend

register_quant_backend(backend_name, lambda x, tq: x)
model = torch.nn.Linear(8, 8, bias=False)
inputs = torch.randn(1, 8)
config = {
"quant_cfg": [
{"quantizer_name": "*", "enable": False},
{
"quantizer_name": "*weight_quantizer",
"cfg": {"num_bits": 8, "axis": None, "backend": backend_name},
},
],
"algorithm": "max",
}
mtq.quantize(model, config, forward_loop=lambda m: m(inputs))
mse_calibrate(model, lambda m: m(inputs), fp8_scale_sweep=fp8_scale_sweep)
return model

def test_register(self):
"""_register_fp8_sweep_calibrator stores factories by backend key and allows overwrite."""
from modelopt.torch.quantization.model_calib import (
_FP8_SWEEP_CALIBRATOR_REGISTRY,
_register_fp8_sweep_calibrator,
)

def factory_a(amax, axis, qf):
return None

def factory_b(amax, axis, qf):
return None

_register_fp8_sweep_calibrator("backend_x", factory_a)
assert _FP8_SWEEP_CALIBRATOR_REGISTRY["backend_x"] is factory_a

_register_fp8_sweep_calibrator("backend_x", factory_b)
assert _FP8_SWEEP_CALIBRATOR_REGISTRY["backend_x"] is factory_b

def test_mse_calibrate_dispatches_to_registered_factory(self):
"""mse_calibrate with fp8_scale_sweep=True calls the registered factory once per quantizer."""
from modelopt.torch.quantization.calib.mse import MseCalibrator
from modelopt.torch.quantization.model_calib import _register_fp8_sweep_calibrator

factory_calls: list = []

class _RecordingCalibrator(MseCalibrator):
def collect(self, x):
pass

def compute_amax(self, verbose=False):
return self._initial_amax

def my_factory(amax, axis, quant_func):
factory_calls.append(amax)
return _RecordingCalibrator(amax=amax, axis=axis, quant_func=quant_func)

_register_fp8_sweep_calibrator("_test_dispatch", my_factory)
self._quantize_and_calibrate("_test_dispatch", fp8_scale_sweep=True)

assert len(factory_calls) == 1

def test_mse_calibrate_skips_registry_when_fp8_sweep_false(self):
"""Registry factory is not invoked when fp8_scale_sweep=False."""
from modelopt.torch.quantization.model_calib import _register_fp8_sweep_calibrator

factory_calls: list = []

def my_factory(amax, axis, quant_func):
factory_calls.append(amax)
return calib.MseCalibrator(amax=amax, axis=axis, quant_func=quant_func)

_register_fp8_sweep_calibrator("_test_no_sweep", my_factory)
self._quantize_and_calibrate("_test_no_sweep", fp8_scale_sweep=False)

assert len(factory_calls) == 0

def test_unregistered_backend_uses_default_mse_calibrator(self):
"""A quantizer with an unregistered backend falls through to MseCalibrator."""
from modelopt.torch.quantization.calib.mse import MseCalibrator

model = self._quantize_and_calibrate("_test_unregistered", fp8_scale_sweep=True)
for module in model.modules():
if isinstance(module, TensorQuantizer) and module.is_enabled:
if getattr(module, "_calibrator", None) is not None:
assert isinstance(module._calibrator, MseCalibrator)
Loading