Skip to content

Commit 61e506a

Browse files
committed
Move AutoQuant review helpers
Signed-off-by: weimingc <17592131+meenchen@users.noreply.github.com>
1 parent 6093276 commit 61e506a

5 files changed

Lines changed: 80 additions & 74 deletions

File tree

examples/llm_ptq/example_utils.py

Lines changed: 55 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,9 @@
4242
ProcessorMixin,
4343
)
4444

45+
from modelopt.torch.export.model_utils import is_multimodal_model
46+
from modelopt.torch.quantization.config import _default_disabled_quantizer_cfg
47+
4548
try:
4649
from huggingface_hub import snapshot_download
4750
except ImportError:
@@ -51,6 +54,58 @@
5154

5255
SPECULATIVE_MODEL_LIST = ["Eagle", "Medusa"]
5356

57+
# TODO: Refactor into the config system.
58+
_QWEN36_AUTOQ_DISABLED_LAYERS = (
59+
"*shared_expert_gate*",
60+
"*linear_attn.in_proj_a*",
61+
"*linear_attn.in_proj_b*",
62+
)
63+
_VLM_AUTOQ_DISABLED_LAYERS = ("*visual*", "*mtp*", "*vision_tower*")
64+
65+
66+
def _is_qwen_model(model) -> bool:
67+
"""Return True when model/config identifiers indicate a Qwen-family model."""
68+
candidates = [type(model).__name__]
69+
config = getattr(model, "config", None)
70+
configs = [
71+
config,
72+
getattr(config, "text_config", None),
73+
getattr(config, "language_config", None),
74+
]
75+
for cfg in configs:
76+
if cfg is None:
77+
continue
78+
candidates.append(type(cfg).__name__)
79+
model_type = getattr(cfg, "model_type", None)
80+
if model_type is not None:
81+
candidates.append(str(model_type))
82+
architectures = getattr(cfg, "architectures", ()) or ()
83+
if isinstance(architectures, str):
84+
architectures = (architectures,)
85+
candidates.extend(str(architecture) for architecture in architectures)
86+
return any("qwen" in candidate.lower() for candidate in candidates)
87+
88+
89+
def _get_auto_quantize_disabled_layers(model) -> list[str]:
90+
"""Return layer patterns that should be excluded from AutoQuantize search."""
91+
disabled_layers = [
92+
entry["quantizer_name"]
93+
for entry in _default_disabled_quantizer_cfg
94+
if "parent_class" not in entry and entry["quantizer_name"] != "*lm_head*"
95+
]
96+
if _is_qwen_model(model):
97+
disabled_layers.extend(p for p in _QWEN36_AUTOQ_DISABLED_LAYERS if p not in disabled_layers)
98+
if is_multimodal_model(model):
99+
disabled_layers.extend(p for p in _VLM_AUTOQ_DISABLED_LAYERS if p not in disabled_layers)
100+
return disabled_layers
101+
102+
103+
def _get_auto_quantize_cost_excluded_patterns(model) -> list[str]:
104+
"""Return layer patterns excluded only from AutoQuantize cost accounting."""
105+
if is_multimodal_model(model):
106+
return list(_VLM_AUTOQ_DISABLED_LAYERS)
107+
return []
108+
54109

55110
def run_nemotron_vl_preview(
56111
full_model,
@@ -133,7 +188,6 @@ def is_nemotron_vl(model_or_config):
133188
# Try to get config from model, or use directly if it's a config
134189
if hasattr(model_or_config, "config"):
135190
config = model_or_config.config
136-
from modelopt.torch.export.model_utils import is_multimodal_model
137191

138192
if not is_multimodal_model(model_or_config):
139193
return False

examples/llm_ptq/hf_ptq.py

Lines changed: 3 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@
2727
from cast_mxfp4_to_nvfp4 import apply_to_model as apply_cast_mxfp4_to_nvfp4
2828
from cast_mxfp4_to_nvfp4 import force_weight_quantizers_static
2929
from example_utils import (
30+
_get_auto_quantize_cost_excluded_patterns,
31+
_get_auto_quantize_disabled_layers,
3032
build_quant_cfg,
3133
copy_custom_model_files,
3234
create_vlm_calibration_loop,
@@ -73,7 +75,7 @@
7375
)
7476
from modelopt.torch.export.model_utils import get_language_model_from_vl, is_multimodal_model
7577
from modelopt.torch.quantization._auto_quantize_cost import EXCLUDED_MODULE_NAME_PATTERNS_KEY
76-
from modelopt.torch.quantization.config import _default_disabled_quantizer_cfg, need_calibration
78+
from modelopt.torch.quantization.config import need_calibration
7779
from modelopt.torch.quantization.plugins.accelerate import init_quantized_weights
7880
from modelopt.torch.quantization.utils import is_quantized
7981
from modelopt.torch.speculative.eagle.utils import (
@@ -159,59 +161,6 @@ def _canonical_qformat(name: str) -> str:
159161
mto.enable_huggingface_checkpointing()
160162

161163

162-
# TODO: Refactor into the config system.
163-
_QWEN36_AUTOQ_DISABLED_LAYERS = (
164-
"*shared_expert_gate*",
165-
"*linear_attn.in_proj_a*",
166-
"*linear_attn.in_proj_b*",
167-
)
168-
_VLM_AUTOQ_DISABLED_LAYERS = ("*visual*", "*mtp*", "*vision_tower*")
169-
170-
171-
def _is_qwen_model(model) -> bool:
172-
"""Return True when model/config identifiers indicate a Qwen-family model."""
173-
candidates = [type(model).__name__]
174-
config = getattr(model, "config", None)
175-
configs = [
176-
config,
177-
getattr(config, "text_config", None),
178-
getattr(config, "language_config", None),
179-
]
180-
for cfg in configs:
181-
if cfg is None:
182-
continue
183-
candidates.append(type(cfg).__name__)
184-
model_type = getattr(cfg, "model_type", None)
185-
if model_type is not None:
186-
candidates.append(str(model_type))
187-
architectures = getattr(cfg, "architectures", ()) or ()
188-
if isinstance(architectures, str):
189-
architectures = (architectures,)
190-
candidates.extend(str(architecture) for architecture in architectures)
191-
return any("qwen" in candidate.lower() for candidate in candidates)
192-
193-
194-
def _get_auto_quantize_disabled_layers(model) -> list[str]:
195-
"""Return layer patterns that should be excluded from AutoQuantize search."""
196-
disabled_layers = [
197-
entry["quantizer_name"]
198-
for entry in _default_disabled_quantizer_cfg
199-
if "parent_class" not in entry and entry["quantizer_name"] != "*lm_head*"
200-
]
201-
if _is_qwen_model(model):
202-
disabled_layers.extend(p for p in _QWEN36_AUTOQ_DISABLED_LAYERS if p not in disabled_layers)
203-
if is_multimodal_model(model):
204-
disabled_layers.extend(p for p in _VLM_AUTOQ_DISABLED_LAYERS if p not in disabled_layers)
205-
return disabled_layers
206-
207-
208-
def _get_auto_quantize_cost_excluded_patterns(model) -> list[str]:
209-
"""Return layer patterns excluded only from AutoQuantize cost accounting."""
210-
if is_multimodal_model(model):
211-
return list(_VLM_AUTOQ_DISABLED_LAYERS)
212-
return []
213-
214-
215164
def extract_and_prepare_language_model_from_vl(full_model):
216165
"""Extract language model from VL model and disable quantization for non-language components.
217166

modelopt/torch/quantization/algorithms.py

Lines changed: 8 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -55,22 +55,15 @@
5555
from .utils import is_quantized_linear
5656

5757

58-
def _is_fused_experts_module(module: nn.Module) -> bool:
59-
"""Return True if ``module`` is a quantized fused-MoE-experts container.
60-
61-
These modules expose plural ``*_input_quantizer`` and ``*_weight_quantizers``
62-
(an ``nn.ModuleList`` of per-expert quantizers) instead of the singular
63-
``input_quantizer`` / ``weight_quantizer`` attrs found on standard
64-
``nn.Linear``-derived QuantModules. AutoQuantize hparam discovery and cost
65-
accounting need to recognize this layout to enumerate fused experts as
66-
search dimensions.
67-
"""
68-
# Late import to avoid a circular import at module load time.
58+
def _is_hf_quant_fused_experts_module(module: nn.Module) -> bool:
59+
"""Return True for a converted HF fused-MoE-experts quantization wrapper."""
60+
# Late import avoids a circular import: the HF plugin registers AutoQuantize
61+
# support from this module at import time.
6962
try:
70-
from .plugins.huggingface import _QuantFusedExperts
63+
from .plugins.huggingface import _is_quant_fused_experts_module
7164
except ImportError:
7265
return False
73-
return isinstance(module, _QuantFusedExperts)
66+
return _is_quant_fused_experts_module(module)
7467

7568

7669
# Quantizer attribute names that participate in AutoQuantize snapshot/restore.
@@ -90,7 +83,7 @@ def _get_quantizer_attrs(module: nn.Module) -> tuple[str, ...]:
9083
shared input quantizers + two ``ModuleList`` of per-expert weight quantizers).
9184
For standard Linear-derived QuantModules, returns the canonical trio.
9285
"""
93-
if _is_fused_experts_module(module):
86+
if _is_hf_quant_fused_experts_module(module):
9487
return _FUSED_EXPERTS_QUANTIZER_ATTRS
9588
return _STD_QUANTIZER_ATTRS
9689

@@ -517,7 +510,7 @@ def _is_auto_quantize_module(module):
517510
# weight quantizers in an ``nn.ModuleList`` plus shared input quantizers.
518511
# All N experts in a layer share one search dimension (one recipe per
519512
# fused module).
520-
return _is_fused_experts_module(module) and isinstance(module, QuantModule)
513+
return _is_hf_quant_fused_experts_module(module) and isinstance(module, QuantModule)
521514

522515
@staticmethod
523516
def _get_search_recipes(quantization_formats):

modelopt/torch/quantization/plugins/huggingface.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -946,6 +946,11 @@ def fold_weight(self, keep_attrs: bool = False):
946946
delattr(q, attr_name)
947947

948948

949+
def _is_quant_fused_experts_module(module):
950+
"""Return True for a converted HF fused-MoE-experts quantization wrapper."""
951+
return isinstance(module, _QuantFusedExperts)
952+
953+
949954
class _QuantDbrxFFN(_QuantSparseSequentialMoe):
950955
@property
951956
def num_experts(self):

tests/examples/llm_ptq/test_hf_ptq_args.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,11 @@ def _import_hf_ptq(monkeypatch):
2828
return importlib.import_module("hf_ptq")
2929

3030

31+
def _import_example_utils(monkeypatch):
32+
monkeypatch.syspath_prepend(str(_EXAMPLES_DIR))
33+
return importlib.import_module("example_utils")
34+
35+
3136
def _parse_hf_ptq_args(monkeypatch, *args):
3237
hf_ptq = _import_hf_ptq(monkeypatch)
3338
monkeypatch.setattr(sys, "argv", ["hf_ptq.py", *args])
@@ -87,7 +92,7 @@ def test_load_model_keeps_nemotron_vl_text_calibration_for_autoquant(monkeypatch
8792

8893

8994
def test_qwen_autoquant_disabled_layers_are_scoped_to_qwen_models(monkeypatch):
90-
hf_ptq = _import_hf_ptq(monkeypatch)
95+
example_utils = _import_example_utils(monkeypatch)
9196
qwen_model = SimpleNamespace(config=SimpleNamespace(model_type="qwen3_moe"))
9297
llama_model = SimpleNamespace(config=SimpleNamespace(model_type="llama"))
9398
qwen_only_patterns = {
@@ -96,10 +101,10 @@ def test_qwen_autoquant_disabled_layers_are_scoped_to_qwen_models(monkeypatch):
96101
"*linear_attn.in_proj_b*",
97102
}
98103

99-
monkeypatch.setattr(hf_ptq, "is_multimodal_model", lambda model: False)
104+
monkeypatch.setattr(example_utils, "is_multimodal_model", lambda model: False)
100105

101-
qwen_disabled_layers = set(hf_ptq._get_auto_quantize_disabled_layers(qwen_model))
102-
llama_disabled_layers = set(hf_ptq._get_auto_quantize_disabled_layers(llama_model))
106+
qwen_disabled_layers = set(example_utils._get_auto_quantize_disabled_layers(qwen_model))
107+
llama_disabled_layers = set(example_utils._get_auto_quantize_disabled_layers(llama_model))
103108

104109
assert qwen_only_patterns <= qwen_disabled_layers
105110
assert qwen_only_patterns.isdisjoint(llama_disabled_layers)

0 commit comments

Comments
 (0)