Skip to content

Commit 2a21f66

Browse files
committed
Add AutoQuant support for VLMs
Signed-off-by: weimingc <17592131+meenchen@users.noreply.github.com>
1 parent 902d369 commit 2a21f66

10 files changed

Lines changed: 308 additions & 46 deletions

File tree

examples/llm_ptq/hf_ptq.py

Lines changed: 28 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@
6666
save_expert_token_count_table,
6767
)
6868
from modelopt.torch.export.model_utils import get_language_model_from_vl, is_multimodal_model
69-
from modelopt.torch.quantization.config import _default_disabled_quantizer_cfg, need_calibration
69+
from modelopt.torch.quantization.config import need_calibration
7070
from modelopt.torch.quantization.plugins.accelerate import init_quantized_weights
7171
from modelopt.torch.quantization.utils import is_quantized
7272
from modelopt.torch.speculative.eagle.utils import (
@@ -292,6 +292,7 @@ def auto_quantize(
292292
args: argparse.Namespace,
293293
language_model: torch.nn.Module,
294294
calib_dataloader: DataLoader,
295+
recipe: ModelOptPTQRecipe | None = None,
295296
auto_quantize_method="gradient",
296297
auto_quantize_score_size=128,
297298
auto_quantize_checkpoint=None,
@@ -323,6 +324,7 @@ def auto_quantize(
323324
"nvfp4_awq",
324325
"nvfp4_mse",
325326
"w4a8_awq",
327+
"w4a16_nvfp4",
326328
"fp8_pb_wo",
327329
"w4a8_mxfp4_fp8",
328330
"nvfp4_mlp_only",
@@ -391,6 +393,10 @@ def forward_step(model, batch):
391393
"active_moe_expert_ratio": args.auto_quantize_active_moe_expert_ratio
392394
}
393395

396+
disabled_layers = recipe.quantize.disabled_layers if recipe is not None else None
397+
if disabled_layers:
398+
print(f"AutoQuantize disabled layers from recipe: {disabled_layers}")
399+
394400
language_model, _ = mtq.auto_quantize(
395401
language_model,
396402
constraints=auto_quantize_constraints,
@@ -405,12 +411,7 @@ def forward_step(model, batch):
405411
len(calib_dataloader), max(auto_quantize_score_size // args.batch_size, 1)
406412
),
407413
verbose=True,
408-
# Disable all default disabled layers such as lm_head, mlp.gate, router etc.
409-
disabled_layers=[
410-
entry["quantizer_name"]
411-
for entry in _default_disabled_quantizer_cfg
412-
if "parent_class" not in entry
413-
],
414+
disabled_layers=disabled_layers,
414415
method=auto_quantize_method,
415416
checkpoint=auto_quantize_checkpoint,
416417
)
@@ -550,12 +551,9 @@ def load_model(args: argparse.Namespace):
550551
: len(args.dataset)
551552
]
552553

553-
# We only quantize the language model for VLMs other than the type supported above.
554-
# Recipe mode is the exception: in Qwen3.5/3.6-MoE VLMs, lm_head sits
555-
# on the outer CausalLM, not the inner language backbone. A recipe that targets
556-
# lm_head must therefore quantize against the full model and explicitly keep visual
557-
# and MTP siblings disabled.
558-
if args.recipe is None:
554+
# AutoQuantize walks the outer CausalLM so lm_head is visible to the
555+
# search. Visual/MTP siblings are excluded by disabled-layer patterns.
556+
if args.auto_quantize_bits is None:
559557
extracted_lm, extracted_model_type = extract_and_prepare_language_model_from_vl(
560558
full_model
561559
)
@@ -994,9 +992,10 @@ def quantize_main(
994992
default_pad_token,
995993
device: torch.device,
996994
):
997-
# Load the recipe up front so we can detect layerwise calibration before batch-size probing.
995+
# Load the recipe up front so we can detect layerwise calibration before batch-size probing
996+
# and read AutoQuantize search metadata such as disabled_layers.
998997
recipe = None
999-
if args.recipe is not None and not args.auto_quantize_bits:
998+
if args.recipe is not None:
1000999
print(f"Use recipe {args.recipe} for quantization")
10011000
recipe = load_recipe(args.recipe)
10021001
if not isinstance(recipe, ModelOptPTQRecipe):
@@ -1081,10 +1080,18 @@ def _is_layerwise(obj):
10811080
"Auto quantization needs multiple quantization format."
10821081
)
10831082

1083+
# For VL models, autoquant must walk submodules of the OUTER CausalLM
1084+
# (which carries lm_head and the LM-head forward path) — otherwise
1085+
# lm_head and any sibling-of-language_model modules are silently
1086+
# invisible to the search. ``forward_step`` also needs the outer model
1087+
# to produce ``CausalLMOutputWithPast`` (for ``.loss`` / ``.logits``).
1088+
# Visual tower and MTP siblings are auto-excluded inside
1089+
# ``auto_quantize()`` via *visual* / *mtp* / *vision_tower* patterns.
10841090
auto_quantize(
10851091
args,
1086-
language_model,
1092+
full_model,
10871093
calib_dataloader,
1094+
recipe=recipe,
10881095
auto_quantize_method=args.auto_quantize_method,
10891096
auto_quantize_score_size=args.auto_quantize_score_size,
10901097
auto_quantize_checkpoint=args.auto_quantize_checkpoint,
@@ -1209,7 +1216,9 @@ def parse_args() -> argparse.Namespace:
12091216
help=(
12101217
"PTQ recipe YAML file or name without suffix (e.g. general/ptq/fp8_default-kv_fp8_cast, "
12111218
"general/ptq/nvfp4_default-kv_fp8_cast, general/ptq/nvfp4_default-kv_nvfp4_cast). "
1212-
"When set, --kv_cache_qformat is ignored; the recipe fully determines KV cache config."
1219+
"For plain PTQ, the recipe fully determines the quantization config and --kv_cache_qformat "
1220+
"is ignored. For AutoQuantize, --qformat still determines the search formats while the "
1221+
"recipe may provide search metadata such as quantize.disabled_layers."
12131222
),
12141223
default=None,
12151224
)
@@ -1299,8 +1308,8 @@ def parse_args() -> argparse.Namespace:
12991308
"Formats ending in '_cast' (fp8_cast, nvfp4_cast) set the amax to FP8 range "
13001309
"without data-driven calibration. "
13011310
"Other formats (fp8, nvfp4, etc.) use data-driven calibration. "
1302-
"Ignored when --recipe is given: the recipe YAML is authoritative for KV "
1303-
"cache config (use the *_cast_kv.yaml recipes for the cast variants)."
1311+
"Ignored for plain PTQ when --recipe is given because the recipe YAML is authoritative "
1312+
"for KV cache config (use the *_cast_kv.yaml recipes for the cast variants)."
13041313
),
13051314
)
13061315
parser.add_argument(

examples/llm_ptq/scripts/parser.sh

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -103,14 +103,16 @@ parse_options() {
103103
# Verify required options are provided
104104
if [ -z "$MODEL_PATH" ] || [ -z "$TASKS" ] || ([ -z "$QFORMAT" ] && [ -z "$RECIPE" ]); then
105105
echo "Usage: $0 --model=<MODEL_PATH> (--quant=<QFORMAT> | --recipe=<RECIPE>) --tasks=<TASK,...>"
106+
echo " AutoQuant may use both --quant=<QFORMATS> and --recipe=<RECIPE>."
106107
echo "Optional args: --sparsity=<SPARSITY_FMT> --awq_block_size=<AWQ_BLOCK_SIZE> --calib=<CALIB_SIZE>"
107108
exit 1
108109
fi
109110

110-
# --quant and --recipe are mutually exclusive: --recipe is a full PTQ spec, while
111-
# --quant selects a built-in qformat preset. Pick exactly one.
112-
if [ -n "$QFORMAT" ] && [ -n "$RECIPE" ]; then
113-
echo "Cannot specify both --quant and --recipe; pick one." >&2
111+
# For plain PTQ, --quant and --recipe are mutually exclusive: --recipe is a full PTQ spec,
112+
# while --quant selects a built-in qformat preset. For AutoQuant, --quant selects the search
113+
# candidates and --recipe may provide search metadata such as disabled_layers.
114+
if [ -n "$QFORMAT" ] && [ -n "$RECIPE" ] && [ -z "$AUTO_QUANTIZE_BITS" ]; then
115+
echo "Cannot specify both --quant and --recipe for plain PTQ; pick one." >&2
114116
exit 1
115117
fi
116118

modelopt/torch/quantization/_auto_quantize_cost.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,21 @@ def is_routed_moe_module_name(name: str) -> bool:
9090
return "shared_expert" not in name and _ROUTED_MOE_EXPERT_NAME_RE.search(name) is not None
9191

9292

93+
def _get_module_weight_numel(module: nn.Module) -> int:
94+
"""Return the parameter count for a module's quantizable weights."""
95+
weight = getattr(module, "weight", None)
96+
if weight is not None:
97+
return weight.numel()
98+
99+
# Fused MoE expert containers expose projection tensors directly instead of
100+
# a single ``weight`` parameter.
101+
return sum(
102+
param.numel()
103+
for attr in ("gate_up_proj", "down_proj")
104+
if (param := getattr(module, attr, None)) is not None
105+
)
106+
107+
93108
class AutoQuantizeCostModel:
94109
"""Base class for AutoQuantize effective-bits cost accounting."""
95110

@@ -119,7 +134,7 @@ def total_weight_size(
119134
) -> float:
120135
"""Return the cost denominator for the effective-bits constraint."""
121136
return sum(
122-
module.weight.numel() * self.module_cost_weight([name], cost_constraints)
137+
_get_module_weight_numel(module) * self.module_cost_weight([name], cost_constraints)
123138
for name, module in named_modules
124139
if is_auto_quantize_module(module)
125140
)

0 commit comments

Comments
 (0)