Skip to content

Commit 2828faa

Browse files
committed
Add AutoQuant support for VLMs
Signed-off-by: weimingc <17592131+meenchen@users.noreply.github.com>
1 parent 902d369 commit 2828faa

4 files changed

Lines changed: 241 additions & 36 deletions

File tree

examples/llm_ptq/hf_ptq.py

Lines changed: 35 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,28 @@ def _set_kv_cache_constant_amax(quant_cfg: list) -> None:
140140
mto.enable_huggingface_checkpointing()
141141

142142

143+
# TODO: To be refacored into config system.
144+
_QWEN36_AUTOQ_DISABLED_LAYERS = (
145+
"*shared_expert_gate*",
146+
"*linear_attn.in_proj_a*",
147+
"*linear_attn.in_proj_b*",
148+
)
149+
_VLM_AUTOQ_DISABLED_LAYERS = ("*visual*", "*mtp*", "*vision_tower*")
150+
151+
152+
def get_auto_quantize_disabled_layers(model) -> list[str]:
153+
"""Return layer patterns that should be excluded from AutoQuantize search."""
154+
disabled_layers = [
155+
entry["quantizer_name"]
156+
for entry in _default_disabled_quantizer_cfg
157+
if "parent_class" not in entry and entry["quantizer_name"] != "*lm_head*"
158+
]
159+
disabled_layers.extend(p for p in _QWEN36_AUTOQ_DISABLED_LAYERS if p not in disabled_layers)
160+
if is_multimodal_model(model):
161+
disabled_layers.extend(p for p in _VLM_AUTOQ_DISABLED_LAYERS if p not in disabled_layers)
162+
return disabled_layers
163+
164+
143165
def extract_and_prepare_language_model_from_vl(full_model):
144166
"""Extract language model from VL model and disable quantization for non-language components.
145167
@@ -323,6 +345,7 @@ def auto_quantize(
323345
"nvfp4_awq",
324346
"nvfp4_mse",
325347
"w4a8_awq",
348+
"w4a16_nvfp4",
326349
"fp8_pb_wo",
327350
"w4a8_mxfp4_fp8",
328351
"nvfp4_mlp_only",
@@ -405,12 +428,7 @@ def forward_step(model, batch):
405428
len(calib_dataloader), max(auto_quantize_score_size // args.batch_size, 1)
406429
),
407430
verbose=True,
408-
# Disable all default disabled layers such as lm_head, mlp.gate, router etc.
409-
disabled_layers=[
410-
entry["quantizer_name"]
411-
for entry in _default_disabled_quantizer_cfg
412-
if "parent_class" not in entry
413-
],
431+
disabled_layers=get_auto_quantize_disabled_layers(language_model),
414432
method=auto_quantize_method,
415433
checkpoint=auto_quantize_checkpoint,
416434
)
@@ -550,12 +568,9 @@ def load_model(args: argparse.Namespace):
550568
: len(args.dataset)
551569
]
552570

553-
# We only quantize the language model for VLMs other than the type supported above.
554-
# Recipe mode is the exception: in Qwen3.5/3.6-MoE VLMs, lm_head sits
555-
# on the outer CausalLM, not the inner language backbone. A recipe that targets
556-
# lm_head must therefore quantize against the full model and explicitly keep visual
557-
# and MTP siblings disabled.
558-
if args.recipe is None:
571+
# AutoQuantize walks the outer CausalLM so lm_head is visible to the
572+
# search. Visual/MTP siblings are excluded by disabled-layer patterns.
573+
if args.auto_quantize_bits is None:
559574
extracted_lm, extracted_model_type = extract_and_prepare_language_model_from_vl(
560575
full_model
561576
)
@@ -1081,9 +1096,16 @@ def _is_layerwise(obj):
10811096
"Auto quantization needs multiple quantization format."
10821097
)
10831098

1099+
# For VL models, autoquant must walk submodules of the OUTER CausalLM
1100+
# (which carries lm_head and the LM-head forward path) — otherwise
1101+
# lm_head and any sibling-of-language_model modules are silently
1102+
# invisible to the search. ``forward_step`` also needs the outer model
1103+
# to produce ``CausalLMOutputWithPast`` (for ``.loss`` / ``.logits``).
1104+
# Visual tower and MTP siblings are auto-excluded inside
1105+
# ``auto_quantize()`` via *visual* / *mtp* / *vision_tower* patterns.
10841106
auto_quantize(
10851107
args,
1086-
language_model,
1108+
full_model,
10871109
calib_dataloader,
10881110
auto_quantize_method=args.auto_quantize_method,
10891111
auto_quantize_score_size=args.auto_quantize_score_size,

modelopt/torch/quantization/_auto_quantize_cost.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,21 @@ def is_routed_moe_module_name(name: str) -> bool:
9090
return "shared_expert" not in name and _ROUTED_MOE_EXPERT_NAME_RE.search(name) is not None
9191

9292

93+
def _get_module_weight_numel(module: nn.Module) -> int:
94+
"""Return the parameter count for a module's quantizable weights."""
95+
weight = getattr(module, "weight", None)
96+
if weight is not None:
97+
return weight.numel()
98+
99+
# Fused MoE expert containers expose projection tensors directly instead of
100+
# a single ``weight`` parameter.
101+
return sum(
102+
param.numel()
103+
for attr in ("gate_up_proj", "down_proj")
104+
if (param := getattr(module, attr, None)) is not None
105+
)
106+
107+
93108
class AutoQuantizeCostModel:
94109
"""Base class for AutoQuantize effective-bits cost accounting."""
95110

@@ -119,7 +134,7 @@ def total_weight_size(
119134
) -> float:
120135
"""Return the cost denominator for the effective-bits constraint."""
121136
return sum(
122-
module.weight.numel() * self.module_cost_weight([name], cost_constraints)
137+
_get_module_weight_numel(module) * self.module_cost_weight([name], cost_constraints)
123138
for name, module in named_modules
124139
if is_auto_quantize_module(module)
125140
)

modelopt/torch/quantization/algorithms.py

Lines changed: 140 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,77 @@
5454
from .utils import is_quantized_linear
5555

5656

57+
def _is_fused_experts_module(module: nn.Module) -> bool:
58+
"""Return True if ``module`` is a quantized fused-MoE-experts container.
59+
60+
These modules expose plural ``*_input_quantizer`` and ``*_weight_quantizers``
61+
(an ``nn.ModuleList`` of per-expert quantizers) instead of the singular
62+
``input_quantizer`` / ``weight_quantizer`` attrs found on standard
63+
``nn.Linear``-derived QuantModules. AutoQuantize hparam discovery and cost
64+
accounting need to recognize this layout to enumerate fused experts as
65+
search dimensions.
66+
"""
67+
# Late import to avoid a circular import at module load time.
68+
try:
69+
from .plugins.huggingface import _QuantFusedExperts
70+
except ImportError:
71+
return False
72+
return isinstance(module, _QuantFusedExperts)
73+
74+
75+
# Quantizer attribute names that participate in AutoQuantize snapshot/restore.
76+
_STD_QUANTIZER_ATTRS = ("input_quantizer", "weight_quantizer", "output_quantizer")
77+
_FUSED_EXPERTS_QUANTIZER_ATTRS = (
78+
"gate_up_proj_input_quantizer",
79+
"gate_up_proj_weight_quantizers",
80+
"down_proj_input_quantizer",
81+
"down_proj_weight_quantizers",
82+
)
83+
84+
85+
def _get_quantizer_attrs(module: nn.Module) -> tuple[str, ...]:
86+
"""Return the quantizer attribute names that AutoQuantize must snapshot/restore.
87+
88+
For fused MoE experts, this returns the four plural quantizer attrs (two
89+
shared input quantizers + two ``ModuleList`` of per-expert weight quantizers).
90+
For standard Linear-derived QuantModules, returns the canonical trio.
91+
"""
92+
if _is_fused_experts_module(module):
93+
return _FUSED_EXPERTS_QUANTIZER_ATTRS
94+
return _STD_QUANTIZER_ATTRS
95+
96+
97+
def _make_fresh_quantizer_for_attr(module: nn.Module, attr_name: str) -> nn.Module:
98+
"""Return a fresh, default quantizer object suitable to overwrite ``module.<attr_name>``.
99+
100+
For ModuleList attrs (per-expert quantizers on fused-experts modules), the
101+
returned ModuleList preserves the original list length so per-expert
102+
enumeration stays consistent across recipes.
103+
"""
104+
current = getattr(module, attr_name, None)
105+
if isinstance(current, nn.ModuleList):
106+
return nn.ModuleList(TensorQuantizer() for _ in range(len(current)))
107+
return TensorQuantizer()
108+
109+
110+
def _get_module_weight_numel(module: nn.Module) -> int:
111+
"""Return the total parameter count of a module's quantizable weights.
112+
113+
Standard QuantLinear modules have a single ``weight`` parameter. Fused
114+
experts modules have two 3-D fused parameters (``gate_up_proj`` and
115+
``down_proj``) instead — both contribute to the cost accounting.
116+
"""
117+
if _is_fused_experts_module(module):
118+
total = 0
119+
for attr in ("gate_up_proj", "down_proj"):
120+
param = getattr(module, attr, None)
121+
if param is not None:
122+
total += param.numel()
123+
return total
124+
weight = getattr(module, "weight", None)
125+
return weight.numel() if weight is not None else 0
126+
127+
57128
def estimate_quant_compression(quant_cfg: QuantizeConfig) -> float:
58129
"""Estimate the compression ratio of a quantization configuration.
59130
@@ -231,26 +302,26 @@ def __init__(
231302
# This is a hack; We dont want to make the input_quantizer, weight_quantizer, output_quantizer
232303
# a dynamic attribute for backward compatibility with the model_calib.py
233304
# TODO: Make input_quantizer, weight_quantizer, output_quantizer a dynamic attribute and get rid of this hack
305+
# NOTE: For fused-experts modules, the relevant attrs are plural
306+
# (``*_input_quantizer`` + ``*_weight_quantizers`` ModuleList) — see
307+
# ``_get_quantizer_attrs``. Both layouts share the same snapshot dict
308+
# shape so ``active.setter`` swaps the right child modules.
234309
self._all_quantizer_choices = {quant_recipe: {} for quant_recipe in self.choices}
235310

236311
quant_recipe: QuantRecipe
237312
for quant_recipe in self.choices:
238313
for quant_module in self.quant_modules:
239-
for quantizer_attr_name in [
240-
"input_quantizer",
241-
"weight_quantizer",
242-
"output_quantizer",
243-
]:
244-
setattr(quant_module, quantizer_attr_name, TensorQuantizer())
314+
attr_names = _get_quantizer_attrs(quant_module)
315+
for attr_name in attr_names:
316+
setattr(
317+
quant_module,
318+
attr_name,
319+
_make_fresh_quantizer_for_attr(quant_module, attr_name),
320+
)
245321

246322
set_quantizer_by_cfg(quant_module, quant_recipe.config.quant_cfg)
247323
self._all_quantizer_choices[quant_recipe][quant_module] = {
248-
quantizer_attr_name: getattr(quant_module, quantizer_attr_name)
249-
for quantizer_attr_name in [
250-
"input_quantizer",
251-
"weight_quantizer",
252-
"output_quantizer",
253-
]
324+
attr_name: getattr(quant_module, attr_name) for attr_name in attr_names
254325
}
255326

256327
self.active = self.original
@@ -360,6 +431,20 @@ def attrs(self) -> list[str]:
360431
return ["name", "cost_weight", *super().attrs]
361432

362433

434+
_LINEAR_ATTN_QKVZ_RE = re.compile(r"^(.*?\.linear_attn)\.(?:in_proj_qkv|in_proj_z)$")
435+
_LINEAR_ATTN_BA_RE = re.compile(r"^(.*?\.linear_attn)\.(?:in_proj_a|in_proj_b)$")
436+
437+
438+
def _linear_attn_qkvz_group_key(_model, name: str) -> str | None:
439+
m = _LINEAR_ATTN_QKVZ_RE.match(name)
440+
return f"{m.group(1)}/qkvz" if m else None
441+
442+
443+
def _linear_attn_ba_group_key(_model, name: str) -> str | None:
444+
m = _LINEAR_ATTN_BA_RE.match(name)
445+
return f"{m.group(1)}/ba" if m else None
446+
447+
363448
class _AutoQuantizeBaseSearcher(BaseSearcher, ABC):
364449
"""Base searcher for AutoQuantize algorithm."""
365450

@@ -381,6 +466,13 @@ class _AutoQuantizeBaseSearcher(BaseSearcher, ABC):
381466
r"^(.*?)\.(gate_proj|up_proj)$", # gate_proj, up_proj for llama like models
382467
r"^(.*?)\.(\d+\.(w1|w2|w3))$", # mixtral experts
383468
r"^(.*?)\.((w1_linear|w2_linear|w3_linear)\.\d+)$", # dbrx experts
469+
# Qwen3.5/3.6 hybrid linear_attn: vLLM fuses (in_proj_qkv, in_proj_z)
470+
# into ``in_proj_qkvz`` and (in_proj_a, in_proj_b) into ``in_proj_ba`` and
471+
# requires fused shards to share quant_algo. Two callables (not one
472+
# regex) so qkv+z and a+b produce DIFFERENT group keys; each pair
473+
# stays with its own fusion partner.
474+
_linear_attn_qkvz_group_key,
475+
_linear_attn_ba_group_key,
384476
]
385477

386478
score_module_rules = []
@@ -411,6 +503,7 @@ def default_state_dict(self) -> SearchStateDict:
411503
"cost": {},
412504
"active_moe_expert_ratio": None,
413505
"cost_denominator": None,
506+
"disabled_layers": None,
414507
"candidate_stats": defaultdict(dict),
415508
"quantizer_states": {},
416509
"best": {"recipe": {}, "constraints": {}, "score": float("inf"), "is_satisfied": False},
@@ -433,9 +526,15 @@ def load_search_checkpoint(self) -> bool:
433526

434527
@staticmethod
435528
def _is_auto_quantize_module(module):
436-
return (
437-
is_quantized_linear(module) or isinstance(module, QuantLinearConvBase)
438-
) and isinstance(module, QuantModule)
529+
if (is_quantized_linear(module) or isinstance(module, QuantLinearConvBase)) and isinstance(
530+
module, QuantModule
531+
):
532+
return True
533+
# Fused MoE experts: a single ``QuantModule`` that owns N per-expert
534+
# weight quantizers in an ``nn.ModuleList`` plus shared input quantizers.
535+
# All N experts in a layer share one search dimension (one recipe per
536+
# fused module).
537+
return _is_fused_experts_module(module) and isinstance(module, QuantModule)
439538

440539
@staticmethod
441540
def _get_search_recipes(quantization_formats):
@@ -677,6 +776,7 @@ def before_search(self):
677776
self.cost_model = self.config["cost_model"]
678777
self.cost = self.config["cost"]
679778
self.active_moe_expert_ratio = self.config["active_moe_expert_ratio"]
779+
self.disabled_layers = self.config["disabled_layers"]
680780
self.cost_denominator = getattr(self, "cost_denominator", None)
681781

682782
search_recipes = self._get_search_recipes(self.config["quantization_formats"])
@@ -765,11 +865,9 @@ def _print_recipe_summary(best_recipe, total_cost, total_weight_size, prefix="Au
765865
@staticmethod
766866
def _get_total_weight_size(modules):
767867
return sum(
768-
(
769-
module.weight.numel()
770-
if _AutoQuantizeBaseSearcher._is_auto_quantize_module(module)
771-
else 0
772-
)
868+
_get_module_weight_numel(module)
869+
if _AutoQuantizeBaseSearcher._is_auto_quantize_module(module)
870+
else 0
773871
for module in modules
774872
)
775873

@@ -1372,6 +1470,16 @@ def run_search_with_stats(self, max_weight_size, verbose=False):
13721470
AutoQuantizeSearcher = AutoQuantizeGradientSearcher
13731471

13741472

1473+
def _as_list(value) -> list:
1474+
if value is None:
1475+
return []
1476+
if isinstance(value, list):
1477+
return value
1478+
if isinstance(value, tuple):
1479+
return list(value)
1480+
return [value]
1481+
1482+
13751483
def get_auto_quantize_config(search_state, constraints=None, verbose=False):
13761484
"""Build a flat quant config dict from auto_quantize search_state.
13771485
@@ -1401,6 +1509,11 @@ def _cfg_to_dict(v):
14011509
return v
14021510

14031511
quant_cfg: list[dict] = [{"quantizer_name": "*", "enable": False}]
1512+
quant_cfg.extend(
1513+
{"quantizer_name": pattern, "enable": False}
1514+
for pattern in _as_list(search_state.get("disabled_layers"))
1515+
)
1516+
per_module_entries: list[dict] = []
14041517
_per_module_attrs = ("input_quantizer", "weight_quantizer", "output_quantizer")
14051518
# Track global (non per-module) recipe entries. Last recipe wins for each pattern.
14061519
global_entries: dict[str, dict] = {}
@@ -1421,7 +1534,7 @@ def _cfg_to_dict(v):
14211534
}
14221535
if matched_cfg is not None:
14231536
entry["cfg"] = _cfg_to_dict(matched_cfg)
1424-
quant_cfg.append(entry)
1537+
per_module_entries.append(entry)
14251538

14261539
# Collect non-per-module entries (e.g. *[kv]_bmm_quantizer) from winning recipes.
14271540
for recipe_entry in recipe.config.quant_cfg:
@@ -1438,7 +1551,10 @@ def _cfg_to_dict(v):
14381551
ge["cfg"] = _cfg_to_dict(cfg)
14391552
global_entries[pattern] = ge
14401553

1554+
# Keep path-scoped recipe entries before explicit module entries so selected
1555+
# modules override default disables such as ``*lm_head*``.
14411556
quant_cfg.extend(global_entries.values())
1557+
quant_cfg.extend(per_module_entries)
14421558
warnings.warn(
14431559
"get_auto_quantize_config: returned config uses algorithm='max'. "
14441560
"Per-recipe calibration algorithms (e.g. smoothquant, awq) are not preserved. "
@@ -1502,6 +1618,9 @@ def _match_quantizer_cfg(quant_cfg, quantizer_attr):
15021618
matched = None
15031619
matched_enable = None
15041620
for entry in quant_cfg:
1621+
parent_class = entry.get("parent_class") if hasattr(entry, "get") else entry.parent_class
1622+
if parent_class is not None:
1623+
continue
15051624
pattern = entry["quantizer_name"]
15061625
cfg = entry.get("cfg")
15071626
enable = entry.get("enable", True)

0 commit comments

Comments
 (0)