-
Notifications
You must be signed in to change notification settings - Fork 453
Add YAML based AutoQuantize recipe (currently only CLI is supported) #1523
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
d913fcf
fcee651
e15dc62
a2763fa
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -55,7 +55,12 @@ | |
| import modelopt.torch.opt as mto | ||
| import modelopt.torch.quantization as mtq | ||
| import modelopt.torch.sparsity as mts | ||
| from modelopt.recipe import ModelOptPTQRecipe, load_recipe | ||
| from modelopt.recipe import ( | ||
| ModelOptAutoQuantizeRecipe, | ||
| ModelOptPTQRecipe, | ||
| ModelOptRecipeBase, | ||
| load_recipe, | ||
| ) | ||
| from modelopt.torch.export import ( | ||
| export_hf_checkpoint, | ||
| export_hf_vllm_fq_checkpoint, | ||
|
|
@@ -208,6 +213,7 @@ def make_calib_dataloader( | |
| tokenizer: PreTrainedTokenizerBase | None, | ||
| device: torch.device, | ||
| model_type: str | None, | ||
| recipe: ModelOptRecipeBase | None = None, | ||
| ) -> tuple[DataLoader | _DeviceDataLoader, str | None]: | ||
| calib_dataloader = None | ||
| first_text_speech_dataset = None | ||
|
|
@@ -271,8 +277,12 @@ def make_calib_dataloader( | |
| assert tokenizer is not None and isinstance( | ||
| tokenizer, (PreTrainedTokenizer, PreTrainedTokenizerFast) | ||
| ), "The PreTrainedTokenizer must be set" | ||
| # Labels are only needed for gradient-based auto_quantize | ||
| include_labels = ( | ||
| # Labels are only needed for gradient-based auto_quantize (CLI or recipe path). | ||
| is_autoquant_recipe_gradient = ( | ||
| isinstance(recipe, ModelOptAutoQuantizeRecipe) | ||
| and recipe.auto_quantize.method == "gradient" | ||
| ) | ||
| include_labels = is_autoquant_recipe_gradient or ( | ||
| args.auto_quantize_bits is not None and args.auto_quantize_method == "gradient" | ||
| ) | ||
|
|
||
|
|
@@ -292,48 +302,32 @@ def auto_quantize( | |
| args: argparse.Namespace, | ||
| language_model: torch.nn.Module, | ||
| calib_dataloader: DataLoader, | ||
| auto_quantize_method="gradient", | ||
| auto_quantize_score_size=128, | ||
| auto_quantize_checkpoint=None, | ||
| full_model: torch.nn.Module | None = None, | ||
| *, | ||
| auto_quantize_method: str, | ||
| auto_quantize_score_size: int, | ||
| auto_quantize_checkpoint: str | None, | ||
| constraints: dict, | ||
| quantization_formats: list[dict], | ||
| disabled_layers: list[str], | ||
| kv_cache_quant_cfg: dict | None, | ||
| ): | ||
| """Auto search quantization of multiple formats.""" | ||
| """Pure orchestrator: build forward_step/loss_func, call mtq.auto_quantize, | ||
| run KV cache post-step. All knobs are explicit keyword-only args; the | ||
| caller (dispatch site in ``quantize_main``) is responsible for resolving | ||
| them from either CLI args or a recipe before invoking this function. | ||
| """ | ||
|
|
||
| if args.calib_with_images: | ||
| raise NotImplementedError( | ||
| "AutoQuantize with image-text calibration is not supported yet. " | ||
| "Please run plain PTQ (e.g., --qformat nvfp4) with --calib_with_images." | ||
| ) | ||
|
|
||
| assert not (args.auto_quantize_bits and args.inference_pipeline_parallel > 1), ( | ||
| assert args.inference_pipeline_parallel <= 1, ( | ||
| "Auto Quantization is not supported for pipeline parallel size > 1" | ||
| ) | ||
|
|
||
| qformat_list = args.qformat.split(",") | ||
| assert qformat_list, "No quantization formats provided" | ||
| # Check if all provided quantization formats are supported | ||
| assert all( | ||
| qformat | ||
| in [ | ||
| "fp8", | ||
| "int8_sq", | ||
| "int8_wo", | ||
| "int4_awq", | ||
| "nvfp4", | ||
| "nvfp4_awq", | ||
| "nvfp4_mse", | ||
| "w4a8_awq", | ||
| "fp8_pb_wo", | ||
| "w4a8_mxfp4_fp8", | ||
| "nvfp4_mlp_only", | ||
| "nvfp4_experts_only", | ||
| "nvfp4_omlp_only", | ||
| "nvfp4_local_hessian", | ||
| "mxfp8", | ||
| ] | ||
| for qformat in qformat_list | ||
| ), "One or more quantization formats provided are not supported for unified checkpoint export" | ||
|
|
||
| # When language_model is a base text model without lm_head (e.g. Gemma4TextModel), | ||
| # use full_model's lm_head to compute logits/loss from hidden states. | ||
| is_base_model = ( | ||
|
|
@@ -384,49 +378,42 @@ def forward_step(model, batch): | |
|
|
||
| language_model, _ = mtq.auto_quantize( | ||
| language_model, | ||
| constraints={"effective_bits": args.auto_quantize_bits}, | ||
| constraints=constraints, | ||
| data_loader=calib_dataloader, | ||
| forward_step=forward_step, | ||
| loss_func=loss_func, # Only used for gradient-based method | ||
| # TRTLLM only support one quantization format or None (do not quantize, internally supported) | ||
| quantization_formats=[QUANT_CFG_CHOICES[format] for format in qformat_list], | ||
| quantization_formats=quantization_formats, # type: ignore[arg-type] | ||
| num_calib_steps=len(calib_dataloader), | ||
| # AutoQuantize scoring is the costly phase; allow smaller sample counts than calibration. | ||
| num_score_steps=min( | ||
| len(calib_dataloader), max(auto_quantize_score_size // args.batch_size, 1) | ||
| ), | ||
| verbose=True, | ||
| # Disable all default disabled layers such as lm_head, mlp.gate, router etc. | ||
| disabled_layers=[ | ||
| entry["quantizer_name"] | ||
| for entry in _default_disabled_quantizer_cfg | ||
| if "parent_class" not in entry | ||
| ], | ||
| disabled_layers=disabled_layers, | ||
| method=auto_quantize_method, | ||
| checkpoint=auto_quantize_checkpoint, | ||
| ) | ||
|
|
||
| calibrate_loop = create_forward_loop(dataloader=calib_dataloader) | ||
| # We need to explicitly set up KV cache quantization after auto_quantize | ||
| enable_quant_kv_cache = args.kv_cache_qformat != "none" | ||
| print(f"{'Enable' if enable_quant_kv_cache else 'Disable'} KV cache quantization") | ||
| if enable_quant_kv_cache: | ||
| kv_cache_quant_cfg = copy.deepcopy( | ||
| getattr(mtq, KV_QUANT_CFG_CHOICES[args.kv_cache_qformat])["quant_cfg"] | ||
| ) | ||
| kv_cache_quant_cfg = [ | ||
| e for e in kv_cache_quant_cfg if e["quantizer_name"] != "*" | ||
| print(f"{'Enable' if kv_cache_quant_cfg is not None else 'Disable'} KV cache quantization") | ||
| if kv_cache_quant_cfg is not None: | ||
| kv_entries = [ | ||
| e for e in copy.deepcopy(kv_cache_quant_cfg["quant_cfg"]) if e["quantizer_name"] != "*" | ||
| ] # keep other quantizers from auto_quantize | ||
|
|
||
| if args.kv_cache_qformat in _KV_CAST_FORMATS: | ||
| _set_kv_cache_constant_amax(kv_cache_quant_cfg) | ||
|
|
||
| mtq.set_quantizer_by_cfg(language_model, quant_cfg=kv_cache_quant_cfg) | ||
| if args.kv_cache_qformat not in _KV_CAST_FORMATS: | ||
| mtq.set_quantizer_by_cfg(language_model, quant_cfg=kv_entries) | ||
| # Calibrate only when at least one KV entry doesn't pin amax via use_constant_amax. | ||
| # Cast-variant presets (kv_fp8_cast, kv_nvfp4_cast) bake this in; data-driven | ||
| # variants (kv_fp8, kv_nvfp4, etc.) need a calibration pass. | ||
| needs_calibration = not all( | ||
| (e.get("cfg") or {}).get("use_constant_amax") is True for e in kv_entries | ||
| ) | ||
| if needs_calibration: | ||
| # Calibrate only the KV cache quantizers; disable all others. | ||
| with mtq.set_quantizer_by_cfg_context( | ||
| language_model, | ||
| [{"quantizer_name": "*", "enable": False}, *kv_cache_quant_cfg], | ||
| [{"quantizer_name": "*", "enable": False}, *kv_entries], | ||
| ): | ||
| mtq.calibrate(language_model, algorithm="max", forward_loop=calibrate_loop) | ||
| return language_model | ||
|
|
@@ -987,12 +974,20 @@ def quantize_main( | |
| ): | ||
| # Load the recipe up front so we can detect layerwise calibration before batch-size probing. | ||
| recipe = None | ||
| if args.recipe is not None and not args.auto_quantize_bits: | ||
| if args.recipe is not None: | ||
| print(f"Use recipe {args.recipe} for quantization") | ||
| recipe = load_recipe(args.recipe) | ||
| if not isinstance(recipe, ModelOptPTQRecipe): | ||
| if not isinstance(recipe, (ModelOptPTQRecipe, ModelOptAutoQuantizeRecipe)): | ||
| raise TypeError( | ||
| f"Expected PTQ recipe, but got {type(recipe).__name__} from {args.recipe}" | ||
| f"Expected PTQ or AutoQuantize recipe, but got {type(recipe).__name__} " | ||
| f"from {args.recipe}" | ||
| ) | ||
| # Fail-fast on conflicting budget sources: a recipe carries its own | ||
| # effective_bits, so silently honoring one over the other would be a | ||
| # reproducibility hazard. | ||
| if args.auto_quantize_bits is not None: | ||
| raise ValueError( | ||
| "Cannot combine --auto_quantize_bits with --recipe; the recipe owns the budget." | ||
| ) | ||
|
|
||
| def _is_layerwise(obj): | ||
|
|
@@ -1043,7 +1038,9 @@ def _is_layerwise(obj): | |
| else: | ||
| sample_input_single_batch = None | ||
|
|
||
| run_auto_quant = args.auto_quantize_bits is not None | ||
| run_auto_quant = args.auto_quantize_bits is not None or isinstance( | ||
| recipe, ModelOptAutoQuantizeRecipe | ||
| ) | ||
|
|
||
| args.batch_size = get_max_batch_size( | ||
| language_model, | ||
|
|
@@ -1057,7 +1054,7 @@ def _is_layerwise(obj): | |
| print(f"Use calib batch_size {args.batch_size}") | ||
|
|
||
| calib_dataloader, first_text_speech_dataset = make_calib_dataloader( | ||
| args, language_model, processor, tokenizer, device, model_type | ||
| args, language_model, processor, tokenizer, device, model_type, recipe=recipe | ||
| ) | ||
|
|
||
| # Detect if this is a Nemotron VL model using architecture-based detection | ||
|
|
@@ -1067,20 +1064,104 @@ def _is_layerwise(obj): | |
| args, full_model, model_type, tokenizer, calib_dataloader, is_nemotron_vl_model | ||
| ) | ||
|
|
||
| if args.auto_quantize_bits: | ||
| assert len(args.qformat.split(",")) > 1, ( | ||
| "Auto quantization needs multiple quantization format." | ||
| ) | ||
| # All auto_quantize() knobs are resolved here before calling the helper. | ||
| # Helper is a leaf orchestrator — it does not know whether inputs came from | ||
| # CLI args or a recipe. | ||
| if isinstance(recipe, ModelOptAutoQuantizeRecipe) or args.auto_quantize_bits is not None: | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I agree, I can remove the auto_quantize specific args that are part of the proposed recipe now, this will help us get rid of the branch when calling the
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sure sounds good; I am okay with this plan. We should update the README and show the correct auto_quantize usage if we decide to remove the CLI args. |
||
| default_disabled_layers = [ | ||
| entry["quantizer_name"] | ||
| for entry in _default_disabled_quantizer_cfg | ||
| if "parent_class" not in entry | ||
| ] | ||
|
|
||
| auto_quantize( | ||
| args, | ||
| language_model, | ||
| calib_dataloader, | ||
| auto_quantize_method=args.auto_quantize_method, | ||
| auto_quantize_score_size=args.auto_quantize_score_size, | ||
| auto_quantize_checkpoint=args.auto_quantize_checkpoint, | ||
| full_model=full_model, | ||
| ) | ||
| # Resolve --kv_cache_qformat to a full QuantizeConfig dict (or None). Used as the | ||
| # CLI fallback when a recipe is silent on KV cache, and as the sole source for the | ||
| # CLI autoquant branch. Cast variants get use_constant_amax injected at this layer | ||
| # so the helper can stay format-agnostic (it just checks use_constant_amax to | ||
| # decide whether to calibrate). | ||
| def _cli_kv_cache_quant_cfg(): | ||
| if args.kv_cache_qformat == "none": | ||
| return None | ||
| cfg = copy.deepcopy(getattr(mtq, KV_QUANT_CFG_CHOICES[args.kv_cache_qformat])) | ||
| if args.kv_cache_qformat in _KV_CAST_FORMATS: | ||
| _set_kv_cache_constant_amax(cfg["quant_cfg"]) | ||
| return cfg | ||
|
|
||
| if isinstance(recipe, ModelOptAutoQuantizeRecipe): | ||
| aq = recipe.auto_quantize | ||
|
|
||
| # mtq.auto_quantize labels candidates by upstream identity: dicts that ARE | ||
| # an mtq.X_CFG object get the constant's name in logs (e.g. NVFP4_DEFAULT_CFG); | ||
| # all other dicts get "CUSTOM_N" plus a "results may not be optimal" warning. | ||
| # Recipe candidates come from .model_dump() — equal by value but not identity, | ||
| # so we'd lose the friendly names. Substitute the canonical object back when | ||
| # the dump matches a known preset, so logs and the warning line up with CLI. | ||
| # The match check uses exclude_unset=True so it compares against the | ||
| # preset YAML's natural shape (mtq.X_CFG dicts don't carry Pydantic-filled | ||
| # defaults). The payload still passes the full dump to upstream. | ||
|
Comment on lines
+1093
to
+1101
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. verbose comment. Can we simplify and make it concise? |
||
| def _candidate_for_mtq(fmt): | ||
| strict = fmt.model_dump(exclude_unset=True) | ||
| for cfg in QUANT_CFG_CHOICES.values(): | ||
| if cfg == strict: | ||
| return cfg | ||
| return fmt.model_dump() | ||
|
|
||
| auto_quantize( | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can we pass aq to |
||
| args, | ||
| language_model, | ||
| calib_dataloader, | ||
| full_model=full_model, | ||
| auto_quantize_method=aq.method, | ||
| auto_quantize_score_size=aq.num_score_steps, | ||
| auto_quantize_checkpoint=args.auto_quantize_checkpoint, | ||
| constraints=aq.constraints.model_dump(exclude_none=True), | ||
| quantization_formats=[_candidate_for_mtq(fmt) for fmt in aq.candidate_formats], | ||
| disabled_layers=aq.disabled_layers or default_disabled_layers, | ||
| kv_cache_quant_cfg=( | ||
| aq.kv_cache.model_dump() | ||
| if aq.kv_cache is not None | ||
| else _cli_kv_cache_quant_cfg() | ||
| ), | ||
| ) | ||
| else: | ||
| qformat_list = args.qformat.split(",") | ||
| assert len(qformat_list) > 1, "Auto quantization needs multiple quantization format." | ||
| assert all( | ||
| qformat | ||
| in [ | ||
| "fp8", | ||
| "int8_sq", | ||
| "int8_wo", | ||
| "int4_awq", | ||
| "nvfp4", | ||
| "nvfp4_awq", | ||
| "nvfp4_mse", | ||
| "w4a8_awq", | ||
| "fp8_pb_wo", | ||
| "w4a8_mxfp4_fp8", | ||
| "nvfp4_mlp_only", | ||
| "nvfp4_experts_only", | ||
| "nvfp4_omlp_only", | ||
| "nvfp4_local_hessian", | ||
| "mxfp8", | ||
| ] | ||
| for qformat in qformat_list | ||
| ), ( | ||
| "One or more quantization formats provided are not supported for unified checkpoint export" | ||
| ) | ||
| auto_quantize( | ||
| args, | ||
| language_model, | ||
| calib_dataloader, | ||
| full_model=full_model, | ||
| auto_quantize_method=args.auto_quantize_method, | ||
| auto_quantize_score_size=args.auto_quantize_score_size, | ||
| auto_quantize_checkpoint=args.auto_quantize_checkpoint, | ||
| constraints={"effective_bits": args.auto_quantize_bits}, | ||
| quantization_formats=[QUANT_CFG_CHOICES[fmt] for fmt in qformat_list], | ||
| disabled_layers=default_disabled_layers, | ||
| kv_cache_quant_cfg=_cli_kv_cache_quant_cfg(), | ||
| ) | ||
|
|
||
| else: | ||
| # mono quantization | ||
|
|
@@ -1198,9 +1279,11 @@ def parse_args() -> argparse.Namespace: | |
| parser.add_argument( | ||
| "--recipe", | ||
| help=( | ||
| "PTQ recipe YAML file or name without suffix (e.g. general/ptq/fp8_default-kv_fp8_cast, " | ||
| "general/ptq/nvfp4_default-kv_fp8_cast, general/ptq/nvfp4_default-kv_nvfp4_cast). " | ||
| "When set, --kv_cache_qformat is ignored; the recipe fully determines KV cache config." | ||
| "PTQ or AutoQuantize recipe YAML file or name without suffix " | ||
| "(e.g. general/ptq/nvfp4_default-kv_fp8_cast, " | ||
| "general/auto_quantize/nvfp4_fp8_at_4p8bits-kv_fp8_cast). " | ||
| "PTQ recipes fully own quant config; AutoQuantize recipes own search config " | ||
| "and may optionally override --kv_cache_qformat via their kv_cache field." | ||
| ), | ||
| default=None, | ||
| ) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: This seems a verbose comment to me (not aligned with https://github.com/NVIDIA/Model-Optimizer/blob/main/CONTRIBUTING.md#-coding-standards(