diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index b653369693..9b1cc5bc0c 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -20,6 +20,7 @@ import warnings from collections.abc import Callable from functools import partial +from typing import TypeAlias import torch import torch.distributed as dist @@ -36,7 +37,7 @@ from modelopt.torch.utils.distributed import DistributedProcessGroup, ParallelState from modelopt.torch.utils.network import bind_forward_method, unpatch_forward_method -from .calib import MseCalibrator, NVFP4MSECalibrator +from .calib import MseCalibrator, NVFP4MSECalibrator, _Calibrator from .conversion import create_and_replace_svdquant_linear_on_the_fly, set_quantizer_by_cfg_context from .nn import NVFP4StaticQuantizer, QuantModule, SequentialQuantizer, TensorQuantizer from .utils import ( @@ -56,6 +57,7 @@ from .utils.calib_utils import _GPTQ_HELPER_REGISTRY, GPTQHelper __all__ = [ + "CalibratorFactory", "awq", "layerwise_calibrate", "local_hessian_calibrate", @@ -64,6 +66,28 @@ "svdquant", ] +CalibratorFactory: TypeAlias = Callable[ + [torch.Tensor, int | tuple | list | None, Callable[..., torch.Tensor]], _Calibrator +] + +_FP8_SWEEP_CALIBRATOR_REGISTRY: dict[str, CalibratorFactory] = {} + + +def _register_fp8_sweep_calibrator(backend: str, calibrator_factory: CalibratorFactory) -> None: + """Register a custom calibrator factory for a quantization backend. + + When ``fp8_scale_sweep=True`` is passed to :func:`mse_calibrate`, any weight + quantizer whose ``backend`` attribute matches a registered key will use the + corresponding factory instead of the default :class:`MseCalibrator`. + + Args: + backend: Backend name string (must match ``TensorQuantizer.backend``). + calibrator_factory: Callable with signature + ``(amax: Tensor, axis: int | tuple | list | None, quant_func: Callable)`` + that returns a :class:`_Calibrator` instance. + """ + _FP8_SWEEP_CALIBRATOR_REGISTRY[backend] = calibrator_factory + def weight_only_quantize(model: nn.Module): """Just quantize the weights of the model.""" @@ -341,6 +365,22 @@ def mse_calibrate( # Convert to NVFP4StaticQuantizer in-place NVFP4StaticQuantizer.from_tensor_quantizer(module, global_amax=global_amax) + if fp8_scale_sweep: + # Check if backend has a registered custom calibrator factory. + _backend: str | None = getattr(module, "backend", None) + backend_factory = ( + _FP8_SWEEP_CALIBRATOR_REGISTRY.get(_backend) + if _backend is not None + else None + ) + if backend_factory is not None: + module._calibrator = backend_factory( + initial_amax, + module._calibrator._axis, + partial(_mse_quant_func, quantizer=module), + ) + continue + if fp8_scale_sweep and is_nvfp4_static: # Replace calibrator with NVFP4MSECalibrator module._calibrator = NVFP4MSECalibrator( diff --git a/tests/unit/torch/quantization/test_mse_calibrator.py b/tests/unit/torch/quantization/test_mse_calibrator.py index 5e55465120..4332b09386 100644 --- a/tests/unit/torch/quantization/test_mse_calibrator.py +++ b/tests/unit/torch/quantization/test_mse_calibrator.py @@ -526,3 +526,117 @@ def quant_func(x, amax): assert a_best.numel() == 2 assert torch.all(torch.isfinite(a_best)) assert torch.all(a_best > 0) + + +class TestRegisterFP8SweepCalibrator: + """Tests for _register_fp8_sweep_calibrator and its dispatch in mse_calibrate.""" + + def setup_method(self): + from modelopt.torch.quantization.model_calib import _FP8_SWEEP_CALIBRATOR_REGISTRY + from modelopt.torch.quantization.nn.modules.tensor_quantizer import ( + _QUANT_FUNCTIONAL_BACKENDS, + ) + + self._orig_fp8_registry = dict(_FP8_SWEEP_CALIBRATOR_REGISTRY) + self._orig_quant_backends = dict(_QUANT_FUNCTIONAL_BACKENDS) + + def teardown_method(self): + from modelopt.torch.quantization.model_calib import _FP8_SWEEP_CALIBRATOR_REGISTRY + from modelopt.torch.quantization.nn.modules.tensor_quantizer import ( + _QUANT_FUNCTIONAL_BACKENDS, + ) + + _FP8_SWEEP_CALIBRATOR_REGISTRY.clear() + _FP8_SWEEP_CALIBRATOR_REGISTRY.update(self._orig_fp8_registry) + _QUANT_FUNCTIONAL_BACKENDS.clear() + _QUANT_FUNCTIONAL_BACKENDS.update(self._orig_quant_backends) + + def _quantize_and_calibrate(self, backend_name, fp8_scale_sweep=True): + """Quantize a small Linear with the given backend and run mse_calibrate.""" + import modelopt.torch.quantization as mtq + from modelopt.torch.quantization.model_calib import mse_calibrate + from modelopt.torch.quantization.nn.modules.tensor_quantizer import register_quant_backend + + register_quant_backend(backend_name, lambda x, tq: x) + model = torch.nn.Linear(8, 8, bias=False) + inputs = torch.randn(1, 8) + config = { + "quant_cfg": [ + {"quantizer_name": "*", "enable": False}, + { + "quantizer_name": "*weight_quantizer", + "cfg": {"num_bits": 8, "axis": None, "backend": backend_name}, + }, + ], + "algorithm": "max", + } + mtq.quantize(model, config, forward_loop=lambda m: m(inputs)) + mse_calibrate(model, lambda m: m(inputs), fp8_scale_sweep=fp8_scale_sweep) + return model + + def test_register(self): + """_register_fp8_sweep_calibrator stores factories by backend key and allows overwrite.""" + from modelopt.torch.quantization.model_calib import ( + _FP8_SWEEP_CALIBRATOR_REGISTRY, + _register_fp8_sweep_calibrator, + ) + + def factory_a(amax, axis, qf): + return None + + def factory_b(amax, axis, qf): + return None + + _register_fp8_sweep_calibrator("backend_x", factory_a) + assert _FP8_SWEEP_CALIBRATOR_REGISTRY["backend_x"] is factory_a + + _register_fp8_sweep_calibrator("backend_x", factory_b) + assert _FP8_SWEEP_CALIBRATOR_REGISTRY["backend_x"] is factory_b + + def test_mse_calibrate_dispatches_to_registered_factory(self): + """mse_calibrate with fp8_scale_sweep=True calls the registered factory once per quantizer.""" + from modelopt.torch.quantization.calib.mse import MseCalibrator + from modelopt.torch.quantization.model_calib import _register_fp8_sweep_calibrator + + factory_calls: list = [] + + class _RecordingCalibrator(MseCalibrator): + def collect(self, x): + pass + + def compute_amax(self, verbose=False): + return self._initial_amax + + def my_factory(amax, axis, quant_func): + factory_calls.append(amax) + return _RecordingCalibrator(amax=amax, axis=axis, quant_func=quant_func) + + _register_fp8_sweep_calibrator("_test_dispatch", my_factory) + self._quantize_and_calibrate("_test_dispatch", fp8_scale_sweep=True) + + assert len(factory_calls) == 1 + + def test_mse_calibrate_skips_registry_when_fp8_sweep_false(self): + """Registry factory is not invoked when fp8_scale_sweep=False.""" + from modelopt.torch.quantization.model_calib import _register_fp8_sweep_calibrator + + factory_calls: list = [] + + def my_factory(amax, axis, quant_func): + factory_calls.append(amax) + return calib.MseCalibrator(amax=amax, axis=axis, quant_func=quant_func) + + _register_fp8_sweep_calibrator("_test_no_sweep", my_factory) + self._quantize_and_calibrate("_test_no_sweep", fp8_scale_sweep=False) + + assert len(factory_calls) == 0 + + def test_unregistered_backend_uses_default_mse_calibrator(self): + """A quantizer with an unregistered backend falls through to MseCalibrator.""" + from modelopt.torch.quantization.calib.mse import MseCalibrator + + model = self._quantize_and_calibrate("_test_unregistered", fp8_scale_sweep=True) + for module in model.modules(): + if isinstance(module, TensorQuantizer) and module.is_enabled: + if getattr(module, "_calibrator", None) is not None: + assert isinstance(module._calibrator, MseCalibrator)