6666 save_expert_token_count_table ,
6767)
6868from modelopt .torch .export .model_utils import get_language_model_from_vl , is_multimodal_model
69+ from modelopt .torch .quantization ._auto_quantize_cost import EXCLUDED_MODULE_NAME_PATTERNS_KEY
6970from modelopt .torch .quantization .config import _default_disabled_quantizer_cfg , need_calibration
7071from modelopt .torch .quantization .plugins .accelerate import init_quantized_weights
7172from modelopt .torch .quantization .utils import is_quantized
@@ -140,6 +141,36 @@ def _set_kv_cache_constant_amax(quant_cfg: list) -> None:
140141mto .enable_huggingface_checkpointing ()
141142
142143
144+ # TODO: To be refacored into config system.
145+ _QWEN36_AUTOQ_DISABLED_LAYERS = (
146+ "*shared_expert_gate*" ,
147+ "*linear_attn.in_proj_a*" ,
148+ "*linear_attn.in_proj_b*" ,
149+ )
150+ _VLM_AUTOQ_DISABLED_LAYERS = ("*visual*" , "*mtp*" , "*vision_tower*" )
151+
152+
153+ def get_auto_quantize_disabled_layers (model ) -> list [str ]:
154+ """Return layer patterns that should be excluded from AutoQuantize search."""
155+ disabled_layers = [
156+ entry ["quantizer_name" ]
157+ for entry in _default_disabled_quantizer_cfg
158+ if "parent_class" not in entry and entry ["quantizer_name" ] != "*lm_head*"
159+ ]
160+ disabled_layers .extend (p for p in _QWEN36_AUTOQ_DISABLED_LAYERS if p not in disabled_layers )
161+ if is_multimodal_model (model ):
162+ disabled_layers .extend (p for p in _VLM_AUTOQ_DISABLED_LAYERS if p not in disabled_layers )
163+ return disabled_layers
164+
165+
166+ def get_auto_quantize_cost_excluded_patterns (args , model ) -> list [str ]:
167+ """Return layer patterns excluded only from AutoQuantize cost accounting."""
168+ excluded_patterns = list (args .auto_quantize_cost_exclude_patterns or [])
169+ if args .auto_quantize_cost_exclude_vlm_modules and is_multimodal_model (model ):
170+ excluded_patterns .extend (_VLM_AUTOQ_DISABLED_LAYERS )
171+ return list (dict .fromkeys (excluded_patterns ))
172+
173+
143174def extract_and_prepare_language_model_from_vl (full_model ):
144175 """Extract language model from VL model and disable quantization for non-language components.
145176
@@ -323,6 +354,7 @@ def auto_quantize(
323354 "nvfp4_awq" ,
324355 "nvfp4_mse" ,
325356 "w4a8_awq" ,
357+ "w4a16_nvfp4" ,
326358 "fp8_pb_wo" ,
327359 "w4a8_mxfp4_fp8" ,
328360 "nvfp4_mlp_only" ,
@@ -386,10 +418,14 @@ def forward_step(model, batch):
386418 "effective_bits" : args .auto_quantize_bits ,
387419 "cost_model" : args .auto_quantize_cost_model ,
388420 }
421+ auto_quantize_cost = {}
389422 if args .auto_quantize_active_moe_expert_ratio is not None :
390- auto_quantize_constraints ["cost" ] = {
391- "active_moe_expert_ratio" : args .auto_quantize_active_moe_expert_ratio
392- }
423+ auto_quantize_cost ["active_moe_expert_ratio" ] = args .auto_quantize_active_moe_expert_ratio
424+ cost_excluded_patterns = get_auto_quantize_cost_excluded_patterns (args , language_model )
425+ if cost_excluded_patterns :
426+ auto_quantize_cost [EXCLUDED_MODULE_NAME_PATTERNS_KEY ] = cost_excluded_patterns
427+ if auto_quantize_cost :
428+ auto_quantize_constraints ["cost" ] = auto_quantize_cost
393429
394430 language_model , _ = mtq .auto_quantize (
395431 language_model ,
@@ -405,12 +441,7 @@ def forward_step(model, batch):
405441 len (calib_dataloader ), max (auto_quantize_score_size // args .batch_size , 1 )
406442 ),
407443 verbose = True ,
408- # Disable all default disabled layers such as lm_head, mlp.gate, router etc.
409- disabled_layers = [
410- entry ["quantizer_name" ]
411- for entry in _default_disabled_quantizer_cfg
412- if "parent_class" not in entry
413- ],
444+ disabled_layers = get_auto_quantize_disabled_layers (language_model ),
414445 method = auto_quantize_method ,
415446 checkpoint = auto_quantize_checkpoint ,
416447 )
@@ -550,12 +581,10 @@ def load_model(args: argparse.Namespace):
550581 : len (args .dataset )
551582 ]
552583
553- # We only quantize the language model for VLMs other than the type supported above.
554- # Recipe mode is the exception: in Qwen3.5/3.6-MoE VLMs, lm_head sits
555- # on the outer CausalLM, not the inner language backbone. A recipe that targets
556- # lm_head must therefore quantize against the full model and explicitly keep visual
557- # and MTP siblings disabled.
558- if args .recipe is None :
584+ # Plain PTQ quantizes only the extracted language model. Recipe and
585+ # AutoQuantize paths keep the outer CausalLM so recipes/search can see
586+ # Qwen3.5/3.6-MoE VLM lm_head.
587+ if args .recipe is None and args .auto_quantize_bits is None :
559588 extracted_lm , extracted_model_type = extract_and_prepare_language_model_from_vl (
560589 full_model
561590 )
@@ -1081,9 +1110,16 @@ def _is_layerwise(obj):
10811110 "Auto quantization needs multiple quantization format."
10821111 )
10831112
1113+ # For VL models, autoquant must walk submodules of the OUTER CausalLM
1114+ # (which carries lm_head and the LM-head forward path) — otherwise
1115+ # lm_head and any sibling-of-language_model modules are silently
1116+ # invisible to the search. ``forward_step`` also needs the outer model
1117+ # to produce ``CausalLMOutputWithPast`` (for ``.loss`` / ``.logits``).
1118+ # Visual tower and MTP siblings are auto-excluded inside
1119+ # ``auto_quantize()`` via *visual* / *mtp* / *vision_tower* patterns.
10841120 auto_quantize (
10851121 args ,
1086- language_model ,
1122+ full_model ,
10871123 calib_dataloader ,
10881124 auto_quantize_method = args .auto_quantize_method ,
10891125 auto_quantize_score_size = args .auto_quantize_score_size ,
@@ -1423,6 +1459,24 @@ def parse_args() -> argparse.Namespace:
14231459 "routing; use --moe_calib_experts_ratio to control calibration expert coverage."
14241460 ),
14251461 )
1462+ parser .add_argument (
1463+ "--auto_quantize_cost_exclude_patterns" ,
1464+ nargs = "+" ,
1465+ default = None ,
1466+ help = (
1467+ "Wildcard module-name patterns to exclude from AutoQuantize effective-bits cost "
1468+ "accounting. The matched modules can still be disabled from quantization separately; "
1469+ "this flag only changes the budget denominator and selected-cost calculation."
1470+ ),
1471+ )
1472+ parser .add_argument (
1473+ "--auto_quantize_cost_exclude_vlm_modules" ,
1474+ action = "store_true" ,
1475+ help = (
1476+ "Exclude VLM sibling modules matching *visual*, *vision_tower*, and *mtp* from "
1477+ "AutoQuantize effective-bits cost accounting."
1478+ ),
1479+ )
14261480 parser .add_argument (
14271481 "--moe_calib_experts_ratio" ,
14281482 type = float ,
0 commit comments