5555import modelopt .torch .opt as mto
5656import modelopt .torch .quantization as mtq
5757import modelopt .torch .sparsity as mts
58- from modelopt .recipe import ModelOptPTQRecipe , load_recipe
58+ from modelopt .recipe import (
59+ ModelOptAutoQuantizeRecipe ,
60+ ModelOptPTQRecipe ,
61+ ModelOptRecipeBase ,
62+ load_recipe ,
63+ )
5964from modelopt .torch .export import (
6065 export_hf_checkpoint ,
6166 export_hf_vllm_fq_checkpoint ,
@@ -208,6 +213,7 @@ def make_calib_dataloader(
208213 tokenizer : PreTrainedTokenizerBase | None ,
209214 device : torch .device ,
210215 model_type : str | None ,
216+ recipe : ModelOptRecipeBase | None = None ,
211217) -> tuple [DataLoader | _DeviceDataLoader , str | None ]:
212218 calib_dataloader = None
213219 first_text_speech_dataset = None
@@ -271,8 +277,12 @@ def make_calib_dataloader(
271277 assert tokenizer is not None and isinstance (
272278 tokenizer , (PreTrainedTokenizer , PreTrainedTokenizerFast )
273279 ), "The PreTrainedTokenizer must be set"
274- # Labels are only needed for gradient-based auto_quantize
275- include_labels = (
280+ # Labels are only needed for gradient-based auto_quantize (CLI or recipe path).
281+ is_autoquant_recipe_gradient = (
282+ isinstance (recipe , ModelOptAutoQuantizeRecipe )
283+ and recipe .auto_quantize .method == "gradient"
284+ )
285+ include_labels = is_autoquant_recipe_gradient or (
276286 args .auto_quantize_bits is not None and args .auto_quantize_method == "gradient"
277287 )
278288
@@ -292,48 +302,32 @@ def auto_quantize(
292302 args : argparse .Namespace ,
293303 language_model : torch .nn .Module ,
294304 calib_dataloader : DataLoader ,
295- auto_quantize_method = "gradient" ,
296- auto_quantize_score_size = 128 ,
297- auto_quantize_checkpoint = None ,
298305 full_model : torch .nn .Module | None = None ,
306+ * ,
307+ auto_quantize_method : str ,
308+ auto_quantize_score_size : int ,
309+ auto_quantize_checkpoint : str | None ,
310+ constraints : dict ,
311+ quantization_formats : list [dict ],
312+ disabled_layers : list [str ],
313+ kv_cache_qformat : str ,
299314):
300- """Auto search quantization of multiple formats."""
315+ """Pure orchestrator: build forward_step/loss_func, call mtq.auto_quantize,
316+ run KV cache post-step. All knobs are explicit keyword-only args; the
317+ caller (dispatch site in ``quantize_main``) is responsible for resolving
318+ them from either CLI args or a recipe before invoking this function.
319+ """
301320
302321 if args .calib_with_images :
303322 raise NotImplementedError (
304323 "AutoQuantize with image-text calibration is not supported yet. "
305324 "Please run plain PTQ (e.g., --qformat nvfp4) with --calib_with_images."
306325 )
307326
308- assert not ( args .auto_quantize_bits and args . inference_pipeline_parallel > 1 ) , (
327+ assert args .inference_pipeline_parallel <= 1 , (
309328 "Auto Quantization is not supported for pipeline parallel size > 1"
310329 )
311330
312- qformat_list = args .qformat .split ("," )
313- assert qformat_list , "No quantization formats provided"
314- # Check if all provided quantization formats are supported
315- assert all (
316- qformat
317- in [
318- "fp8" ,
319- "int8_sq" ,
320- "int8_wo" ,
321- "int4_awq" ,
322- "nvfp4" ,
323- "nvfp4_awq" ,
324- "nvfp4_mse" ,
325- "w4a8_awq" ,
326- "fp8_pb_wo" ,
327- "w4a8_mxfp4_fp8" ,
328- "nvfp4_mlp_only" ,
329- "nvfp4_experts_only" ,
330- "nvfp4_omlp_only" ,
331- "nvfp4_local_hessian" ,
332- "mxfp8" ,
333- ]
334- for qformat in qformat_list
335- ), "One or more quantization formats provided are not supported for unified checkpoint export"
336-
337331 # When language_model is a base text model without lm_head (e.g. Gemma4TextModel),
338332 # use full_model's lm_head to compute logits/loss from hidden states.
339333 is_base_model = (
@@ -384,45 +378,39 @@ def forward_step(model, batch):
384378
385379 language_model , _ = mtq .auto_quantize (
386380 language_model ,
387- constraints = { "effective_bits" : args . auto_quantize_bits } ,
381+ constraints = constraints ,
388382 data_loader = calib_dataloader ,
389383 forward_step = forward_step ,
390384 loss_func = loss_func , # Only used for gradient-based method
391385 # TRTLLM only support one quantization format or None (do not quantize, internally supported)
392- quantization_formats = [ QUANT_CFG_CHOICES [ format ] for format in qformat_list ],
386+ quantization_formats = quantization_formats , # type: ignore[arg-type]
393387 num_calib_steps = len (calib_dataloader ),
394388 # AutoQuantize scoring is the costly phase; allow smaller sample counts than calibration.
395389 num_score_steps = min (
396390 len (calib_dataloader ), max (auto_quantize_score_size // args .batch_size , 1 )
397391 ),
398392 verbose = True ,
399- # Disable all default disabled layers such as lm_head, mlp.gate, router etc.
400- disabled_layers = [
401- entry ["quantizer_name" ]
402- for entry in _default_disabled_quantizer_cfg
403- if "parent_class" not in entry
404- ],
393+ disabled_layers = disabled_layers ,
405394 method = auto_quantize_method ,
406395 checkpoint = auto_quantize_checkpoint ,
407396 )
408397
409398 calibrate_loop = create_forward_loop (dataloader = calib_dataloader )
410- # We need to explicitly set up KV cache quantization after auto_quantize
411- enable_quant_kv_cache = args .kv_cache_qformat != "none"
399+ enable_quant_kv_cache = kv_cache_qformat != "none"
412400 print (f"{ 'Enable' if enable_quant_kv_cache else 'Disable' } KV cache quantization" )
413401 if enable_quant_kv_cache :
414402 kv_cache_quant_cfg = copy .deepcopy (
415- getattr (mtq , KV_QUANT_CFG_CHOICES [args . kv_cache_qformat ])["quant_cfg" ]
403+ getattr (mtq , KV_QUANT_CFG_CHOICES [kv_cache_qformat ])["quant_cfg" ]
416404 )
417405 kv_cache_quant_cfg = [
418406 e for e in kv_cache_quant_cfg if e ["quantizer_name" ] != "*"
419407 ] # keep other quantizers from auto_quantize
420408
421- if args . kv_cache_qformat in _KV_CAST_FORMATS :
409+ if kv_cache_qformat in _KV_CAST_FORMATS :
422410 _set_kv_cache_constant_amax (kv_cache_quant_cfg )
423411
424412 mtq .set_quantizer_by_cfg (language_model , quant_cfg = kv_cache_quant_cfg )
425- if args . kv_cache_qformat not in _KV_CAST_FORMATS :
413+ if kv_cache_qformat not in _KV_CAST_FORMATS :
426414 # Calibrate only the KV cache quantizers; disable all others.
427415 with mtq .set_quantizer_by_cfg_context (
428416 language_model ,
@@ -1003,12 +991,20 @@ def quantize_main(
1003991):
1004992 # Load the recipe up front so we can detect layerwise calibration before batch-size probing.
1005993 recipe = None
1006- if args .recipe is not None and not args . auto_quantize_bits :
994+ if args .recipe is not None :
1007995 print (f"Use recipe { args .recipe } for quantization" )
1008996 recipe = load_recipe (args .recipe )
1009- if not isinstance (recipe , ModelOptPTQRecipe ):
997+ if not isinstance (recipe , ( ModelOptPTQRecipe , ModelOptAutoQuantizeRecipe ) ):
1010998 raise TypeError (
1011- f"Expected PTQ recipe, but got { type (recipe ).__name__ } from { args .recipe } "
999+ f"Expected PTQ or AutoQuantize recipe, but got { type (recipe ).__name__ } "
1000+ f"from { args .recipe } "
1001+ )
1002+ # Fail-fast on conflicting budget sources: a recipe carries its own
1003+ # effective_bits, so silently honoring one over the other would be a
1004+ # reproducibility hazard.
1005+ if args .auto_quantize_bits is not None :
1006+ raise ValueError (
1007+ "Cannot combine --auto_quantize_bits with --recipe; the recipe owns the budget."
10121008 )
10131009
10141010 def _is_layerwise (obj ):
@@ -1059,7 +1055,9 @@ def _is_layerwise(obj):
10591055 else :
10601056 sample_input_single_batch = None
10611057
1062- run_auto_quant = args .auto_quantize_bits is not None
1058+ run_auto_quant = args .auto_quantize_bits is not None or isinstance (
1059+ recipe , ModelOptAutoQuantizeRecipe
1060+ )
10631061
10641062 args .batch_size = get_max_batch_size (
10651063 language_model ,
@@ -1073,7 +1071,7 @@ def _is_layerwise(obj):
10731071 print (f"Use calib batch_size { args .batch_size } " )
10741072
10751073 calib_dataloader , first_text_speech_dataset = make_calib_dataloader (
1076- args , language_model , processor , tokenizer , device , model_type
1074+ args , language_model , processor , tokenizer , device , model_type , recipe = recipe
10771075 )
10781076
10791077 # Detect if this is a Nemotron VL model using architecture-based detection
@@ -1083,20 +1081,91 @@ def _is_layerwise(obj):
10831081 args , full_model , model_type , tokenizer , calib_dataloader , is_nemotron_vl_model
10841082 )
10851083
1086- if args .auto_quantize_bits :
1087- assert len (args .qformat .split ("," )) > 1 , (
1088- "Auto quantization needs multiple quantization format."
1089- )
1084+ # All auto_quantize() knobs are resolved here before calling the helper.
1085+ # Helper is a leaf orchestrator — it does not know whether inputs came from
1086+ # CLI args or a recipe.
1087+ if isinstance (recipe , ModelOptAutoQuantizeRecipe ) or args .auto_quantize_bits :
1088+ default_disabled_layers = [
1089+ entry ["quantizer_name" ]
1090+ for entry in _default_disabled_quantizer_cfg
1091+ if "parent_class" not in entry
1092+ ]
10901093
1091- auto_quantize (
1092- args ,
1093- language_model ,
1094- calib_dataloader ,
1095- auto_quantize_method = args .auto_quantize_method ,
1096- auto_quantize_score_size = args .auto_quantize_score_size ,
1097- auto_quantize_checkpoint = args .auto_quantize_checkpoint ,
1098- full_model = full_model ,
1099- )
1094+ if isinstance (recipe , ModelOptAutoQuantizeRecipe ):
1095+ aq = recipe .auto_quantize
1096+
1097+ # mtq.auto_quantize labels candidates by upstream identity: dicts that ARE
1098+ # an mtq.X_CFG object get the constant's name in logs (e.g. NVFP4_DEFAULT_CFG);
1099+ # all other dicts get "CUSTOM_N" plus a "results may not be optimal" warning.
1100+ # Recipe candidates come from .model_dump() — equal by value but not identity,
1101+ # so we'd lose the friendly names. Substitute the canonical object back when
1102+ # the dump matches a known preset, so logs and the warning line up with CLI.
1103+ # The match check uses exclude_unset=True so it compares against the
1104+ # preset YAML's natural shape (mtq.X_CFG dicts don't carry Pydantic-filled
1105+ # defaults). The payload still passes the full dump to upstream.
1106+ def _candidate_for_mtq (fmt ):
1107+ strict = fmt .model_dump (exclude_unset = True )
1108+ for cfg in QUANT_CFG_CHOICES .values ():
1109+ if cfg == strict :
1110+ return cfg
1111+ return fmt .model_dump ()
1112+
1113+ auto_quantize (
1114+ args ,
1115+ language_model ,
1116+ calib_dataloader ,
1117+ full_model = full_model ,
1118+ auto_quantize_method = aq .method ,
1119+ auto_quantize_score_size = aq .num_score_steps ,
1120+ auto_quantize_checkpoint = aq .score_checkpoint ,
1121+ constraints = aq .constraints .model_dump (exclude_none = True ),
1122+ quantization_formats = [_candidate_for_mtq (fmt ) for fmt in aq .candidate_formats ],
1123+ disabled_layers = aq .disabled_layers or default_disabled_layers ,
1124+ kv_cache_qformat = (
1125+ aq .kv_cache .qformat
1126+ if (aq .kv_cache and aq .kv_cache .qformat )
1127+ else args .kv_cache_qformat
1128+ ),
1129+ )
1130+ else :
1131+ qformat_list = args .qformat .split ("," )
1132+ assert len (qformat_list ) > 1 , "Auto quantization needs multiple quantization format."
1133+ assert all (
1134+ qformat
1135+ in [
1136+ "fp8" ,
1137+ "int8_sq" ,
1138+ "int8_wo" ,
1139+ "int4_awq" ,
1140+ "nvfp4" ,
1141+ "nvfp4_awq" ,
1142+ "nvfp4_mse" ,
1143+ "w4a8_awq" ,
1144+ "fp8_pb_wo" ,
1145+ "w4a8_mxfp4_fp8" ,
1146+ "nvfp4_mlp_only" ,
1147+ "nvfp4_experts_only" ,
1148+ "nvfp4_omlp_only" ,
1149+ "nvfp4_local_hessian" ,
1150+ "mxfp8" ,
1151+ ]
1152+ for qformat in qformat_list
1153+ ), (
1154+ "One or more quantization formats provided are not supported for unified checkpoint export"
1155+ )
1156+ auto_quantize (
1157+ args ,
1158+ language_model ,
1159+ calib_dataloader ,
1160+ full_model = full_model ,
1161+ auto_quantize_method = args .auto_quantize_method ,
1162+ auto_quantize_score_size = args .auto_quantize_score_size ,
1163+ auto_quantize_checkpoint = args .auto_quantize_checkpoint ,
1164+ constraints = {"effective_bits" : args .auto_quantize_bits },
1165+ quantization_formats = [QUANT_CFG_CHOICES [fmt ] for fmt in qformat_list ],
1166+ disabled_layers = default_disabled_layers ,
1167+ kv_cache_qformat = args .kv_cache_qformat ,
1168+ )
11001169
11011170 else :
11021171 # mono quantization
@@ -1214,9 +1283,11 @@ def parse_args() -> argparse.Namespace:
12141283 parser .add_argument (
12151284 "--recipe" ,
12161285 help = (
1217- "PTQ recipe YAML file or name without suffix (e.g. general/ptq/fp8_default-kv_fp8_cast, "
1218- "general/ptq/nvfp4_default-kv_fp8_cast, general/ptq/nvfp4_default-kv_nvfp4_cast). "
1219- "When set, --kv_cache_qformat is ignored; the recipe fully determines KV cache config."
1286+ "PTQ or AutoQuantize recipe YAML file or name without suffix "
1287+ "(e.g. general/ptq/nvfp4_default-kv_fp8_cast, "
1288+ "general/auto_quantize/nvfp4_fp8_at_4p8bits-kv_fp8_cast). "
1289+ "PTQ recipes fully own quant config; AutoQuantize recipes own search config "
1290+ "and may optionally override --kv_cache_qformat via their kv_cache field."
12201291 ),
12211292 default = None ,
12221293 )
0 commit comments