Skip to content

Commit baaf80f

Browse files
committed
Add per-mode opt-out for layerwise calibration
Introduce `_supports_layerwise: bool = True` on `BaseCalibrateModeDescriptor` so individual calibration modes can declare incompatibility with layer-by-layer calibration. `wrapped_calib_func` raises a clear `ValueError` when `layerwise=True` is requested on a mode that opts out, instead of failing deep inside the algorithm. Opt `SVDQuantModeDescriptor` out — `create_and_replace_svdquant_linear_on_the_fly` reads `ModeloptStateManager` from the root model, which is not present when `layerwise_calibrate` dispatches per decoder layer. Restructure the end-to-end layerwise tests: - `test_mtq_quantize_layerwise_e2e_max` runs the full happy path on `max` - `test_mtq_quantize_layerwise_dispatches_for_algorithm` stubs `layerwise_calibrate` and verifies every supporting algorithm (including gptq) routes through it - `test_mtq_quantize_layerwise_raises_for_unsupported_algorithm` guards the new ValueError for svdquant Signed-off-by: realAsma <akuriparambi@nvidia.com>
1 parent 2d87322 commit baaf80f

File tree

2 files changed

+153
-1
lines changed

2 files changed

+153
-1
lines changed

modelopt/torch/quantization/mode.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,7 @@ def wrapped_calib_func(
213213
config: QuantizeAlgorithmConfig,
214214
forward_loop: ForwardLoop | None = None,
215215
func: Callable | None = None,
216+
supports_layerwise: bool = True,
216217
) -> ConvertReturnType:
217218
"""Wrap the calibration function to be compatible with the ModelOpt convert entrypoint.
218219
@@ -241,6 +242,13 @@ def wrapped_calib_func(
241242
if layerwise:
242243
# All currently implemented PTQ algorithms support layerwise calibration;
243244
# future algorithms that need full-model context must add a guard here.
245+
if not supports_layerwise:
246+
raise ValueError(
247+
f"Calibration algorithm '{method}' does not support layerwise=True. "
248+
"Set layerwise=False, or override `_supports_layerwise = True` on the "
249+
"corresponding CalibrateModeDescriptor once the algorithm is made "
250+
"compatible with per-layer calibration."
251+
)
244252
if forward_loop is None:
245253
raise ValueError("forward_loop is required for calibration but got None.")
246254
# Wrap with layerwise processing
@@ -282,6 +290,10 @@ class BaseCalibrateModeDescriptor(ModeDescriptor):
282290

283291
_calib_func: Callable | None
284292

293+
# Override to False when the algorithm requires full-model context and
294+
# cannot run per decoder layer (e.g. needs ModeloptStateManager on the root).
295+
_supports_layerwise: bool = True
296+
285297
def __init__(self, *args, **kwargs):
286298
"""Initialize Base calibrate mode descriptor."""
287299
assert issubclass(self.config_class, QuantizeAlgorithmConfig), (
@@ -327,7 +339,13 @@ def convert(self) -> ConvertEntrypoint:
327339
def wrapped_func(model, config, forward_loop=None):
328340
# Access _calib_func as a class attribute to avoid binding
329341
# Check if _calib_func is defined as a class attribute
330-
return wrapped_calib_func(model, config, forward_loop, func=self.__class__._calib_func)
342+
return wrapped_calib_func(
343+
model,
344+
config,
345+
forward_loop,
346+
func=self.__class__._calib_func,
347+
supports_layerwise=self.__class__._supports_layerwise,
348+
)
331349

332350
return wrapped_func
333351

@@ -486,6 +504,9 @@ def config_class(self) -> type[QuantizeAlgorithmConfig]:
486504
return SVDQuantConfig
487505

488506
_calib_func = svdquant
507+
# create_and_replace_svdquant_linear_on_the_fly reads ModeloptStateManager from the
508+
# root model, which is not present when layerwise_calibrate dispatches per decoder layer.
509+
_supports_layerwise = False
489510

490511
@property
491512
def restore(self) -> RestoreEntrypoint:

tests/unit/torch/quantization/test_layerwise_calibrate.py

Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,16 @@
1515

1616
"""Unit tests for layerwise_calibrate and LayerActivationCollector."""
1717

18+
import copy
1819
from collections import deque
1920

2021
import pytest
2122
import torch
2223
import torch.nn as nn
2324

25+
import modelopt.torch.quantization as mtq
2426
from modelopt.torch.quantization.model_calib import layerwise_calibrate
27+
from modelopt.torch.quantization.nn import TensorQuantizer
2528
from modelopt.torch.quantization.utils.layerwise_calib import LayerActivationCollector, _SkipLayer
2629

2730

@@ -593,3 +596,131 @@ def forward_loop(m):
593596
for i, orig in enumerate(originals):
594597
assert model.layers[i] is orig, f"Layer {i} not restored to original after cleanup"
595598
assert not hasattr(orig, "_layerwise_calib"), f"Layer {i} still has _layerwise_calib"
599+
600+
601+
# ---------------------------------------------------------------------------
602+
# End-to-end mtq.quantize(..., algorithm={"layerwise": True}) per PTQ algorithm
603+
# ---------------------------------------------------------------------------
604+
605+
606+
def _int8_layerwise_config(algorithm: dict) -> dict:
607+
"""Start from the shipped INT8 config and enable layerwise in the algorithm block.
608+
609+
Using a real shipped config guarantees the same include/exclude rules
610+
production PTQ relies on, so algorithm dispatch matches real usage.
611+
"""
612+
cfg = copy.deepcopy(mtq.INT8_SMOOTHQUANT_CFG)
613+
cfg["algorithm"] = algorithm
614+
return cfg
615+
616+
617+
def _awq_layerwise_config() -> dict:
618+
"""INT4 weight-only AWQ config sized for the _DecoderBlock test model."""
619+
cfg = copy.deepcopy(mtq.INT4_AWQ_CFG)
620+
# Resize AWQ block to fit dim=16 hidden.
621+
for entry in cfg["quant_cfg"]:
622+
if entry.get("quantizer_name") == "*weight_quantizer":
623+
entry.setdefault("cfg", {})["block_sizes"] = {-1: 8, "type": "static"}
624+
cfg["algorithm"] = {"method": "awq_lite", "alpha_step": 0.5, "layerwise": True}
625+
return cfg
626+
627+
628+
def _svdquant_layerwise_config() -> dict:
629+
"""SVDQuant config sized for the _DecoderBlock test model."""
630+
cfg = copy.deepcopy(mtq.INT4_AWQ_CFG)
631+
for entry in cfg["quant_cfg"]:
632+
if entry.get("quantizer_name") == "*weight_quantizer":
633+
entry.setdefault("cfg", {})["block_sizes"] = {-1: 8, "type": "static"}
634+
cfg["algorithm"] = {"method": "svdquant", "lowrank": 4, "layerwise": True}
635+
return cfg
636+
637+
638+
def test_mtq_quantize_layerwise_e2e_max(monkeypatch):
639+
"""End-to-end: mtq.quantize with layerwise=True produces populated amax values.
640+
641+
``max`` is the representative algorithm for the layerwise happy path because
642+
every other algorithm seeds amax via max_calibrate first — if max works, the
643+
shared skip/run/capture machinery is sound. Other algorithms are covered by
644+
the dispatch-only test below to avoid hardware requirements (e.g. gptq needs
645+
CUDA) or unnecessary duplication.
646+
"""
647+
_register_test_discoverer(monkeypatch)
648+
config = _int8_layerwise_config({"method": "max", "layerwise": True})
649+
650+
torch.manual_seed(0)
651+
model = _SimpleTransformerModel(n_layers=3, dim=16)
652+
calib_data = [torch.randint(0, 32, (2, 8)) for _ in range(2)]
653+
654+
def forward_loop(m):
655+
for batch in calib_data:
656+
m(batch)
657+
658+
model = mtq.quantize(model, config, forward_loop=forward_loop)
659+
660+
for i, layer in enumerate(model.layers):
661+
assert not isinstance(layer, _SkipLayer), f"layer {i} left as _SkipLayer"
662+
assert not hasattr(layer, "_layerwise_calib"), f"layer {i} leaked _layerwise_calib"
663+
664+
amax_count = sum(
665+
1
666+
for layer in model.layers
667+
for module in layer.modules()
668+
if (
669+
isinstance(module, TensorQuantizer)
670+
and module.is_enabled
671+
and getattr(module, "_amax", None) is not None
672+
)
673+
)
674+
assert amax_count > 0, "no TensorQuantizer in decoder layers had _amax populated"
675+
676+
with torch.no_grad():
677+
model(calib_data[0])
678+
679+
680+
@pytest.mark.parametrize(
681+
"algorithm",
682+
["gptq", "awq_lite", "smoothquant", "mse"],
683+
)
684+
def test_mtq_quantize_layerwise_dispatches_for_algorithm(monkeypatch, algorithm):
685+
"""Every layerwise-supporting algorithm must route through layerwise_calibrate.
686+
687+
Stubs layerwise_calibrate to a spy so the dispatch contract is checked without
688+
running the algorithm's full calibration — lets ``gptq`` (CUDA-only at runtime)
689+
and other expensive algorithms participate in CPU unit tests.
690+
"""
691+
spy: dict = {}
692+
693+
def stub(model, forward_loop, calib_func, **kwargs):
694+
spy["calib_func"] = calib_func
695+
spy["kwargs"] = kwargs
696+
697+
monkeypatch.setattr("modelopt.torch.quantization.mode.layerwise_calibrate", stub)
698+
699+
if algorithm == "awq_lite":
700+
config = _awq_layerwise_config()
701+
else:
702+
config = _int8_layerwise_config({"method": algorithm, "layerwise": True})
703+
704+
torch.manual_seed(0)
705+
model = _SimpleTransformerModel(n_layers=2, dim=16)
706+
mtq.quantize(
707+
model,
708+
config,
709+
forward_loop=lambda m: m(torch.randint(0, 32, (2, 8))),
710+
)
711+
712+
assert "calib_func" in spy, f"{algorithm} did not dispatch through layerwise_calibrate"
713+
assert callable(spy["calib_func"])
714+
715+
716+
def test_mtq_quantize_layerwise_raises_for_unsupported_algorithm():
717+
"""Modes with ``_supports_layerwise = False`` must raise a clear ValueError."""
718+
config = _svdquant_layerwise_config()
719+
torch.manual_seed(0)
720+
model = _SimpleTransformerModel(n_layers=2, dim=16)
721+
with pytest.raises(ValueError, match="does not support layerwise=True"):
722+
mtq.quantize(
723+
model,
724+
config,
725+
forward_loop=lambda m: m(torch.randint(0, 32, (2, 8))),
726+
)

0 commit comments

Comments
 (0)