Skip to content

Commit ec69be9

Browse files
committed
Add AutoQuant support for VLMs
Signed-off-by: weimingc <17592131+meenchen@users.noreply.github.com>
1 parent 902d369 commit ec69be9

5 files changed

Lines changed: 250 additions & 36 deletions

File tree

examples/llm_ptq/hf_ptq.py

Lines changed: 36 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,28 @@ def _set_kv_cache_constant_amax(quant_cfg: list) -> None:
140140
mto.enable_huggingface_checkpointing()
141141

142142

143+
# TODO: To be refacored into config system.
144+
_QWEN36_AUTOQ_DISABLED_LAYERS = (
145+
"*shared_expert_gate*",
146+
"*linear_attn.in_proj_a*",
147+
"*linear_attn.in_proj_b*",
148+
)
149+
_VLM_AUTOQ_DISABLED_LAYERS = ("*visual*", "*mtp*", "*vision_tower*")
150+
151+
152+
def get_auto_quantize_disabled_layers(model) -> list[str]:
153+
"""Return layer patterns that should be excluded from AutoQuantize search."""
154+
disabled_layers = [
155+
entry["quantizer_name"]
156+
for entry in _default_disabled_quantizer_cfg
157+
if "parent_class" not in entry and entry["quantizer_name"] != "*lm_head*"
158+
]
159+
disabled_layers.extend(p for p in _QWEN36_AUTOQ_DISABLED_LAYERS if p not in disabled_layers)
160+
if is_multimodal_model(model):
161+
disabled_layers.extend(p for p in _VLM_AUTOQ_DISABLED_LAYERS if p not in disabled_layers)
162+
return disabled_layers
163+
164+
143165
def extract_and_prepare_language_model_from_vl(full_model):
144166
"""Extract language model from VL model and disable quantization for non-language components.
145167
@@ -323,6 +345,7 @@ def auto_quantize(
323345
"nvfp4_awq",
324346
"nvfp4_mse",
325347
"w4a8_awq",
348+
"w4a16_nvfp4",
326349
"fp8_pb_wo",
327350
"w4a8_mxfp4_fp8",
328351
"nvfp4_mlp_only",
@@ -405,12 +428,7 @@ def forward_step(model, batch):
405428
len(calib_dataloader), max(auto_quantize_score_size // args.batch_size, 1)
406429
),
407430
verbose=True,
408-
# Disable all default disabled layers such as lm_head, mlp.gate, router etc.
409-
disabled_layers=[
410-
entry["quantizer_name"]
411-
for entry in _default_disabled_quantizer_cfg
412-
if "parent_class" not in entry
413-
],
431+
disabled_layers=get_auto_quantize_disabled_layers(language_model),
414432
method=auto_quantize_method,
415433
checkpoint=auto_quantize_checkpoint,
416434
)
@@ -550,12 +568,10 @@ def load_model(args: argparse.Namespace):
550568
: len(args.dataset)
551569
]
552570

553-
# We only quantize the language model for VLMs other than the type supported above.
554-
# Recipe mode is the exception: in Qwen3.5/3.6-MoE VLMs, lm_head sits
555-
# on the outer CausalLM, not the inner language backbone. A recipe that targets
556-
# lm_head must therefore quantize against the full model and explicitly keep visual
557-
# and MTP siblings disabled.
558-
if args.recipe is None:
571+
# Plain PTQ quantizes only the extracted language model. Recipe and
572+
# AutoQuantize paths keep the outer CausalLM so recipes/search can see
573+
# Qwen3.5/3.6-MoE VLM lm_head.
574+
if args.recipe is None and args.auto_quantize_bits is None:
559575
extracted_lm, extracted_model_type = extract_and_prepare_language_model_from_vl(
560576
full_model
561577
)
@@ -1081,9 +1097,16 @@ def _is_layerwise(obj):
10811097
"Auto quantization needs multiple quantization format."
10821098
)
10831099

1100+
# For VL models, autoquant must walk submodules of the OUTER CausalLM
1101+
# (which carries lm_head and the LM-head forward path) — otherwise
1102+
# lm_head and any sibling-of-language_model modules are silently
1103+
# invisible to the search. ``forward_step`` also needs the outer model
1104+
# to produce ``CausalLMOutputWithPast`` (for ``.loss`` / ``.logits``).
1105+
# Visual tower and MTP siblings are auto-excluded inside
1106+
# ``auto_quantize()`` via *visual* / *mtp* / *vision_tower* patterns.
10841107
auto_quantize(
10851108
args,
1086-
language_model,
1109+
full_model,
10871110
calib_dataloader,
10881111
auto_quantize_method=args.auto_quantize_method,
10891112
auto_quantize_score_size=args.auto_quantize_score_size,

modelopt/torch/quantization/_auto_quantize_cost.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,26 @@ def is_routed_moe_module_name(name: str) -> bool:
9090
return "shared_expert" not in name and _ROUTED_MOE_EXPERT_NAME_RE.search(name) is not None
9191

9292

93+
def _get_module_weight_numel(module: nn.Module) -> int:
94+
"""Return the parameter count for a module's quantizable weights.
95+
96+
Standard quantized linear modules have a single ``weight`` parameter. Fused
97+
MoE expert containers expose projection tensors directly instead, so both
98+
fused projections contribute to AutoQuantize cost accounting.
99+
"""
100+
weight = getattr(module, "weight", None)
101+
if weight is not None:
102+
return weight.numel()
103+
104+
# Fused MoE expert containers expose projection tensors directly instead of
105+
# a single ``weight`` parameter.
106+
return sum(
107+
param.numel()
108+
for attr in ("gate_up_proj", "down_proj")
109+
if (param := getattr(module, attr, None)) is not None
110+
)
111+
112+
93113
class AutoQuantizeCostModel:
94114
"""Base class for AutoQuantize effective-bits cost accounting."""
95115

@@ -119,7 +139,7 @@ def total_weight_size(
119139
) -> float:
120140
"""Return the cost denominator for the effective-bits constraint."""
121141
return sum(
122-
module.weight.numel() * self.module_cost_weight([name], cost_constraints)
142+
_get_module_weight_numel(module) * self.module_cost_weight([name], cost_constraints)
123143
for name, module in named_modules
124144
if is_auto_quantize_module(module)
125145
)

modelopt/torch/quantization/algorithms.py

Lines changed: 123 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
AUTO_QUANTIZE_CONSTRAINT_KEYS,
4646
COST_MODEL_ACTIVE_MOE,
4747
COST_MODEL_WEIGHT,
48+
_get_module_weight_numel,
4849
get_auto_quantize_cost_model,
4950
normalize_auto_quantize_constraints,
5051
)
@@ -54,6 +55,59 @@
5455
from .utils import is_quantized_linear
5556

5657

58+
def _is_fused_experts_module(module: nn.Module) -> bool:
59+
"""Return True if ``module`` is a quantized fused-MoE-experts container.
60+
61+
These modules expose plural ``*_input_quantizer`` and ``*_weight_quantizers``
62+
(an ``nn.ModuleList`` of per-expert quantizers) instead of the singular
63+
``input_quantizer`` / ``weight_quantizer`` attrs found on standard
64+
``nn.Linear``-derived QuantModules. AutoQuantize hparam discovery and cost
65+
accounting need to recognize this layout to enumerate fused experts as
66+
search dimensions.
67+
"""
68+
# Late import to avoid a circular import at module load time.
69+
try:
70+
from .plugins.huggingface import _QuantFusedExperts
71+
except ImportError:
72+
return False
73+
return isinstance(module, _QuantFusedExperts)
74+
75+
76+
# Quantizer attribute names that participate in AutoQuantize snapshot/restore.
77+
_STD_QUANTIZER_ATTRS = ("input_quantizer", "weight_quantizer", "output_quantizer")
78+
_FUSED_EXPERTS_QUANTIZER_ATTRS = (
79+
"gate_up_proj_input_quantizer",
80+
"gate_up_proj_weight_quantizers",
81+
"down_proj_input_quantizer",
82+
"down_proj_weight_quantizers",
83+
)
84+
85+
86+
def _get_quantizer_attrs(module: nn.Module) -> tuple[str, ...]:
87+
"""Return the quantizer attribute names that AutoQuantize must snapshot/restore.
88+
89+
For fused MoE experts, this returns the four plural quantizer attrs (two
90+
shared input quantizers + two ``ModuleList`` of per-expert weight quantizers).
91+
For standard Linear-derived QuantModules, returns the canonical trio.
92+
"""
93+
if _is_fused_experts_module(module):
94+
return _FUSED_EXPERTS_QUANTIZER_ATTRS
95+
return _STD_QUANTIZER_ATTRS
96+
97+
98+
def _make_fresh_quantizer_for_attr(module: nn.Module, attr_name: str) -> nn.Module:
99+
"""Return a fresh, default quantizer object suitable to overwrite ``module.<attr_name>``.
100+
101+
For ModuleList attrs (per-expert quantizers on fused-experts modules), the
102+
returned ModuleList preserves the original list length so per-expert
103+
enumeration stays consistent across recipes.
104+
"""
105+
current = getattr(module, attr_name, None)
106+
if isinstance(current, nn.ModuleList):
107+
return nn.ModuleList(TensorQuantizer() for _ in range(len(current)))
108+
return TensorQuantizer()
109+
110+
57111
def estimate_quant_compression(quant_cfg: QuantizeConfig) -> float:
58112
"""Estimate the compression ratio of a quantization configuration.
59113
@@ -231,26 +285,26 @@ def __init__(
231285
# This is a hack; We dont want to make the input_quantizer, weight_quantizer, output_quantizer
232286
# a dynamic attribute for backward compatibility with the model_calib.py
233287
# TODO: Make input_quantizer, weight_quantizer, output_quantizer a dynamic attribute and get rid of this hack
288+
# NOTE: For fused-experts modules, the relevant attrs are plural
289+
# (``*_input_quantizer`` + ``*_weight_quantizers`` ModuleList) — see
290+
# ``_get_quantizer_attrs``. Both layouts share the same snapshot dict
291+
# shape so ``active.setter`` swaps the right child modules.
234292
self._all_quantizer_choices = {quant_recipe: {} for quant_recipe in self.choices}
235293

236294
quant_recipe: QuantRecipe
237295
for quant_recipe in self.choices:
238296
for quant_module in self.quant_modules:
239-
for quantizer_attr_name in [
240-
"input_quantizer",
241-
"weight_quantizer",
242-
"output_quantizer",
243-
]:
244-
setattr(quant_module, quantizer_attr_name, TensorQuantizer())
297+
attr_names = _get_quantizer_attrs(quant_module)
298+
for attr_name in attr_names:
299+
setattr(
300+
quant_module,
301+
attr_name,
302+
_make_fresh_quantizer_for_attr(quant_module, attr_name),
303+
)
245304

246305
set_quantizer_by_cfg(quant_module, quant_recipe.config.quant_cfg)
247306
self._all_quantizer_choices[quant_recipe][quant_module] = {
248-
quantizer_attr_name: getattr(quant_module, quantizer_attr_name)
249-
for quantizer_attr_name in [
250-
"input_quantizer",
251-
"weight_quantizer",
252-
"output_quantizer",
253-
]
307+
attr_name: getattr(quant_module, attr_name) for attr_name in attr_names
254308
}
255309

256310
self.active = self.original
@@ -360,6 +414,20 @@ def attrs(self) -> list[str]:
360414
return ["name", "cost_weight", *super().attrs]
361415

362416

417+
_LINEAR_ATTN_QKVZ_RE = re.compile(r"^(.*?\.linear_attn)\.(?:in_proj_qkv|in_proj_z)$")
418+
_LINEAR_ATTN_BA_RE = re.compile(r"^(.*?\.linear_attn)\.(?:in_proj_a|in_proj_b)$")
419+
420+
421+
def _linear_attn_qkvz_group_key(_model, name: str) -> str | None:
422+
m = _LINEAR_ATTN_QKVZ_RE.match(name)
423+
return f"{m.group(1)}/qkvz" if m else None
424+
425+
426+
def _linear_attn_ba_group_key(_model, name: str) -> str | None:
427+
m = _LINEAR_ATTN_BA_RE.match(name)
428+
return f"{m.group(1)}/ba" if m else None
429+
430+
363431
class _AutoQuantizeBaseSearcher(BaseSearcher, ABC):
364432
"""Base searcher for AutoQuantize algorithm."""
365433

@@ -381,6 +449,13 @@ class _AutoQuantizeBaseSearcher(BaseSearcher, ABC):
381449
r"^(.*?)\.(gate_proj|up_proj)$", # gate_proj, up_proj for llama like models
382450
r"^(.*?)\.(\d+\.(w1|w2|w3))$", # mixtral experts
383451
r"^(.*?)\.((w1_linear|w2_linear|w3_linear)\.\d+)$", # dbrx experts
452+
# Qwen3.5/3.6 hybrid linear_attn: vLLM fuses (in_proj_qkv, in_proj_z)
453+
# into ``in_proj_qkvz`` and (in_proj_a, in_proj_b) into ``in_proj_ba`` and
454+
# requires fused shards to share quant_algo. Two callables (not one
455+
# regex) so qkv+z and a+b produce DIFFERENT group keys; each pair
456+
# stays with its own fusion partner.
457+
_linear_attn_qkvz_group_key,
458+
_linear_attn_ba_group_key,
384459
]
385460

386461
score_module_rules = []
@@ -411,6 +486,7 @@ def default_state_dict(self) -> SearchStateDict:
411486
"cost": {},
412487
"active_moe_expert_ratio": None,
413488
"cost_denominator": None,
489+
"disabled_layers": None,
414490
"candidate_stats": defaultdict(dict),
415491
"quantizer_states": {},
416492
"best": {"recipe": {}, "constraints": {}, "score": float("inf"), "is_satisfied": False},
@@ -433,9 +509,15 @@ def load_search_checkpoint(self) -> bool:
433509

434510
@staticmethod
435511
def _is_auto_quantize_module(module):
436-
return (
437-
is_quantized_linear(module) or isinstance(module, QuantLinearConvBase)
438-
) and isinstance(module, QuantModule)
512+
if (is_quantized_linear(module) or isinstance(module, QuantLinearConvBase)) and isinstance(
513+
module, QuantModule
514+
):
515+
return True
516+
# Fused MoE experts: a single ``QuantModule`` that owns N per-expert
517+
# weight quantizers in an ``nn.ModuleList`` plus shared input quantizers.
518+
# All N experts in a layer share one search dimension (one recipe per
519+
# fused module).
520+
return _is_fused_experts_module(module) and isinstance(module, QuantModule)
439521

440522
@staticmethod
441523
def _get_search_recipes(quantization_formats):
@@ -677,6 +759,7 @@ def before_search(self):
677759
self.cost_model = self.config["cost_model"]
678760
self.cost = self.config["cost"]
679761
self.active_moe_expert_ratio = self.config["active_moe_expert_ratio"]
762+
self.disabled_layers = self.config["disabled_layers"]
680763
self.cost_denominator = getattr(self, "cost_denominator", None)
681764

682765
search_recipes = self._get_search_recipes(self.config["quantization_formats"])
@@ -765,11 +848,9 @@ def _print_recipe_summary(best_recipe, total_cost, total_weight_size, prefix="Au
765848
@staticmethod
766849
def _get_total_weight_size(modules):
767850
return sum(
768-
(
769-
module.weight.numel()
770-
if _AutoQuantizeBaseSearcher._is_auto_quantize_module(module)
771-
else 0
772-
)
851+
_get_module_weight_numel(module)
852+
if _AutoQuantizeBaseSearcher._is_auto_quantize_module(module)
853+
else 0
773854
for module in modules
774855
)
775856

@@ -1372,6 +1453,16 @@ def run_search_with_stats(self, max_weight_size, verbose=False):
13721453
AutoQuantizeSearcher = AutoQuantizeGradientSearcher
13731454

13741455

1456+
def _as_list(value) -> list:
1457+
if value is None:
1458+
return []
1459+
if isinstance(value, list):
1460+
return value
1461+
if isinstance(value, tuple):
1462+
return list(value)
1463+
return [value]
1464+
1465+
13751466
def get_auto_quantize_config(search_state, constraints=None, verbose=False):
13761467
"""Build a flat quant config dict from auto_quantize search_state.
13771468
@@ -1401,6 +1492,11 @@ def _cfg_to_dict(v):
14011492
return v
14021493

14031494
quant_cfg: list[dict] = [{"quantizer_name": "*", "enable": False}]
1495+
quant_cfg.extend(
1496+
{"quantizer_name": pattern, "enable": False}
1497+
for pattern in _as_list(search_state.get("disabled_layers"))
1498+
)
1499+
per_module_entries: list[dict] = []
14041500
_per_module_attrs = ("input_quantizer", "weight_quantizer", "output_quantizer")
14051501
# Track global (non per-module) recipe entries. Last recipe wins for each pattern.
14061502
global_entries: dict[str, dict] = {}
@@ -1421,7 +1517,7 @@ def _cfg_to_dict(v):
14211517
}
14221518
if matched_cfg is not None:
14231519
entry["cfg"] = _cfg_to_dict(matched_cfg)
1424-
quant_cfg.append(entry)
1520+
per_module_entries.append(entry)
14251521

14261522
# Collect non-per-module entries (e.g. *[kv]_bmm_quantizer) from winning recipes.
14271523
for recipe_entry in recipe.config.quant_cfg:
@@ -1438,7 +1534,10 @@ def _cfg_to_dict(v):
14381534
ge["cfg"] = _cfg_to_dict(cfg)
14391535
global_entries[pattern] = ge
14401536

1537+
# Keep path-scoped recipe entries before explicit module entries so selected
1538+
# modules override default disables such as ``*lm_head*``.
14411539
quant_cfg.extend(global_entries.values())
1540+
quant_cfg.extend(per_module_entries)
14421541
warnings.warn(
14431542
"get_auto_quantize_config: returned config uses algorithm='max'. "
14441543
"Per-recipe calibration algorithms (e.g. smoothquant, awq) are not preserved. "
@@ -1502,6 +1601,9 @@ def _match_quantizer_cfg(quant_cfg, quantizer_attr):
15021601
matched = None
15031602
matched_enable = None
15041603
for entry in quant_cfg:
1604+
parent_class = entry.get("parent_class") if hasattr(entry, "get") else entry.parent_class
1605+
if parent_class is not None:
1606+
continue
15051607
pattern = entry["quantizer_name"]
15061608
cfg = entry.get("cfg")
15071609
enable = entry.get("enable", True)

0 commit comments

Comments
 (0)