2020import warnings
2121from collections .abc import Callable
2222from functools import partial
23+ from typing import TypeAlias
2324
2425import torch
2526import torch .distributed as dist
3334from modelopt .torch .utils .distributed import DistributedProcessGroup , ParallelState
3435from modelopt .torch .utils .network import bind_forward_method , unpatch_forward_method
3536
36- from .calib import MseCalibrator , NVFP4MSECalibrator
37+ from .calib import MseCalibrator , NVFP4MSECalibrator , _Calibrator
3738from .conversion import create_and_replace_svdquant_linear_on_the_fly , set_quantizer_by_cfg_context
3839from .nn import NVFP4StaticQuantizer , QuantModule , SequentialQuantizer , TensorQuantizer
3940from .utils import (
5253from .utils .calib_utils import GPTQHelper
5354
5455__all__ = [
56+ "CalibratorFactory" ,
5557 "awq" ,
5658 "local_hessian_calibrate" ,
5759 "max_calibrate" ,
6163 "svdquant" ,
6264]
6365
64- # Registry for backends that provide a custom calibrator factory for mse_calibrate().
65- _FP8_SWEEP_CALIBRATOR_REGISTRY : dict [str , type ] = {}
66+ CalibratorFactory : TypeAlias = Callable [
67+ [torch .Tensor , int | tuple | list | None , Callable [..., torch .Tensor ]], _Calibrator
68+ ]
69+
70+ _FP8_SWEEP_CALIBRATOR_REGISTRY : dict [str , CalibratorFactory ] = {}
6671
6772
68- def register_fp8_sweep_calibrator (backend : str , calibrator_factory ) -> None :
73+ def register_fp8_sweep_calibrator (backend : str , calibrator_factory : CalibratorFactory ) -> None :
6974 """Register a custom calibrator factory for a quantization backend.
7075
7176 When ``fp8_scale_sweep=True`` is passed to :func:`mse_calibrate`, any weight
@@ -74,8 +79,9 @@ def register_fp8_sweep_calibrator(backend: str, calibrator_factory) -> None:
7479
7580 Args:
7681 backend: Backend name string (must match ``TensorQuantizer.backend``).
77- calibrator_factory: Callable with signature ``(amax, axis, quant_func)``
78- that returns a calibrator instance.
82+ calibrator_factory: Callable with signature
83+ ``(amax: Tensor, axis: int | tuple | list | None, quant_func: Callable)``
84+ that returns a :class:`_Calibrator` instance.
7985 """
8086 _FP8_SWEEP_CALIBRATOR_REGISTRY [backend ] = calibrator_factory
8187
@@ -358,8 +364,11 @@ def mse_calibrate(
358364
359365 if fp8_scale_sweep :
360366 # Check if backend has a registered custom calibrator factory.
361- backend_factory = _FP8_SWEEP_CALIBRATOR_REGISTRY .get (
362- getattr (module , "backend" , None )
367+ _backend : str | None = getattr (module , "backend" , None )
368+ backend_factory = (
369+ _FP8_SWEEP_CALIBRATOR_REGISTRY .get (_backend )
370+ if _backend is not None
371+ else None
363372 )
364373 if backend_factory is not None :
365374 module ._calibrator = backend_factory (
0 commit comments