From c7b5044a128335b5d6d734a8661102118dd237d9 Mon Sep 17 00:00:00 2001 From: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com> Date: Thu, 16 Apr 2026 13:50:52 -0700 Subject: [PATCH 1/4] add custom calibration backend registry Signed-off-by: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com> --- modelopt/torch/quantization/model_calib.py | 32 +++++ .../torch/quantization/test_mse_calibrator.py | 114 ++++++++++++++++++ 2 files changed, 146 insertions(+) diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index 35a0e931c9..f969cee3ef 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -55,11 +55,30 @@ "awq", "local_hessian_calibrate", "max_calibrate", + "register_fp8_sweep_calibrator", "sequential_calibrate", "smoothquant", "svdquant", ] +# Registry for backends that provide a custom calibrator factory for mse_calibrate(). +_FP8_SWEEP_CALIBRATOR_REGISTRY: dict[str, type] = {} + + +def register_fp8_sweep_calibrator(backend: str, calibrator_factory) -> None: + """Register a custom calibrator factory for a quantization backend. + + When ``fp8_scale_sweep=True`` is passed to :func:`mse_calibrate`, any weight + quantizer whose ``backend`` attribute matches a registered key will use the + corresponding factory instead of the default :class:`MseCalibrator`. + + Args: + backend: Backend name string (must match ``TensorQuantizer.backend``). + calibrator_factory: Callable with signature ``(amax, axis, quant_func)`` + that returns a calibrator instance. + """ + _FP8_SWEEP_CALIBRATOR_REGISTRY[backend] = calibrator_factory + def weight_only_quantize(model: nn.Module): """Just quantize the weights of the model.""" @@ -337,6 +356,19 @@ def mse_calibrate( # Convert to NVFP4StaticQuantizer in-place NVFP4StaticQuantizer.from_tensor_quantizer(module, global_amax=global_amax) + if fp8_scale_sweep: + # Check if backend has a registered custom calibrator factory. + backend_factory = _FP8_SWEEP_CALIBRATOR_REGISTRY.get( + getattr(module, "backend", None) + ) + if backend_factory is not None: + module._calibrator = backend_factory( + initial_amax, + module._calibrator._axis, + partial(_mse_quant_func, quantizer=module), + ) + continue + if fp8_scale_sweep and is_nvfp4_static: # Replace calibrator with NVFP4MSECalibrator module._calibrator = NVFP4MSECalibrator( diff --git a/tests/unit/torch/quantization/test_mse_calibrator.py b/tests/unit/torch/quantization/test_mse_calibrator.py index 5e55465120..f180e464a2 100644 --- a/tests/unit/torch/quantization/test_mse_calibrator.py +++ b/tests/unit/torch/quantization/test_mse_calibrator.py @@ -526,3 +526,117 @@ def quant_func(x, amax): assert a_best.numel() == 2 assert torch.all(torch.isfinite(a_best)) assert torch.all(a_best > 0) + + +class TestRegisterFP8SweepCalibrator: + """Tests for register_fp8_sweep_calibrator and its dispatch in mse_calibrate.""" + + def setup_method(self): + from modelopt.torch.quantization.model_calib import _FP8_SWEEP_CALIBRATOR_REGISTRY + from modelopt.torch.quantization.nn.modules.tensor_quantizer import ( + _QUANT_FUNCTIONAL_BACKENDS, + ) + + self._orig_fp8_registry = dict(_FP8_SWEEP_CALIBRATOR_REGISTRY) + self._orig_quant_backends = dict(_QUANT_FUNCTIONAL_BACKENDS) + + def teardown_method(self): + from modelopt.torch.quantization.model_calib import _FP8_SWEEP_CALIBRATOR_REGISTRY + from modelopt.torch.quantization.nn.modules.tensor_quantizer import ( + _QUANT_FUNCTIONAL_BACKENDS, + ) + + _FP8_SWEEP_CALIBRATOR_REGISTRY.clear() + _FP8_SWEEP_CALIBRATOR_REGISTRY.update(self._orig_fp8_registry) + _QUANT_FUNCTIONAL_BACKENDS.clear() + _QUANT_FUNCTIONAL_BACKENDS.update(self._orig_quant_backends) + + def _quantize_and_calibrate(self, backend_name, fp8_scale_sweep=True): + """Quantize a small Linear with the given backend and run mse_calibrate.""" + import modelopt.torch.quantization as mtq + from modelopt.torch.quantization.model_calib import mse_calibrate + from modelopt.torch.quantization.nn.modules.tensor_quantizer import register_quant_backend + + register_quant_backend(backend_name, lambda x, tq: x) + model = torch.nn.Linear(8, 8, bias=False) + inputs = torch.randn(1, 8) + config = { + "quant_cfg": [ + {"quantizer_name": "*", "enable": False}, + { + "quantizer_name": "*weight_quantizer", + "cfg": {"num_bits": 8, "axis": None, "backend": backend_name}, + }, + ], + "algorithm": "max", + } + mtq.quantize(model, config, forward_loop=lambda m: m(inputs)) + mse_calibrate(model, lambda m: m(inputs), fp8_scale_sweep=fp8_scale_sweep) + return model + + def test_register(self): + """register_fp8_sweep_calibrator stores factories by backend key and allows overwrite.""" + from modelopt.torch.quantization.model_calib import ( + _FP8_SWEEP_CALIBRATOR_REGISTRY, + register_fp8_sweep_calibrator, + ) + + def factory_a(amax, axis, qf): + return None + + def factory_b(amax, axis, qf): + return None + + register_fp8_sweep_calibrator("backend_x", factory_a) + assert _FP8_SWEEP_CALIBRATOR_REGISTRY["backend_x"] is factory_a + + register_fp8_sweep_calibrator("backend_x", factory_b) + assert _FP8_SWEEP_CALIBRATOR_REGISTRY["backend_x"] is factory_b + + def test_mse_calibrate_dispatches_to_registered_factory(self): + """mse_calibrate with fp8_scale_sweep=True calls the registered factory once per quantizer.""" + from modelopt.torch.quantization.calib.mse import MseCalibrator + from modelopt.torch.quantization.model_calib import register_fp8_sweep_calibrator + + factory_calls: list = [] + + class _RecordingCalibrator(MseCalibrator): + def collect(self, x): + pass + + def compute_amax(self, verbose=False): + return self._initial_amax + + def my_factory(amax, axis, quant_func): + factory_calls.append(amax) + return _RecordingCalibrator(amax=amax, axis=axis, quant_func=quant_func) + + register_fp8_sweep_calibrator("_test_dispatch", my_factory) + self._quantize_and_calibrate("_test_dispatch", fp8_scale_sweep=True) + + assert len(factory_calls) == 1 + + def test_mse_calibrate_skips_registry_when_fp8_sweep_false(self): + """Registry factory is not invoked when fp8_scale_sweep=False.""" + from modelopt.torch.quantization.model_calib import register_fp8_sweep_calibrator + + factory_calls: list = [] + + def my_factory(amax, axis, quant_func): + factory_calls.append(amax) + return calib.MseCalibrator(amax=amax, axis=axis, quant_func=quant_func) + + register_fp8_sweep_calibrator("_test_no_sweep", my_factory) + self._quantize_and_calibrate("_test_no_sweep", fp8_scale_sweep=False) + + assert len(factory_calls) == 0 + + def test_unregistered_backend_uses_default_mse_calibrator(self): + """A quantizer with an unregistered backend falls through to MseCalibrator.""" + from modelopt.torch.quantization.calib.mse import MseCalibrator + + model = self._quantize_and_calibrate("_test_unregistered", fp8_scale_sweep=True) + for module in model.modules(): + if isinstance(module, TensorQuantizer) and module.is_enabled: + if getattr(module, "_calibrator", None) is not None: + assert isinstance(module._calibrator, MseCalibrator) From 0e43a87df5f1e14d4c4fb115177cd6bfbbb28269 Mon Sep 17 00:00:00 2001 From: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com> Date: Thu, 16 Apr 2026 23:23:49 +0000 Subject: [PATCH 2/4] add type check Signed-off-by: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com> --- modelopt/torch/quantization/model_calib.py | 25 +++++++++++++++------- 1 file changed, 17 insertions(+), 8 deletions(-) diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index f969cee3ef..5de055931e 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -20,6 +20,7 @@ import warnings from collections.abc import Callable from functools import partial +from typing import TypeAlias import torch import torch.distributed as dist @@ -33,7 +34,7 @@ from modelopt.torch.utils.distributed import DistributedProcessGroup, ParallelState from modelopt.torch.utils.network import bind_forward_method, unpatch_forward_method -from .calib import MseCalibrator, NVFP4MSECalibrator +from .calib import MseCalibrator, NVFP4MSECalibrator, _Calibrator from .conversion import create_and_replace_svdquant_linear_on_the_fly, set_quantizer_by_cfg_context from .nn import NVFP4StaticQuantizer, QuantModule, SequentialQuantizer, TensorQuantizer from .utils import ( @@ -52,6 +53,7 @@ from .utils.calib_utils import GPTQHelper __all__ = [ + "CalibratorFactory", "awq", "local_hessian_calibrate", "max_calibrate", @@ -61,11 +63,14 @@ "svdquant", ] -# Registry for backends that provide a custom calibrator factory for mse_calibrate(). -_FP8_SWEEP_CALIBRATOR_REGISTRY: dict[str, type] = {} +CalibratorFactory: TypeAlias = Callable[ + [torch.Tensor, int | tuple | list | None, Callable[..., torch.Tensor]], _Calibrator +] + +_FP8_SWEEP_CALIBRATOR_REGISTRY: dict[str, CalibratorFactory] = {} -def register_fp8_sweep_calibrator(backend: str, calibrator_factory) -> None: +def register_fp8_sweep_calibrator(backend: str, calibrator_factory: CalibratorFactory) -> None: """Register a custom calibrator factory for a quantization backend. When ``fp8_scale_sweep=True`` is passed to :func:`mse_calibrate`, any weight @@ -74,8 +79,9 @@ def register_fp8_sweep_calibrator(backend: str, calibrator_factory) -> None: Args: backend: Backend name string (must match ``TensorQuantizer.backend``). - calibrator_factory: Callable with signature ``(amax, axis, quant_func)`` - that returns a calibrator instance. + calibrator_factory: Callable with signature + ``(amax: Tensor, axis: int | tuple | list | None, quant_func: Callable)`` + that returns a :class:`_Calibrator` instance. """ _FP8_SWEEP_CALIBRATOR_REGISTRY[backend] = calibrator_factory @@ -358,8 +364,11 @@ def mse_calibrate( if fp8_scale_sweep: # Check if backend has a registered custom calibrator factory. - backend_factory = _FP8_SWEEP_CALIBRATOR_REGISTRY.get( - getattr(module, "backend", None) + _backend: str | None = getattr(module, "backend", None) + backend_factory = ( + _FP8_SWEEP_CALIBRATOR_REGISTRY.get(_backend) + if _backend is not None + else None ) if backend_factory is not None: module._calibrator = backend_factory( From b2a7b32897365cd0e39d315acd7bc588d54c628d Mon Sep 17 00:00:00 2001 From: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com> Date: Sat, 18 Apr 2026 00:25:07 +0000 Subject: [PATCH 3/4] use private name Signed-off-by: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com> --- modelopt/torch/quantization/model_calib.py | 4 ++-- .../torch/quantization/test_mse_calibrator.py | 18 +++++++++--------- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index 5de055931e..d5c9308f00 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -54,10 +54,10 @@ __all__ = [ "CalibratorFactory", + "_register_fp8_sweep_calibrator", "awq", "local_hessian_calibrate", "max_calibrate", - "register_fp8_sweep_calibrator", "sequential_calibrate", "smoothquant", "svdquant", @@ -70,7 +70,7 @@ _FP8_SWEEP_CALIBRATOR_REGISTRY: dict[str, CalibratorFactory] = {} -def register_fp8_sweep_calibrator(backend: str, calibrator_factory: CalibratorFactory) -> None: +def _register_fp8_sweep_calibrator(backend: str, calibrator_factory: CalibratorFactory) -> None: """Register a custom calibrator factory for a quantization backend. When ``fp8_scale_sweep=True`` is passed to :func:`mse_calibrate`, any weight diff --git a/tests/unit/torch/quantization/test_mse_calibrator.py b/tests/unit/torch/quantization/test_mse_calibrator.py index f180e464a2..4332b09386 100644 --- a/tests/unit/torch/quantization/test_mse_calibrator.py +++ b/tests/unit/torch/quantization/test_mse_calibrator.py @@ -529,7 +529,7 @@ def quant_func(x, amax): class TestRegisterFP8SweepCalibrator: - """Tests for register_fp8_sweep_calibrator and its dispatch in mse_calibrate.""" + """Tests for _register_fp8_sweep_calibrator and its dispatch in mse_calibrate.""" def setup_method(self): from modelopt.torch.quantization.model_calib import _FP8_SWEEP_CALIBRATOR_REGISTRY @@ -575,10 +575,10 @@ def _quantize_and_calibrate(self, backend_name, fp8_scale_sweep=True): return model def test_register(self): - """register_fp8_sweep_calibrator stores factories by backend key and allows overwrite.""" + """_register_fp8_sweep_calibrator stores factories by backend key and allows overwrite.""" from modelopt.torch.quantization.model_calib import ( _FP8_SWEEP_CALIBRATOR_REGISTRY, - register_fp8_sweep_calibrator, + _register_fp8_sweep_calibrator, ) def factory_a(amax, axis, qf): @@ -587,16 +587,16 @@ def factory_a(amax, axis, qf): def factory_b(amax, axis, qf): return None - register_fp8_sweep_calibrator("backend_x", factory_a) + _register_fp8_sweep_calibrator("backend_x", factory_a) assert _FP8_SWEEP_CALIBRATOR_REGISTRY["backend_x"] is factory_a - register_fp8_sweep_calibrator("backend_x", factory_b) + _register_fp8_sweep_calibrator("backend_x", factory_b) assert _FP8_SWEEP_CALIBRATOR_REGISTRY["backend_x"] is factory_b def test_mse_calibrate_dispatches_to_registered_factory(self): """mse_calibrate with fp8_scale_sweep=True calls the registered factory once per quantizer.""" from modelopt.torch.quantization.calib.mse import MseCalibrator - from modelopt.torch.quantization.model_calib import register_fp8_sweep_calibrator + from modelopt.torch.quantization.model_calib import _register_fp8_sweep_calibrator factory_calls: list = [] @@ -611,14 +611,14 @@ def my_factory(amax, axis, quant_func): factory_calls.append(amax) return _RecordingCalibrator(amax=amax, axis=axis, quant_func=quant_func) - register_fp8_sweep_calibrator("_test_dispatch", my_factory) + _register_fp8_sweep_calibrator("_test_dispatch", my_factory) self._quantize_and_calibrate("_test_dispatch", fp8_scale_sweep=True) assert len(factory_calls) == 1 def test_mse_calibrate_skips_registry_when_fp8_sweep_false(self): """Registry factory is not invoked when fp8_scale_sweep=False.""" - from modelopt.torch.quantization.model_calib import register_fp8_sweep_calibrator + from modelopt.torch.quantization.model_calib import _register_fp8_sweep_calibrator factory_calls: list = [] @@ -626,7 +626,7 @@ def my_factory(amax, axis, quant_func): factory_calls.append(amax) return calib.MseCalibrator(amax=amax, axis=axis, quant_func=quant_func) - register_fp8_sweep_calibrator("_test_no_sweep", my_factory) + _register_fp8_sweep_calibrator("_test_no_sweep", my_factory) self._quantize_and_calibrate("_test_no_sweep", fp8_scale_sweep=False) assert len(factory_calls) == 0 From 57b33f39f599951fc581f2206d5539776d477c12 Mon Sep 17 00:00:00 2001 From: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com> Date: Sat, 18 Apr 2026 00:31:23 +0000 Subject: [PATCH 4/4] minor Signed-off-by: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com> --- modelopt/torch/quantization/model_calib.py | 1 - 1 file changed, 1 deletion(-) diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index d5c9308f00..3c1dacf022 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -54,7 +54,6 @@ __all__ = [ "CalibratorFactory", - "_register_fp8_sweep_calibrator", "awq", "local_hessian_calibrate", "max_calibrate",