Skip to content

Commit 0e43a87

Browse files
committed
add type check
Signed-off-by: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com>
1 parent c7b5044 commit 0e43a87

1 file changed

Lines changed: 17 additions & 8 deletions

File tree

modelopt/torch/quantization/model_calib.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import warnings
2121
from collections.abc import Callable
2222
from functools import partial
23+
from typing import TypeAlias
2324

2425
import torch
2526
import torch.distributed as dist
@@ -33,7 +34,7 @@
3334
from modelopt.torch.utils.distributed import DistributedProcessGroup, ParallelState
3435
from modelopt.torch.utils.network import bind_forward_method, unpatch_forward_method
3536

36-
from .calib import MseCalibrator, NVFP4MSECalibrator
37+
from .calib import MseCalibrator, NVFP4MSECalibrator, _Calibrator
3738
from .conversion import create_and_replace_svdquant_linear_on_the_fly, set_quantizer_by_cfg_context
3839
from .nn import NVFP4StaticQuantizer, QuantModule, SequentialQuantizer, TensorQuantizer
3940
from .utils import (
@@ -52,6 +53,7 @@
5253
from .utils.calib_utils import GPTQHelper
5354

5455
__all__ = [
56+
"CalibratorFactory",
5557
"awq",
5658
"local_hessian_calibrate",
5759
"max_calibrate",
@@ -61,11 +63,14 @@
6163
"svdquant",
6264
]
6365

64-
# Registry for backends that provide a custom calibrator factory for mse_calibrate().
65-
_FP8_SWEEP_CALIBRATOR_REGISTRY: dict[str, type] = {}
66+
CalibratorFactory: TypeAlias = Callable[
67+
[torch.Tensor, int | tuple | list | None, Callable[..., torch.Tensor]], _Calibrator
68+
]
69+
70+
_FP8_SWEEP_CALIBRATOR_REGISTRY: dict[str, CalibratorFactory] = {}
6671

6772

68-
def register_fp8_sweep_calibrator(backend: str, calibrator_factory) -> None:
73+
def register_fp8_sweep_calibrator(backend: str, calibrator_factory: CalibratorFactory) -> None:
6974
"""Register a custom calibrator factory for a quantization backend.
7075
7176
When ``fp8_scale_sweep=True`` is passed to :func:`mse_calibrate`, any weight
@@ -74,8 +79,9 @@ def register_fp8_sweep_calibrator(backend: str, calibrator_factory) -> None:
7479
7580
Args:
7681
backend: Backend name string (must match ``TensorQuantizer.backend``).
77-
calibrator_factory: Callable with signature ``(amax, axis, quant_func)``
78-
that returns a calibrator instance.
82+
calibrator_factory: Callable with signature
83+
``(amax: Tensor, axis: int | tuple | list | None, quant_func: Callable)``
84+
that returns a :class:`_Calibrator` instance.
7985
"""
8086
_FP8_SWEEP_CALIBRATOR_REGISTRY[backend] = calibrator_factory
8187

@@ -358,8 +364,11 @@ def mse_calibrate(
358364

359365
if fp8_scale_sweep:
360366
# Check if backend has a registered custom calibrator factory.
361-
backend_factory = _FP8_SWEEP_CALIBRATOR_REGISTRY.get(
362-
getattr(module, "backend", None)
367+
_backend: str | None = getattr(module, "backend", None)
368+
backend_factory = (
369+
_FP8_SWEEP_CALIBRATOR_REGISTRY.get(_backend)
370+
if _backend is not None
371+
else None
363372
)
364373
if backend_factory is not None:
365374
module._calibrator = backend_factory(

0 commit comments

Comments
 (0)