1515
1616import argparse
1717import copy
18+ import os
1819import random
1920import time
2021import warnings
@@ -137,6 +138,43 @@ def _set_kv_cache_constant_amax(quant_cfg: list) -> None:
137138mto .enable_huggingface_checkpointing ()
138139
139140
141+ NVFP4_W4A16_CFG = {
142+ "quant_cfg" : [
143+ {"quantizer_name" : "*" , "enable" : False },
144+ {
145+ "quantizer_name" : "*weight_quantizer" ,
146+ "cfg" : {
147+ "num_bits" : (2 , 1 ),
148+ "block_sizes" : {- 1 : 16 , "type" : "dynamic" , "scale_bits" : (4 , 3 )},
149+ },
150+ },
151+ {"quantizer_name" : "*input_quantizer" , "enable" : False },
152+ * _default_disabled_quantizer_cfg ,
153+ ],
154+ "algorithm" : "max" ,
155+ }
156+
157+ FP8_W8A16_CFG = {
158+ "quant_cfg" : [
159+ {"quantizer_name" : "*" , "enable" : False },
160+ {
161+ "quantizer_name" : "*weight_quantizer" ,
162+ "cfg" : {"num_bits" : (4 , 3 ), "axis" : None },
163+ },
164+ {"quantizer_name" : "*input_quantizer" , "enable" : False },
165+ * _default_disabled_quantizer_cfg ,
166+ ],
167+ "algorithm" : "max" ,
168+ }
169+
170+ QUANT_CFG_CHOICES .update (
171+ {
172+ "nvfp4_w4a16" : NVFP4_W4A16_CFG ,
173+ "fp8_w8a16" : FP8_W8A16_CFG ,
174+ }
175+ )
176+
177+
140178def extract_and_prepare_language_model_from_vl (full_model ):
141179 """Extract language model from VL model and disable quantization for non-language components.
142180
@@ -326,6 +364,8 @@ def auto_quantize(
326364 "nvfp4_omlp_only" ,
327365 "nvfp4_local_hessian" ,
328366 "mxfp8" ,
367+ "nvfp4_w4a16" ,
368+ "fp8_w8a16" ,
329369 ]
330370 for qformat in qformat_list
331371 ), "One or more quantization formats provided are not supported for unified checkpoint export"
@@ -348,6 +388,38 @@ def forward_step(model, batch):
348388 f"Invalid auto_quantize_method: { auto_quantize_method } . Must be 'gradient' or 'kl_div'"
349389 )
350390
391+ # Let AutoQuantize search lm_head, but keep modules out that vLLM either
392+ # constructs as BF16-only paths today or has known unsafe fused dispatch for.
393+ disabled_layers = [
394+ entry ["quantizer_name" ]
395+ for entry in _default_disabled_quantizer_cfg
396+ if "parent_class" not in entry and entry ["quantizer_name" ] != "*lm_head*"
397+ ]
398+ enable_linear_attn_big3 = os .environ .get ("MODELOPT_AUTOQ_ENABLE_LINEAR_ATTN_BIG3" ) == "1"
399+ enable_shared_expert = os .environ .get ("MODELOPT_AUTOQ_ENABLE_SHARED_EXPERT" ) == "1"
400+ autoq_extra_disabled = [
401+ "*shared_expert_gate*" ,
402+ "*linear_attn.in_proj_a*" ,
403+ "*linear_attn.in_proj_b*" ,
404+ ]
405+ if not enable_shared_expert :
406+ autoq_extra_disabled .append ("*mlp.shared_expert*" )
407+ if not enable_linear_attn_big3 :
408+ autoq_extra_disabled .extend (
409+ [
410+ "*linear_attn.in_proj_qkv*" ,
411+ "*linear_attn.in_proj_z*" ,
412+ "*linear_attn.out_proj*" ,
413+ ]
414+ )
415+ for pat in autoq_extra_disabled :
416+ if pat not in disabled_layers :
417+ disabled_layers .append (pat )
418+ if is_multimodal_model (language_model ):
419+ for pat in ("*visual*" , "*mtp*" , "*vision_tower*" ):
420+ if pat not in disabled_layers :
421+ disabled_layers .append (pat )
422+
351423 language_model , _ = mtq .auto_quantize (
352424 language_model ,
353425 constraints = {"effective_bits" : args .auto_quantize_bits },
@@ -362,12 +434,7 @@ def forward_step(model, batch):
362434 len (calib_dataloader ), max (auto_quantize_score_size // args .batch_size , 1 )
363435 ),
364436 verbose = True ,
365- # Disable all default disabled layers such as lm_head, mlp.gate, router etc.
366- disabled_layers = [
367- entry ["quantizer_name" ]
368- for entry in _default_disabled_quantizer_cfg
369- if "parent_class" not in entry
370- ],
437+ disabled_layers = disabled_layers ,
371438 method = auto_quantize_method ,
372439 checkpoint = auto_quantize_checkpoint ,
373440 )
@@ -507,12 +574,26 @@ def load_model(args: argparse.Namespace):
507574 ]
508575
509576 # We only quantize the language model for VLMs other than the type supported above.
510- extracted_lm , extracted_model_type = extract_and_prepare_language_model_from_vl (
511- full_model
512- )
513- if extracted_lm is not None :
514- language_model = extracted_lm
515- model_type = extracted_model_type
577+ # For AutoQuantize, skip the eager visual-disable side-effect: it
578+ # registers ``modelopt`` state on each visual sibling, which
579+ # ``mtq.auto_quantize → apply_mode → is_converted`` then trips on
580+ # ("Model has multiple modelopt states!"). AutoQuantize handles
581+ # visual/mtp via ``disabled_layers`` patterns instead, so the
582+ # extraction is unnecessary for that path.
583+ #
584+ # For ``--recipe`` mode on a VLM, lm_head sits on the OUTER
585+ # CausalLM. Recipe rules can't see it via the inner language
586+ # backbone, so we keep ``language_model = full_model`` here and
587+ # let ``quantize_main`` strip visual/mtp siblings around
588+ # ``mtq.quantize`` (so registration/calibration stays fast and
589+ # batch_size auto-detect doesn't collapse to 1).
590+ if args .auto_quantize_bits is None and args .recipe is None :
591+ extracted_lm , extracted_model_type = extract_and_prepare_language_model_from_vl (
592+ full_model
593+ )
594+ if extracted_lm is not None :
595+ language_model = extracted_lm
596+ model_type = extracted_model_type
516597
517598 tokenizer = get_tokenizer (args .pyt_ckpt_path , trust_remote_code = args .trust_remote_code )
518599
@@ -628,13 +709,52 @@ def mono_quantize(
628709 else None ,
629710 )
630711
712+ # When ``--recipe`` is given on a VLM we keep ``language_model =
713+ # full_model`` (so recipe rules can match lm_head) but ``mtq.quantize``
714+ # would otherwise walk and register quantizers on every Linear in the
715+ # visual encoder + MTP head.
716+ stripped_vlm_modules : dict [str , torch .nn .Module ] = {}
717+ if args .recipe is not None and language_model is full_model :
718+ for path in ("model.visual" , "mtp" ):
719+ parts = path .split ("." )
720+ parent = full_model
721+ ok = True
722+ for p in parts [:- 1 ]:
723+ if not hasattr (parent , p ):
724+ ok = False
725+ break
726+ parent = getattr (parent , p )
727+ if ok and hasattr (parent , parts [- 1 ]):
728+ mod = getattr (parent , parts [- 1 ])
729+ if mod is not None and isinstance (mod , torch .nn .Module ):
730+ stripped_vlm_modules [path ] = mod
731+ setattr (parent , parts [- 1 ], None )
732+ if stripped_vlm_modules :
733+ print (
734+ "[recipe] stripped VLM siblings before mtq.quantize: "
735+ + ", " .join (stripped_vlm_modules .keys ())
736+ )
737+
631738 if calibration_only :
632739 language_model = mtq .calibrate (
633740 language_model , quant_cfg ["algorithm" ], forward_loop = calibrate_loop
634741 )
635742 else :
636743 language_model = mtq .quantize (language_model , quant_cfg , forward_loop = calibrate_loop )
637744
745+ # Restore stripped VLM siblings so export sees the full model.
746+ for path , mod in stripped_vlm_modules .items ():
747+ parts = path .split ("." )
748+ parent = full_model
749+ for p in parts [:- 1 ]:
750+ parent = getattr (parent , p )
751+ setattr (parent , parts [- 1 ], mod )
752+ if stripped_vlm_modules :
753+ print (
754+ "[recipe] restored VLM siblings after mtq.quantize: "
755+ + ", " .join (stripped_vlm_modules .keys ())
756+ )
757+
638758 # For VL models, update full_model to use the quantized language model
639759 if is_nemotron_vl_model :
640760 language_model_lineage = get_language_model_from_vl (full_model )
@@ -1018,10 +1138,18 @@ def quantize_main(
10181138 "Auto quantization needs multiple quantization format."
10191139 )
10201140
1141+ # For VL models, autoquant must walk submodules of the OUTER CausalLM
1142+ # (which carries lm_head and the LM-head forward path) — otherwise
1143+ # lm_head and any sibling-of-language_model modules are silently
1144+ # invisible to the search. ``forward_step`` also needs the outer model
1145+ # to produce ``CausalLMOutputWithPast`` (for ``.loss`` / ``.logits``).
1146+ # Visual tower and MTP siblings are auto-excluded inside
1147+ # ``auto_quantize()`` via *visual* / *mtp* / *vision_tower* patterns.
10211148 auto_quantize (
10221149 args ,
1023- language_model ,
1150+ full_model ,
10241151 calib_dataloader ,
1152+ auto_quantize_method = args .auto_quantize_method ,
10251153 )
10261154
10271155 else :
0 commit comments