Skip to content

Commit 6bca67e

Browse files
committed
audoquantize for VLM
Signed-off-by: weimingc <17592131+meenchen@users.noreply.github.com>
1 parent 077e29a commit 6bca67e

3 files changed

Lines changed: 272 additions & 34 deletions

File tree

examples/llm_ptq/hf_ptq.py

Lines changed: 141 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
import argparse
1717
import copy
18+
import os
1819
import random
1920
import time
2021
import warnings
@@ -137,6 +138,43 @@ def _set_kv_cache_constant_amax(quant_cfg: list) -> None:
137138
mto.enable_huggingface_checkpointing()
138139

139140

141+
NVFP4_W4A16_CFG = {
142+
"quant_cfg": [
143+
{"quantizer_name": "*", "enable": False},
144+
{
145+
"quantizer_name": "*weight_quantizer",
146+
"cfg": {
147+
"num_bits": (2, 1),
148+
"block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)},
149+
},
150+
},
151+
{"quantizer_name": "*input_quantizer", "enable": False},
152+
*_default_disabled_quantizer_cfg,
153+
],
154+
"algorithm": "max",
155+
}
156+
157+
FP8_W8A16_CFG = {
158+
"quant_cfg": [
159+
{"quantizer_name": "*", "enable": False},
160+
{
161+
"quantizer_name": "*weight_quantizer",
162+
"cfg": {"num_bits": (4, 3), "axis": None},
163+
},
164+
{"quantizer_name": "*input_quantizer", "enable": False},
165+
*_default_disabled_quantizer_cfg,
166+
],
167+
"algorithm": "max",
168+
}
169+
170+
QUANT_CFG_CHOICES.update(
171+
{
172+
"nvfp4_w4a16": NVFP4_W4A16_CFG,
173+
"fp8_w8a16": FP8_W8A16_CFG,
174+
}
175+
)
176+
177+
140178
def extract_and_prepare_language_model_from_vl(full_model):
141179
"""Extract language model from VL model and disable quantization for non-language components.
142180
@@ -326,6 +364,8 @@ def auto_quantize(
326364
"nvfp4_omlp_only",
327365
"nvfp4_local_hessian",
328366
"mxfp8",
367+
"nvfp4_w4a16",
368+
"fp8_w8a16",
329369
]
330370
for qformat in qformat_list
331371
), "One or more quantization formats provided are not supported for unified checkpoint export"
@@ -348,6 +388,38 @@ def forward_step(model, batch):
348388
f"Invalid auto_quantize_method: {auto_quantize_method}. Must be 'gradient' or 'kl_div'"
349389
)
350390

391+
# Let AutoQuantize search lm_head, but keep modules out that vLLM either
392+
# constructs as BF16-only paths today or has known unsafe fused dispatch for.
393+
disabled_layers = [
394+
entry["quantizer_name"]
395+
for entry in _default_disabled_quantizer_cfg
396+
if "parent_class" not in entry and entry["quantizer_name"] != "*lm_head*"
397+
]
398+
enable_linear_attn_big3 = os.environ.get("MODELOPT_AUTOQ_ENABLE_LINEAR_ATTN_BIG3") == "1"
399+
enable_shared_expert = os.environ.get("MODELOPT_AUTOQ_ENABLE_SHARED_EXPERT") == "1"
400+
autoq_extra_disabled = [
401+
"*shared_expert_gate*",
402+
"*linear_attn.in_proj_a*",
403+
"*linear_attn.in_proj_b*",
404+
]
405+
if not enable_shared_expert:
406+
autoq_extra_disabled.append("*mlp.shared_expert*")
407+
if not enable_linear_attn_big3:
408+
autoq_extra_disabled.extend(
409+
[
410+
"*linear_attn.in_proj_qkv*",
411+
"*linear_attn.in_proj_z*",
412+
"*linear_attn.out_proj*",
413+
]
414+
)
415+
for pat in autoq_extra_disabled:
416+
if pat not in disabled_layers:
417+
disabled_layers.append(pat)
418+
if is_multimodal_model(language_model):
419+
for pat in ("*visual*", "*mtp*", "*vision_tower*"):
420+
if pat not in disabled_layers:
421+
disabled_layers.append(pat)
422+
351423
language_model, _ = mtq.auto_quantize(
352424
language_model,
353425
constraints={"effective_bits": args.auto_quantize_bits},
@@ -362,12 +434,7 @@ def forward_step(model, batch):
362434
len(calib_dataloader), max(auto_quantize_score_size // args.batch_size, 1)
363435
),
364436
verbose=True,
365-
# Disable all default disabled layers such as lm_head, mlp.gate, router etc.
366-
disabled_layers=[
367-
entry["quantizer_name"]
368-
for entry in _default_disabled_quantizer_cfg
369-
if "parent_class" not in entry
370-
],
437+
disabled_layers=disabled_layers,
371438
method=auto_quantize_method,
372439
checkpoint=auto_quantize_checkpoint,
373440
)
@@ -507,12 +574,26 @@ def load_model(args: argparse.Namespace):
507574
]
508575

509576
# We only quantize the language model for VLMs other than the type supported above.
510-
extracted_lm, extracted_model_type = extract_and_prepare_language_model_from_vl(
511-
full_model
512-
)
513-
if extracted_lm is not None:
514-
language_model = extracted_lm
515-
model_type = extracted_model_type
577+
# For AutoQuantize, skip the eager visual-disable side-effect: it
578+
# registers ``modelopt`` state on each visual sibling, which
579+
# ``mtq.auto_quantize → apply_mode → is_converted`` then trips on
580+
# ("Model has multiple modelopt states!"). AutoQuantize handles
581+
# visual/mtp via ``disabled_layers`` patterns instead, so the
582+
# extraction is unnecessary for that path.
583+
#
584+
# For ``--recipe`` mode on a VLM, lm_head sits on the OUTER
585+
# CausalLM. Recipe rules can't see it via the inner language
586+
# backbone, so we keep ``language_model = full_model`` here and
587+
# let ``quantize_main`` strip visual/mtp siblings around
588+
# ``mtq.quantize`` (so registration/calibration stays fast and
589+
# batch_size auto-detect doesn't collapse to 1).
590+
if args.auto_quantize_bits is None and args.recipe is None:
591+
extracted_lm, extracted_model_type = extract_and_prepare_language_model_from_vl(
592+
full_model
593+
)
594+
if extracted_lm is not None:
595+
language_model = extracted_lm
596+
model_type = extracted_model_type
516597

517598
tokenizer = get_tokenizer(args.pyt_ckpt_path, trust_remote_code=args.trust_remote_code)
518599

@@ -628,13 +709,52 @@ def mono_quantize(
628709
else None,
629710
)
630711

712+
# When ``--recipe`` is given on a VLM we keep ``language_model =
713+
# full_model`` (so recipe rules can match lm_head) but ``mtq.quantize``
714+
# would otherwise walk and register quantizers on every Linear in the
715+
# visual encoder + MTP head.
716+
stripped_vlm_modules: dict[str, torch.nn.Module] = {}
717+
if args.recipe is not None and language_model is full_model:
718+
for path in ("model.visual", "mtp"):
719+
parts = path.split(".")
720+
parent = full_model
721+
ok = True
722+
for p in parts[:-1]:
723+
if not hasattr(parent, p):
724+
ok = False
725+
break
726+
parent = getattr(parent, p)
727+
if ok and hasattr(parent, parts[-1]):
728+
mod = getattr(parent, parts[-1])
729+
if mod is not None and isinstance(mod, torch.nn.Module):
730+
stripped_vlm_modules[path] = mod
731+
setattr(parent, parts[-1], None)
732+
if stripped_vlm_modules:
733+
print(
734+
"[recipe] stripped VLM siblings before mtq.quantize: "
735+
+ ", ".join(stripped_vlm_modules.keys())
736+
)
737+
631738
if calibration_only:
632739
language_model = mtq.calibrate(
633740
language_model, quant_cfg["algorithm"], forward_loop=calibrate_loop
634741
)
635742
else:
636743
language_model = mtq.quantize(language_model, quant_cfg, forward_loop=calibrate_loop)
637744

745+
# Restore stripped VLM siblings so export sees the full model.
746+
for path, mod in stripped_vlm_modules.items():
747+
parts = path.split(".")
748+
parent = full_model
749+
for p in parts[:-1]:
750+
parent = getattr(parent, p)
751+
setattr(parent, parts[-1], mod)
752+
if stripped_vlm_modules:
753+
print(
754+
"[recipe] restored VLM siblings after mtq.quantize: "
755+
+ ", ".join(stripped_vlm_modules.keys())
756+
)
757+
638758
# For VL models, update full_model to use the quantized language model
639759
if is_nemotron_vl_model:
640760
language_model_lineage = get_language_model_from_vl(full_model)
@@ -1018,10 +1138,18 @@ def quantize_main(
10181138
"Auto quantization needs multiple quantization format."
10191139
)
10201140

1141+
# For VL models, autoquant must walk submodules of the OUTER CausalLM
1142+
# (which carries lm_head and the LM-head forward path) — otherwise
1143+
# lm_head and any sibling-of-language_model modules are silently
1144+
# invisible to the search. ``forward_step`` also needs the outer model
1145+
# to produce ``CausalLMOutputWithPast`` (for ``.loss`` / ``.logits``).
1146+
# Visual tower and MTP siblings are auto-excluded inside
1147+
# ``auto_quantize()`` via *visual* / *mtp* / *vision_tower* patterns.
10211148
auto_quantize(
10221149
args,
1023-
language_model,
1150+
full_model,
10241151
calib_dataloader,
1152+
auto_quantize_method=args.auto_quantize_method,
10251153
)
10261154

10271155
else:

modelopt/torch/export/quant_utils.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1354,9 +1354,23 @@ def preprocess_linear_fusion(modules: list[torch.nn.Module], resmooth_only=False
13541354
"""Preprocess the quantized linears that we plan to fuse.
13551355
13561356
Use resmooth_only for MOE experts as each individual expert is not fused.
1357+
1358+
When the modules carry mismatched quantization formats — most often after
1359+
AutoQuantize picks different formats for layers that share input but were
1360+
not coalesced into a single search group — we cannot coalesce them into a
1361+
fused linear. In that case, fall back to skipping the fusion so each linear
1362+
exports independently with its own format, instead of asserting.
13571363
"""
13581364
quantization_format_list = [get_quantization_format(module) for module in modules]
1359-
assert all_items_same(quantization_format_list), "Modules have different quantization formats"
1365+
if not all_items_same(quantization_format_list):
1366+
warn(
1367+
"preprocess_linear_fusion: modules in this fusion group have mixed "
1368+
f"quantization formats {quantization_format_list}. Skipping fusion; "
1369+
"each linear will export with its own format. Common cause: "
1370+
"AutoQuantize assigned different formats to fusion-mate linears.",
1371+
stacklevel=2,
1372+
)
1373+
return
13601374

13611375
# Activation
13621376
if hasattr(modules[0], "input_quantizer"):

0 commit comments

Comments
 (0)