Skip to content

Commit 7f2dca2

Browse files
committed
Add AutoQuant support for VLMs
Signed-off-by: weimingc <17592131+meenchen@users.noreply.github.com>
1 parent 902d369 commit 7f2dca2

6 files changed

Lines changed: 371 additions & 44 deletions

File tree

examples/llm_ptq/hf_ptq.py

Lines changed: 70 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@
6666
save_expert_token_count_table,
6767
)
6868
from modelopt.torch.export.model_utils import get_language_model_from_vl, is_multimodal_model
69+
from modelopt.torch.quantization._auto_quantize_cost import EXCLUDED_MODULE_NAME_PATTERNS_KEY
6970
from modelopt.torch.quantization.config import _default_disabled_quantizer_cfg, need_calibration
7071
from modelopt.torch.quantization.plugins.accelerate import init_quantized_weights
7172
from modelopt.torch.quantization.utils import is_quantized
@@ -140,6 +141,36 @@ def _set_kv_cache_constant_amax(quant_cfg: list) -> None:
140141
mto.enable_huggingface_checkpointing()
141142

142143

144+
# TODO: To be refacored into config system.
145+
_QWEN36_AUTOQ_DISABLED_LAYERS = (
146+
"*shared_expert_gate*",
147+
"*linear_attn.in_proj_a*",
148+
"*linear_attn.in_proj_b*",
149+
)
150+
_VLM_AUTOQ_DISABLED_LAYERS = ("*visual*", "*mtp*", "*vision_tower*")
151+
152+
153+
def get_auto_quantize_disabled_layers(model) -> list[str]:
154+
"""Return layer patterns that should be excluded from AutoQuantize search."""
155+
disabled_layers = [
156+
entry["quantizer_name"]
157+
for entry in _default_disabled_quantizer_cfg
158+
if "parent_class" not in entry and entry["quantizer_name"] != "*lm_head*"
159+
]
160+
disabled_layers.extend(p for p in _QWEN36_AUTOQ_DISABLED_LAYERS if p not in disabled_layers)
161+
if is_multimodal_model(model):
162+
disabled_layers.extend(p for p in _VLM_AUTOQ_DISABLED_LAYERS if p not in disabled_layers)
163+
return disabled_layers
164+
165+
166+
def get_auto_quantize_cost_excluded_patterns(args, model) -> list[str]:
167+
"""Return layer patterns excluded only from AutoQuantize cost accounting."""
168+
excluded_patterns = list(args.auto_quantize_cost_exclude_patterns or [])
169+
if args.auto_quantize_cost_exclude_vlm_modules and is_multimodal_model(model):
170+
excluded_patterns.extend(_VLM_AUTOQ_DISABLED_LAYERS)
171+
return list(dict.fromkeys(excluded_patterns))
172+
173+
143174
def extract_and_prepare_language_model_from_vl(full_model):
144175
"""Extract language model from VL model and disable quantization for non-language components.
145176
@@ -323,6 +354,7 @@ def auto_quantize(
323354
"nvfp4_awq",
324355
"nvfp4_mse",
325356
"w4a8_awq",
357+
"w4a16_nvfp4",
326358
"fp8_pb_wo",
327359
"w4a8_mxfp4_fp8",
328360
"nvfp4_mlp_only",
@@ -386,10 +418,14 @@ def forward_step(model, batch):
386418
"effective_bits": args.auto_quantize_bits,
387419
"cost_model": args.auto_quantize_cost_model,
388420
}
421+
auto_quantize_cost = {}
389422
if args.auto_quantize_active_moe_expert_ratio is not None:
390-
auto_quantize_constraints["cost"] = {
391-
"active_moe_expert_ratio": args.auto_quantize_active_moe_expert_ratio
392-
}
423+
auto_quantize_cost["active_moe_expert_ratio"] = args.auto_quantize_active_moe_expert_ratio
424+
cost_excluded_patterns = get_auto_quantize_cost_excluded_patterns(args, language_model)
425+
if cost_excluded_patterns:
426+
auto_quantize_cost[EXCLUDED_MODULE_NAME_PATTERNS_KEY] = cost_excluded_patterns
427+
if auto_quantize_cost:
428+
auto_quantize_constraints["cost"] = auto_quantize_cost
393429

394430
language_model, _ = mtq.auto_quantize(
395431
language_model,
@@ -405,12 +441,7 @@ def forward_step(model, batch):
405441
len(calib_dataloader), max(auto_quantize_score_size // args.batch_size, 1)
406442
),
407443
verbose=True,
408-
# Disable all default disabled layers such as lm_head, mlp.gate, router etc.
409-
disabled_layers=[
410-
entry["quantizer_name"]
411-
for entry in _default_disabled_quantizer_cfg
412-
if "parent_class" not in entry
413-
],
444+
disabled_layers=get_auto_quantize_disabled_layers(language_model),
414445
method=auto_quantize_method,
415446
checkpoint=auto_quantize_checkpoint,
416447
)
@@ -550,12 +581,10 @@ def load_model(args: argparse.Namespace):
550581
: len(args.dataset)
551582
]
552583

553-
# We only quantize the language model for VLMs other than the type supported above.
554-
# Recipe mode is the exception: in Qwen3.5/3.6-MoE VLMs, lm_head sits
555-
# on the outer CausalLM, not the inner language backbone. A recipe that targets
556-
# lm_head must therefore quantize against the full model and explicitly keep visual
557-
# and MTP siblings disabled.
558-
if args.recipe is None:
584+
# Plain PTQ quantizes only the extracted language model. Recipe and
585+
# AutoQuantize paths keep the outer CausalLM so recipes/search can see
586+
# Qwen3.5/3.6-MoE VLM lm_head.
587+
if args.recipe is None and args.auto_quantize_bits is None:
559588
extracted_lm, extracted_model_type = extract_and_prepare_language_model_from_vl(
560589
full_model
561590
)
@@ -1081,9 +1110,16 @@ def _is_layerwise(obj):
10811110
"Auto quantization needs multiple quantization format."
10821111
)
10831112

1113+
# For VL models, autoquant must walk submodules of the OUTER CausalLM
1114+
# (which carries lm_head and the LM-head forward path) — otherwise
1115+
# lm_head and any sibling-of-language_model modules are silently
1116+
# invisible to the search. ``forward_step`` also needs the outer model
1117+
# to produce ``CausalLMOutputWithPast`` (for ``.loss`` / ``.logits``).
1118+
# Visual tower and MTP siblings are auto-excluded inside
1119+
# ``auto_quantize()`` via *visual* / *mtp* / *vision_tower* patterns.
10841120
auto_quantize(
10851121
args,
1086-
language_model,
1122+
full_model,
10871123
calib_dataloader,
10881124
auto_quantize_method=args.auto_quantize_method,
10891125
auto_quantize_score_size=args.auto_quantize_score_size,
@@ -1423,6 +1459,24 @@ def parse_args() -> argparse.Namespace:
14231459
"routing; use --moe_calib_experts_ratio to control calibration expert coverage."
14241460
),
14251461
)
1462+
parser.add_argument(
1463+
"--auto_quantize_cost_exclude_patterns",
1464+
nargs="+",
1465+
default=None,
1466+
help=(
1467+
"Wildcard module-name patterns to exclude from AutoQuantize effective-bits cost "
1468+
"accounting. The matched modules can still be disabled from quantization separately; "
1469+
"this flag only changes the budget denominator and selected-cost calculation."
1470+
),
1471+
)
1472+
parser.add_argument(
1473+
"--auto_quantize_cost_exclude_vlm_modules",
1474+
action="store_true",
1475+
help=(
1476+
"Exclude VLM sibling modules matching *visual*, *vision_tower*, and *mtp* from "
1477+
"AutoQuantize effective-bits cost accounting."
1478+
),
1479+
)
14261480
parser.add_argument(
14271481
"--moe_calib_experts_ratio",
14281482
type=float,

modelopt/torch/quantization/_auto_quantize_cost.py

Lines changed: 54 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
"""Cost models for AutoQuantize effective-bits accounting."""
1717

18+
import fnmatch
1819
from collections.abc import Callable, Iterable, Sequence
1920
from typing import Any, Final
2021

@@ -27,6 +28,7 @@
2728

2829
AUTO_QUANTIZE_CONSTRAINT_KEYS: Final = frozenset({"effective_bits", "cost_model", "cost"})
2930
ACTIVE_MOE_EXPERT_RATIO_KEY: Final = "active_moe_expert_ratio"
31+
EXCLUDED_MODULE_NAME_PATTERNS_KEY: Final = "excluded_module_name_patterns"
3032
COST_MODEL_WEIGHT: Final = "weight"
3133
COST_MODEL_ACTIVE_MOE: Final = "active_moe"
3234

@@ -90,11 +92,31 @@ def is_routed_moe_module_name(name: str) -> bool:
9092
return "shared_expert" not in name and _ROUTED_MOE_EXPERT_NAME_RE.search(name) is not None
9193

9294

95+
def _get_module_weight_numel(module: nn.Module) -> int:
96+
"""Return the parameter count for a module's quantizable weights.
97+
98+
Standard quantized linear modules have a single ``weight`` parameter. Fused
99+
MoE expert containers expose projection tensors directly instead, so both
100+
fused projections contribute to AutoQuantize cost accounting.
101+
"""
102+
weight = getattr(module, "weight", None)
103+
if weight is not None:
104+
return weight.numel()
105+
106+
# Fused MoE expert containers expose projection tensors directly instead of
107+
# a single ``weight`` parameter.
108+
return sum(
109+
param.numel()
110+
for attr in ("gate_up_proj", "down_proj")
111+
if (param := getattr(module, attr, None)) is not None
112+
)
113+
114+
93115
class AutoQuantizeCostModel:
94116
"""Base class for AutoQuantize effective-bits cost accounting."""
95117

96118
name: str
97-
supported_cost_keys: frozenset[str] = frozenset()
119+
supported_cost_keys: frozenset[str] = frozenset({EXCLUDED_MODULE_NAME_PATTERNS_KEY})
98120

99121
def normalize_cost_constraints(
100122
self, model: nn.Module, cost_constraints: dict[str, Any]
@@ -103,12 +125,35 @@ def normalize_cost_constraints(
103125
unknown_cost_keys = set(cost_constraints) - self.supported_cost_keys
104126
if unknown_cost_keys:
105127
raise ValueError(f"Unsupported auto_quantize cost constraints: {unknown_cost_keys}.")
128+
excluded_patterns = cost_constraints.get(EXCLUDED_MODULE_NAME_PATTERNS_KEY)
129+
if excluded_patterns is None:
130+
return cost_constraints
131+
if isinstance(excluded_patterns, str):
132+
excluded_patterns = [excluded_patterns]
133+
if not isinstance(excluded_patterns, Sequence) or not all(
134+
isinstance(pattern, str) for pattern in excluded_patterns
135+
):
136+
raise ValueError(
137+
f"constraints['cost']['{EXCLUDED_MODULE_NAME_PATTERNS_KEY}'] must be a string "
138+
"or a sequence of strings."
139+
)
140+
cost_constraints[EXCLUDED_MODULE_NAME_PATTERNS_KEY] = list(excluded_patterns)
106141
return cost_constraints
107142

108143
def module_cost_weight(
109144
self, module_names: Sequence[str], cost_constraints: dict[str, Any]
110145
) -> float:
111146
"""Return the cost multiplier for a group of modules."""
147+
excluded_patterns = cost_constraints.get(EXCLUDED_MODULE_NAME_PATTERNS_KEY, [])
148+
if (
149+
module_names
150+
and excluded_patterns
151+
and all(
152+
any(fnmatch.fnmatch(name, pattern) for pattern in excluded_patterns)
153+
for name in module_names
154+
)
155+
):
156+
return 0.0
112157
return 1.0
113158

114159
def total_weight_size(
@@ -119,7 +164,7 @@ def total_weight_size(
119164
) -> float:
120165
"""Return the cost denominator for the effective-bits constraint."""
121166
return sum(
122-
module.weight.numel() * self.module_cost_weight([name], cost_constraints)
167+
_get_module_weight_numel(module) * self.module_cost_weight([name], cost_constraints)
123168
for name, module in named_modules
124169
if is_auto_quantize_module(module)
125170
)
@@ -135,7 +180,9 @@ class ActiveMoECostModel(AutoQuantizeCostModel):
135180
"""Scale routed MoE expert weights by the active experts per-token ratio."""
136181

137182
name = COST_MODEL_ACTIVE_MOE
138-
supported_cost_keys = frozenset({ACTIVE_MOE_EXPERT_RATIO_KEY})
183+
supported_cost_keys = frozenset(
184+
{ACTIVE_MOE_EXPERT_RATIO_KEY, EXCLUDED_MODULE_NAME_PATTERNS_KEY}
185+
)
139186

140187
def normalize_cost_constraints(
141188
self, model: nn.Module, cost_constraints: dict[str, Any]
@@ -164,9 +211,12 @@ def normalize_cost_constraints(
164211
def module_cost_weight(
165212
self, module_names: Sequence[str], cost_constraints: dict[str, Any]
166213
) -> float:
214+
base_weight = super().module_cost_weight(module_names, cost_constraints)
215+
if base_weight == 0.0:
216+
return 0.0
167217
if any(is_routed_moe_module_name(n) for n in module_names):
168218
return cost_constraints[ACTIVE_MOE_EXPERT_RATIO_KEY]
169-
return 1.0
219+
return base_weight
170220

171221

172222
_COST_MODELS: Final = {

0 commit comments

Comments
 (0)