From d913fcf22045192855e0e43afb93c0deaf731b65 Mon Sep 17 00:00:00 2001 From: Juhi Mittal Date: Wed, 20 May 2026 23:33:48 +0000 Subject: [PATCH 1/4] wip: autoquant recipe schema + hf_ptq dispatch Signed-off-by: Juhi Mittal --- examples/llm_ptq/hf_ptq.py | 205 ++++++++++++------ modelopt/recipe/config.py | 106 ++++++++- .../nvfp4_fp8_at_4p8bits-kv_fp8_cast.yaml | 42 ++++ tests/unit/recipe/test_loader.py | 91 ++++++++ 4 files changed, 376 insertions(+), 68 deletions(-) create mode 100644 modelopt_recipes/general/auto_quantize/nvfp4_fp8_at_4p8bits-kv_fp8_cast.yaml diff --git a/examples/llm_ptq/hf_ptq.py b/examples/llm_ptq/hf_ptq.py index 6d27aa593f6..36d4ee943a5 100755 --- a/examples/llm_ptq/hf_ptq.py +++ b/examples/llm_ptq/hf_ptq.py @@ -55,7 +55,12 @@ import modelopt.torch.opt as mto import modelopt.torch.quantization as mtq import modelopt.torch.sparsity as mts -from modelopt.recipe import ModelOptPTQRecipe, load_recipe +from modelopt.recipe import ( + ModelOptAutoQuantizeRecipe, + ModelOptPTQRecipe, + ModelOptRecipeBase, + load_recipe, +) from modelopt.torch.export import ( export_hf_checkpoint, export_hf_vllm_fq_checkpoint, @@ -208,6 +213,7 @@ def make_calib_dataloader( tokenizer: PreTrainedTokenizerBase | None, device: torch.device, model_type: str | None, + recipe: ModelOptRecipeBase | None = None, ) -> tuple[DataLoader | _DeviceDataLoader, str | None]: calib_dataloader = None first_text_speech_dataset = None @@ -271,8 +277,12 @@ def make_calib_dataloader( assert tokenizer is not None and isinstance( tokenizer, (PreTrainedTokenizer, PreTrainedTokenizerFast) ), "The PreTrainedTokenizer must be set" - # Labels are only needed for gradient-based auto_quantize - include_labels = ( + # Labels are only needed for gradient-based auto_quantize (CLI or recipe path). + is_autoquant_recipe_gradient = ( + isinstance(recipe, ModelOptAutoQuantizeRecipe) + and recipe.auto_quantize.method == "gradient" + ) + include_labels = is_autoquant_recipe_gradient or ( args.auto_quantize_bits is not None and args.auto_quantize_method == "gradient" ) @@ -292,12 +302,21 @@ def auto_quantize( args: argparse.Namespace, language_model: torch.nn.Module, calib_dataloader: DataLoader, - auto_quantize_method="gradient", - auto_quantize_score_size=128, - auto_quantize_checkpoint=None, full_model: torch.nn.Module | None = None, + *, + auto_quantize_method: str, + auto_quantize_score_size: int, + auto_quantize_checkpoint: str | None, + constraints: dict, + quantization_formats: list[dict], + disabled_layers: list[str], + kv_cache_qformat: str, ): - """Auto search quantization of multiple formats.""" + """Pure orchestrator: build forward_step/loss_func, call mtq.auto_quantize, + run KV cache post-step. All knobs are explicit keyword-only args; the + caller (dispatch site in ``quantize_main``) is responsible for resolving + them from either CLI args or a recipe before invoking this function. + """ if args.calib_with_images: raise NotImplementedError( @@ -305,35 +324,10 @@ def auto_quantize( "Please run plain PTQ (e.g., --qformat nvfp4) with --calib_with_images." ) - assert not (args.auto_quantize_bits and args.inference_pipeline_parallel > 1), ( + assert args.inference_pipeline_parallel <= 1, ( "Auto Quantization is not supported for pipeline parallel size > 1" ) - qformat_list = args.qformat.split(",") - assert qformat_list, "No quantization formats provided" - # Check if all provided quantization formats are supported - assert all( - qformat - in [ - "fp8", - "int8_sq", - "int8_wo", - "int4_awq", - "nvfp4", - "nvfp4_awq", - "nvfp4_mse", - "w4a8_awq", - "fp8_pb_wo", - "w4a8_mxfp4_fp8", - "nvfp4_mlp_only", - "nvfp4_experts_only", - "nvfp4_omlp_only", - "nvfp4_local_hessian", - "mxfp8", - ] - for qformat in qformat_list - ), "One or more quantization formats provided are not supported for unified checkpoint export" - # When language_model is a base text model without lm_head (e.g. Gemma4TextModel), # use full_model's lm_head to compute logits/loss from hidden states. is_base_model = ( @@ -384,45 +378,39 @@ def forward_step(model, batch): language_model, _ = mtq.auto_quantize( language_model, - constraints={"effective_bits": args.auto_quantize_bits}, + constraints=constraints, data_loader=calib_dataloader, forward_step=forward_step, loss_func=loss_func, # Only used for gradient-based method # TRTLLM only support one quantization format or None (do not quantize, internally supported) - quantization_formats=[QUANT_CFG_CHOICES[format] for format in qformat_list], + quantization_formats=quantization_formats, # type: ignore[arg-type] num_calib_steps=len(calib_dataloader), # AutoQuantize scoring is the costly phase; allow smaller sample counts than calibration. num_score_steps=min( len(calib_dataloader), max(auto_quantize_score_size // args.batch_size, 1) ), verbose=True, - # Disable all default disabled layers such as lm_head, mlp.gate, router etc. - disabled_layers=[ - entry["quantizer_name"] - for entry in _default_disabled_quantizer_cfg - if "parent_class" not in entry - ], + disabled_layers=disabled_layers, method=auto_quantize_method, checkpoint=auto_quantize_checkpoint, ) calibrate_loop = create_forward_loop(dataloader=calib_dataloader) - # We need to explicitly set up KV cache quantization after auto_quantize - enable_quant_kv_cache = args.kv_cache_qformat != "none" + enable_quant_kv_cache = kv_cache_qformat != "none" print(f"{'Enable' if enable_quant_kv_cache else 'Disable'} KV cache quantization") if enable_quant_kv_cache: kv_cache_quant_cfg = copy.deepcopy( - getattr(mtq, KV_QUANT_CFG_CHOICES[args.kv_cache_qformat])["quant_cfg"] + getattr(mtq, KV_QUANT_CFG_CHOICES[kv_cache_qformat])["quant_cfg"] ) kv_cache_quant_cfg = [ e for e in kv_cache_quant_cfg if e["quantizer_name"] != "*" ] # keep other quantizers from auto_quantize - if args.kv_cache_qformat in _KV_CAST_FORMATS: + if kv_cache_qformat in _KV_CAST_FORMATS: _set_kv_cache_constant_amax(kv_cache_quant_cfg) mtq.set_quantizer_by_cfg(language_model, quant_cfg=kv_cache_quant_cfg) - if args.kv_cache_qformat not in _KV_CAST_FORMATS: + if kv_cache_qformat not in _KV_CAST_FORMATS: # Calibrate only the KV cache quantizers; disable all others. with mtq.set_quantizer_by_cfg_context( language_model, @@ -987,12 +975,20 @@ def quantize_main( ): # Load the recipe up front so we can detect layerwise calibration before batch-size probing. recipe = None - if args.recipe is not None and not args.auto_quantize_bits: + if args.recipe is not None: print(f"Use recipe {args.recipe} for quantization") recipe = load_recipe(args.recipe) - if not isinstance(recipe, ModelOptPTQRecipe): + if not isinstance(recipe, (ModelOptPTQRecipe, ModelOptAutoQuantizeRecipe)): raise TypeError( - f"Expected PTQ recipe, but got {type(recipe).__name__} from {args.recipe}" + f"Expected PTQ or AutoQuantize recipe, but got {type(recipe).__name__} " + f"from {args.recipe}" + ) + # Fail-fast on conflicting budget sources: a recipe carries its own + # effective_bits, so silently honoring one over the other would be a + # reproducibility hazard. + if args.auto_quantize_bits is not None: + raise ValueError( + "Cannot combine --auto_quantize_bits with --recipe; the recipe owns the budget." ) def _is_layerwise(obj): @@ -1043,7 +1039,9 @@ def _is_layerwise(obj): else: sample_input_single_batch = None - run_auto_quant = args.auto_quantize_bits is not None + run_auto_quant = args.auto_quantize_bits is not None or isinstance( + recipe, ModelOptAutoQuantizeRecipe + ) args.batch_size = get_max_batch_size( language_model, @@ -1057,7 +1055,7 @@ def _is_layerwise(obj): print(f"Use calib batch_size {args.batch_size}") calib_dataloader, first_text_speech_dataset = make_calib_dataloader( - args, language_model, processor, tokenizer, device, model_type + args, language_model, processor, tokenizer, device, model_type, recipe=recipe ) # Detect if this is a Nemotron VL model using architecture-based detection @@ -1067,20 +1065,91 @@ def _is_layerwise(obj): args, full_model, model_type, tokenizer, calib_dataloader, is_nemotron_vl_model ) - if args.auto_quantize_bits: - assert len(args.qformat.split(",")) > 1, ( - "Auto quantization needs multiple quantization format." - ) + # All auto_quantize() knobs are resolved here before calling the helper. + # Helper is a leaf orchestrator — it does not know whether inputs came from + # CLI args or a recipe. + if isinstance(recipe, ModelOptAutoQuantizeRecipe) or args.auto_quantize_bits: + default_disabled_layers = [ + entry["quantizer_name"] + for entry in _default_disabled_quantizer_cfg + if "parent_class" not in entry + ] - auto_quantize( - args, - language_model, - calib_dataloader, - auto_quantize_method=args.auto_quantize_method, - auto_quantize_score_size=args.auto_quantize_score_size, - auto_quantize_checkpoint=args.auto_quantize_checkpoint, - full_model=full_model, - ) + if isinstance(recipe, ModelOptAutoQuantizeRecipe): + aq = recipe.auto_quantize + + # mtq.auto_quantize labels candidates by upstream identity: dicts that ARE + # an mtq.X_CFG object get the constant's name in logs (e.g. NVFP4_DEFAULT_CFG); + # all other dicts get "CUSTOM_N" plus a "results may not be optimal" warning. + # Recipe candidates come from .model_dump() — equal by value but not identity, + # so we'd lose the friendly names. Substitute the canonical object back when + # the dump matches a known preset, so logs and the warning line up with CLI. + # The match check uses exclude_unset=True so it compares against the + # preset YAML's natural shape (mtq.X_CFG dicts don't carry Pydantic-filled + # defaults). The payload still passes the full dump to upstream. + def _candidate_for_mtq(fmt): + strict = fmt.model_dump(exclude_unset=True) + for cfg in QUANT_CFG_CHOICES.values(): + if cfg == strict: + return cfg + return fmt.model_dump() + + auto_quantize( + args, + language_model, + calib_dataloader, + full_model=full_model, + auto_quantize_method=aq.method, + auto_quantize_score_size=aq.num_score_steps, + auto_quantize_checkpoint=aq.score_checkpoint, + constraints=aq.constraints.model_dump(exclude_none=True), + quantization_formats=[_candidate_for_mtq(fmt) for fmt in aq.candidate_formats], + disabled_layers=aq.disabled_layers or default_disabled_layers, + kv_cache_qformat=( + aq.kv_cache.qformat + if (aq.kv_cache and aq.kv_cache.qformat) + else args.kv_cache_qformat + ), + ) + else: + qformat_list = args.qformat.split(",") + assert len(qformat_list) > 1, "Auto quantization needs multiple quantization format." + assert all( + qformat + in [ + "fp8", + "int8_sq", + "int8_wo", + "int4_awq", + "nvfp4", + "nvfp4_awq", + "nvfp4_mse", + "w4a8_awq", + "fp8_pb_wo", + "w4a8_mxfp4_fp8", + "nvfp4_mlp_only", + "nvfp4_experts_only", + "nvfp4_omlp_only", + "nvfp4_local_hessian", + "mxfp8", + ] + for qformat in qformat_list + ), ( + "One or more quantization formats provided are not supported for unified checkpoint export" + ) + auto_quantize( + args, + language_model, + calib_dataloader, + full_model=full_model, + auto_quantize_method=args.auto_quantize_method, + auto_quantize_score_size=args.auto_quantize_score_size, + auto_quantize_checkpoint=args.auto_quantize_checkpoint, + constraints={"effective_bits": args.auto_quantize_bits}, + quantization_formats=[QUANT_CFG_CHOICES[fmt] for fmt in qformat_list], + disabled_layers=default_disabled_layers, + kv_cache_qformat=args.kv_cache_qformat, + ) else: # mono quantization @@ -1198,9 +1267,11 @@ def parse_args() -> argparse.Namespace: parser.add_argument( "--recipe", help=( - "PTQ recipe YAML file or name without suffix (e.g. general/ptq/fp8_default-kv_fp8_cast, " - "general/ptq/nvfp4_default-kv_fp8_cast, general/ptq/nvfp4_default-kv_nvfp4_cast). " - "When set, --kv_cache_qformat is ignored; the recipe fully determines KV cache config." + "PTQ or AutoQuantize recipe YAML file or name without suffix " + "(e.g. general/ptq/nvfp4_default-kv_fp8_cast, " + "general/auto_quantize/nvfp4_fp8_at_4p8bits-kv_fp8_cast). " + "PTQ recipes fully own quant config; AutoQuantize recipes own search config " + "and may optionally override --kv_cache_qformat via their kv_cache field." ), default=None, ) diff --git a/modelopt/recipe/config.py b/modelopt/recipe/config.py index 749d80a933d..718cd7b758d 100644 --- a/modelopt/recipe/config.py +++ b/modelopt/recipe/config.py @@ -19,8 +19,9 @@ import warnings from enum import Enum +from typing import Literal -from pydantic import Field, model_validator +from pydantic import Field, field_validator, model_validator from modelopt.torch.opt.config import ModeloptBaseConfig, ModeloptField from modelopt.torch.quantization.config import QuantizeConfig # noqa: TC001 @@ -36,6 +37,7 @@ class RecipeType(str, Enum): """List of recipe types. See ``RECIPE_TYPE_TO_CLASS`` at the bottom for the schema mapping.""" PTQ = "ptq" + AUTO_QUANTIZE = "auto_quantize" SPECULATIVE_EAGLE = "speculative_eagle" SPECULATIVE_DFLASH = "speculative_dflash" SPECULATIVE_MEDUSA = "speculative_medusa" @@ -104,6 +106,107 @@ class ModelOptPTQRecipe(ModelOptRecipeBase): ) +class AutoQuantizeKVCache(ModeloptBaseConfig): + """KV-cache configuration for an AutoQuantize recipe (optional).""" + + qformat: str | None = ModeloptField( + default=None, + title="KV cache quantization format", + description="One of the entries in KV_QUANT_CFG_CHOICES, or 'none' to disable. " + "If omitted, the runtime --kv_cache_qformat CLI flag is used.", + ) + + +class AutoQuantizeConstraints(ModeloptBaseConfig): + """Constraints passed to ``mtq.auto_quantize`` (matches its dict shape). + + Today only ``effective_bits`` is supported upstream. When new constraint + keys land (e.g., ``cost_model`` / ``cost`` from PR #1497), add them as + fields here so ``.model_dump(exclude_none=True)`` produces the dict + upstream expects. + """ + + effective_bits: float = ModeloptField( + default=4.8, + title="Effective bits per weight", + description="Average weight-storage bits target for the LP, in (0, 16].", + ) + + @field_validator("effective_bits") + @classmethod + def _validate_effective_bits(cls, v: float) -> float: + if not (0 < v <= 16): + raise ValueError(f"effective_bits must be in (0, 16], got {v}") + return v + + +class AutoQuantizeConfig(ModeloptBaseConfig): + """Schema for the ``auto_quantize`` block in an AutoQuantize recipe.""" + + constraints: AutoQuantizeConstraints = Field( + title="Search constraints + cost model", + description="LP budget and cost model.", + ) + + candidate_formats: list[QuantizeConfig] = ModeloptField( + default=[], + title="Candidate quantization formats", + description="Per-layer search space; each entry is a full QuantizeConfig. " + "At least 2 entries required.", + ) + + method: Literal["gradient", "kl_div"] = ModeloptField( + default="gradient", + title="Sensitivity scoring method", + description="'gradient' (Taylor + Fisher, needs labels) or 'kl_div' (no labels).", + ) + + num_score_steps: int = ModeloptField( + default=128, + title="Phase-3 scoring sample count", + description="Number of batches for sensitivity scoring.", + ) + + disabled_layers: list[str] = ModeloptField( + default=[], + title="Excluded layer patterns", + description="Glob patterns; matching layers are excluded from the search.", + ) + + score_checkpoint: str | None = ModeloptField( + default=None, + title="Search-state checkpoint path", + description="Path to save/restore search state for resume or cheap re-solve.", + ) + + kv_cache: AutoQuantizeKVCache | None = ModeloptField( + default=None, + title="KV cache override", + description="Optional KV cache config. If omitted, --kv_cache_qformat CLI flag is used.", + ) + + @field_validator("candidate_formats") + @classmethod + def _at_least_two_candidates(cls, v: list[QuantizeConfig]) -> list[QuantizeConfig]: + if len(v) < 2: + raise ValueError( + "auto_quantize requires at least 2 candidate_formats. " + "For uniform quantization, use a PTQ recipe instead." + ) + return v + + +class ModelOptAutoQuantizeRecipe(ModelOptRecipeBase): + """Our config class for AutoQuantize recipes.""" + + metadata: RecipeMetadataConfig = _metadata_field(RecipeType.AUTO_QUANTIZE) + + auto_quantize: AutoQuantizeConfig = Field( + title="AutoQuantize config", + description="AutoQuantize search configuration. Required.", + ) + + class ModelOptSpeculativeRecipeBase(ModelOptRecipeBase): """Base class for speculative-decoding recipes. @@ -199,6 +302,7 @@ class ModelOptMedusaRecipe(ModelOptSpeculativeRecipeBase): # uses this for typed-list ``$import`` resolution; add a new entry when introducing a recipe. RECIPE_TYPE_TO_CLASS: dict[RecipeType, type[ModelOptRecipeBase]] = { RecipeType.PTQ: ModelOptPTQRecipe, + RecipeType.AUTO_QUANTIZE: ModelOptAutoQuantizeRecipe, RecipeType.SPECULATIVE_EAGLE: ModelOptEagleRecipe, RecipeType.SPECULATIVE_DFLASH: ModelOptDFlashRecipe, RecipeType.SPECULATIVE_MEDUSA: ModelOptMedusaRecipe, diff --git a/modelopt_recipes/general/auto_quantize/nvfp4_fp8_at_4p8bits-kv_fp8_cast.yaml b/modelopt_recipes/general/auto_quantize/nvfp4_fp8_at_4p8bits-kv_fp8_cast.yaml new file mode 100644 index 00000000000..0fc59f68377 --- /dev/null +++ b/modelopt_recipes/general/auto_quantize/nvfp4_fp8_at_4p8bits-kv_fp8_cast.yaml @@ -0,0 +1,42 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# AutoQuantize recipe: mixed NVFP4 + FP8 per-layer search at 4.8 effective bits, +# FP8 KV cache (cast mode). Gradient-based sensitivity scoring; weight cost model. + +imports: + nvfp4: configs/ptq/presets/model/nvfp4 + fp8: configs/ptq/presets/model/fp8 + +metadata: + recipe_type: auto_quantize + description: Mixed NVFP4 + FP8 at 4.8 effective bits with FP8 KV cache (cast). + +auto_quantize: + constraints: + effective_bits: 4.8 + + candidate_formats: + - $import: nvfp4 + - $import: fp8 + + kv_cache: + qformat: fp8_cast + + method: gradient + num_score_steps: 128 + + disabled_layers: + - "*lm_head*" diff --git a/tests/unit/recipe/test_loader.py b/tests/unit/recipe/test_loader.py index 4c4e2d07ded..5fa9a825c99 100644 --- a/tests/unit/recipe/test_loader.py +++ b/tests/unit/recipe/test_loader.py @@ -21,6 +21,7 @@ import pytest from modelopt.recipe.config import ( + ModelOptAutoQuantizeRecipe, ModelOptDFlashRecipe, ModelOptEagleRecipe, ModelOptPTQRecipe, @@ -243,6 +244,96 @@ def test_load_recipe_dir_missing_quantize_raises(tmp_path): load_recipe(tmp_path) +# --------------------------------------------------------------------------- +# load_recipe — AutoQuantize recipes +# --------------------------------------------------------------------------- + + +_AQ_MINIMAL_BODY = ( + "metadata:\n" + " recipe_type: auto_quantize\n" + "auto_quantize:\n" + " constraints:\n" + " effective_bits: 4.8\n" + " candidate_formats:\n" + " - algorithm: max\n" + " quant_cfg: []\n" + " - algorithm: max\n" + " quant_cfg: []\n" +) + + +def test_load_recipe_autoquantize_builtin(): + """load_recipe loads the built-in AutoQuantize recipe.""" + recipe = load_recipe("general/auto_quantize/nvfp4_fp8_at_4p8bits-kv_fp8_cast") + assert recipe.recipe_type == RecipeType.AUTO_QUANTIZE + assert isinstance(recipe, ModelOptAutoQuantizeRecipe) + aq = recipe.auto_quantize + assert aq.constraints.effective_bits == 4.8 + assert len(aq.candidate_formats) == 2 + assert aq.kv_cache is not None and aq.kv_cache.qformat == "fp8_cast" + + +def test_load_recipe_autoquantize_defaults(): + """Optional AutoQuantize fields use Pydantic defaults when omitted.""" + recipe = load_recipe("general/auto_quantize/nvfp4_fp8_at_4p8bits-kv_fp8_cast") + aq = recipe.auto_quantize + assert aq.method == "gradient" + assert aq.num_score_steps == 128 + assert aq.score_checkpoint is None + + +def test_load_recipe_autoquantize_candidates_match_presets(): + """Built-in AutoQuantize recipe's $imported candidates equal mtq.X_DEFAULT_CFG dicts.""" + import modelopt.torch.quantization as mtq + + recipe = load_recipe("general/auto_quantize/nvfp4_fp8_at_4p8bits-kv_fp8_cast") + candidates = recipe.auto_quantize.candidate_formats + assert candidates[0].model_dump(exclude_unset=True) == mtq.NVFP4_DEFAULT_CFG + assert candidates[1].model_dump(exclude_unset=True) == mtq.FP8_DEFAULT_CFG + + +def test_load_recipe_autoquantize_missing_section_raises(tmp_path): + """An AutoQuantize recipe missing the ``auto_quantize`` section is rejected.""" + bad = tmp_path / "bad.yml" + bad.write_text("metadata:\n recipe_type: auto_quantize\n") + with pytest.raises(ValueError, match="auto_quantize"): + load_recipe(bad) + + +def test_load_recipe_autoquantize_too_few_candidates_raises(tmp_path): + """candidate_formats with fewer than 2 entries is rejected.""" + bad = tmp_path / "bad.yml" + bad.write_text( + "metadata:\n" + " recipe_type: auto_quantize\n" + "auto_quantize:\n" + " constraints:\n" + " effective_bits: 4.8\n" + " candidate_formats:\n" + " - algorithm: max\n" + " quant_cfg: []\n" + ) + with pytest.raises(ValueError, match="at least 2"): + load_recipe(bad) + + +def test_load_recipe_autoquantize_effective_bits_out_of_range_raises(tmp_path): + """effective_bits outside (0, 16] is rejected.""" + bad = tmp_path / "bad.yml" + bad.write_text(_AQ_MINIMAL_BODY.replace("effective_bits: 4.8", "effective_bits: 20")) + with pytest.raises(ValueError, match="effective_bits"): + load_recipe(bad) + + +def test_load_recipe_autoquantize_kv_cache_optional(tmp_path): + """kv_cache is optional; recipes without it parse fine and aq.kv_cache is None.""" + recipe_file = tmp_path / "aq.yml" + recipe_file.write_text(_AQ_MINIMAL_BODY) + recipe = load_recipe(recipe_file) + assert recipe.auto_quantize.kv_cache is None + + # --------------------------------------------------------------------------- # load_recipe — EAGLE speculative decoding # --------------------------------------------------------------------------- From fcee6511dfdc8c7b6991e03d6e1f97a3ab8297fe Mon Sep 17 00:00:00 2001 From: Juhi Mittal Date: Thu, 21 May 2026 19:16:28 +0000 Subject: [PATCH 2/4] address review comments Signed-off-by: Juhi Mittal --- examples/llm_ptq/hf_ptq.py | 2 +- modelopt/recipe/config.py | 30 +++++++++++++++++++++++++++++- tests/unit/recipe/test_loader.py | 11 +++++++++-- 3 files changed, 39 insertions(+), 4 deletions(-) diff --git a/examples/llm_ptq/hf_ptq.py b/examples/llm_ptq/hf_ptq.py index 36d4ee943a5..9bfdf722fe0 100755 --- a/examples/llm_ptq/hf_ptq.py +++ b/examples/llm_ptq/hf_ptq.py @@ -1068,7 +1068,7 @@ def _is_layerwise(obj): # All auto_quantize() knobs are resolved here before calling the helper. # Helper is a leaf orchestrator — it does not know whether inputs came from # CLI args or a recipe. - if isinstance(recipe, ModelOptAutoQuantizeRecipe) or args.auto_quantize_bits: + if isinstance(recipe, ModelOptAutoQuantizeRecipe) or args.auto_quantize_bits is not None: default_disabled_layers = [ entry["quantizer_name"] for entry in _default_disabled_quantizer_cfg diff --git a/modelopt/recipe/config.py b/modelopt/recipe/config.py index 718cd7b758d..1777b0f33eb 100644 --- a/modelopt/recipe/config.py +++ b/modelopt/recipe/config.py @@ -19,7 +19,7 @@ import warnings from enum import Enum -from typing import Literal +from typing import ClassVar, Literal from pydantic import Field, field_validator, model_validator @@ -109,6 +109,24 @@ class ModelOptPTQRecipe(ModelOptRecipeBase): class AutoQuantizeKVCache(ModeloptBaseConfig): """KV-cache configuration for an AutoQuantize recipe (optional).""" + # Mirrors the keys of KV_QUANT_CFG_CHOICES in examples/llm_ptq/hf_ptq.py. + # Kept inline (rather than imported) so the recipe schema stays free of + # example-script dependencies. Update both sides if new KV variants land. + # ClassVar annotation tells Pydantic this is a class-level constant, not a + # private model attribute (which is the default for leading-underscore names). + _SUPPORTED_QFORMATS: ClassVar[frozenset[str]] = frozenset( + { + "none", + "fp8_cast", + "fp8", + "fp8_affine", + "nvfp4_cast", + "nvfp4", + "nvfp4_affine", + "nvfp4_rotate", + } + ) + qformat: str | None = ModeloptField( default=None, title="KV cache quantization format", @@ -116,6 +134,16 @@ class AutoQuantizeKVCache(ModeloptBaseConfig): "If omitted, the runtime --kv_cache_qformat CLI flag is used.", ) + @field_validator("qformat") + @classmethod + def _validate_qformat(cls, v: str | None) -> str | None: + if v is not None and v not in cls._SUPPORTED_QFORMATS: + raise ValueError( + f"Unsupported kv_cache.qformat: {v!r}. " + f"Expected one of {sorted(cls._SUPPORTED_QFORMATS)} or None." + ) + return v + class AutoQuantizeConstraints(ModeloptBaseConfig): """Constraints passed to ``mtq.auto_quantize`` (matches its dict shape). diff --git a/tests/unit/recipe/test_loader.py b/tests/unit/recipe/test_loader.py index 5fa9a825c99..d453f233899 100644 --- a/tests/unit/recipe/test_loader.py +++ b/tests/unit/recipe/test_loader.py @@ -20,6 +20,7 @@ import pytest +import modelopt.torch.quantization as mtq from modelopt.recipe.config import ( ModelOptAutoQuantizeRecipe, ModelOptDFlashRecipe, @@ -285,8 +286,6 @@ def test_load_recipe_autoquantize_defaults(): def test_load_recipe_autoquantize_candidates_match_presets(): """Built-in AutoQuantize recipe's $imported candidates equal mtq.X_DEFAULT_CFG dicts.""" - import modelopt.torch.quantization as mtq - recipe = load_recipe("general/auto_quantize/nvfp4_fp8_at_4p8bits-kv_fp8_cast") candidates = recipe.auto_quantize.candidate_formats assert candidates[0].model_dump(exclude_unset=True) == mtq.NVFP4_DEFAULT_CFG @@ -334,6 +333,14 @@ def test_load_recipe_autoquantize_kv_cache_optional(tmp_path): assert recipe.auto_quantize.kv_cache is None +def test_load_recipe_autoquantize_invalid_kv_qformat_raises(tmp_path): + """An unknown kv_cache.qformat is rejected at recipe-load time, not later.""" + bad = tmp_path / "bad.yml" + bad.write_text(_AQ_MINIMAL_BODY + " kv_cache:\n qformat: not_a_real_format\n") + with pytest.raises(ValueError, match="kv_cache.qformat"): + load_recipe(bad) + + # --------------------------------------------------------------------------- # load_recipe — EAGLE speculative decoding # --------------------------------------------------------------------------- From e15dc62f083209138a3ef2837d18735b36098a8d Mon Sep 17 00:00:00 2001 From: Juhi Mittal Date: Thu, 21 May 2026 22:59:21 +0000 Subject: [PATCH 3/4] address review comments- remove score_checkpoint from Autoquant_YAML, update the kv_cache pydantic type in YAML str -> QuantizeConfig, also update the dispatch in hf_ptq.py now, also add REQUIRED_SECTION_PER_RECIPE_TYPE for Autoquantize and fix a minor bug there Signed-off-by: Juhi Mittal --- examples/llm_ptq/hf_ptq.py | 54 +++++++++++------- modelopt/recipe/config.py | 56 +++---------------- modelopt/recipe/loader.py | 7 ++- .../nvfp4_fp8_at_4p8bits-kv_fp8_cast.yaml | 4 +- tests/unit/recipe/test_loader.py | 21 +++---- 5 files changed, 58 insertions(+), 84 deletions(-) diff --git a/examples/llm_ptq/hf_ptq.py b/examples/llm_ptq/hf_ptq.py index 9bfdf722fe0..1e0b243c3f3 100755 --- a/examples/llm_ptq/hf_ptq.py +++ b/examples/llm_ptq/hf_ptq.py @@ -310,7 +310,7 @@ def auto_quantize( constraints: dict, quantization_formats: list[dict], disabled_layers: list[str], - kv_cache_qformat: str, + kv_cache_quant_cfg: dict | None, ): """Pure orchestrator: build forward_step/loss_func, call mtq.auto_quantize, run KV cache post-step. All knobs are explicit keyword-only args; the @@ -396,25 +396,24 @@ def forward_step(model, batch): ) calibrate_loop = create_forward_loop(dataloader=calib_dataloader) - enable_quant_kv_cache = kv_cache_qformat != "none" - print(f"{'Enable' if enable_quant_kv_cache else 'Disable'} KV cache quantization") - if enable_quant_kv_cache: - kv_cache_quant_cfg = copy.deepcopy( - getattr(mtq, KV_QUANT_CFG_CHOICES[kv_cache_qformat])["quant_cfg"] - ) - kv_cache_quant_cfg = [ - e for e in kv_cache_quant_cfg if e["quantizer_name"] != "*" + print(f"{'Enable' if kv_cache_quant_cfg is not None else 'Disable'} KV cache quantization") + if kv_cache_quant_cfg is not None: + kv_entries = [ + e for e in copy.deepcopy(kv_cache_quant_cfg["quant_cfg"]) if e["quantizer_name"] != "*" ] # keep other quantizers from auto_quantize - if kv_cache_qformat in _KV_CAST_FORMATS: - _set_kv_cache_constant_amax(kv_cache_quant_cfg) - - mtq.set_quantizer_by_cfg(language_model, quant_cfg=kv_cache_quant_cfg) - if kv_cache_qformat not in _KV_CAST_FORMATS: + mtq.set_quantizer_by_cfg(language_model, quant_cfg=kv_entries) + # Calibrate only when at least one KV entry doesn't pin amax via use_constant_amax. + # Cast-variant presets (kv_fp8_cast, kv_nvfp4_cast) bake this in; data-driven + # variants (kv_fp8, kv_nvfp4, etc.) need a calibration pass. + needs_calibration = not all( + (e.get("cfg") or {}).get("use_constant_amax") is True for e in kv_entries + ) + if needs_calibration: # Calibrate only the KV cache quantizers; disable all others. with mtq.set_quantizer_by_cfg_context( language_model, - [{"quantizer_name": "*", "enable": False}, *kv_cache_quant_cfg], + [{"quantizer_name": "*", "enable": False}, *kv_entries], ): mtq.calibrate(language_model, algorithm="max", forward_loop=calibrate_loop) return language_model @@ -1075,6 +1074,19 @@ def _is_layerwise(obj): if "parent_class" not in entry ] + # Resolve --kv_cache_qformat to a full QuantizeConfig dict (or None). Used as the + # CLI fallback when a recipe is silent on KV cache, and as the sole source for the + # CLI autoquant branch. Cast variants get use_constant_amax injected at this layer + # so the helper can stay format-agnostic (it just checks use_constant_amax to + # decide whether to calibrate). + def _cli_kv_cache_quant_cfg(): + if args.kv_cache_qformat == "none": + return None + cfg = copy.deepcopy(getattr(mtq, KV_QUANT_CFG_CHOICES[args.kv_cache_qformat])) + if args.kv_cache_qformat in _KV_CAST_FORMATS: + _set_kv_cache_constant_amax(cfg["quant_cfg"]) + return cfg + if isinstance(recipe, ModelOptAutoQuantizeRecipe): aq = recipe.auto_quantize @@ -1101,14 +1113,14 @@ def _candidate_for_mtq(fmt): full_model=full_model, auto_quantize_method=aq.method, auto_quantize_score_size=aq.num_score_steps, - auto_quantize_checkpoint=aq.score_checkpoint, + auto_quantize_checkpoint=args.auto_quantize_checkpoint, constraints=aq.constraints.model_dump(exclude_none=True), quantization_formats=[_candidate_for_mtq(fmt) for fmt in aq.candidate_formats], disabled_layers=aq.disabled_layers or default_disabled_layers, - kv_cache_qformat=( - aq.kv_cache.qformat - if (aq.kv_cache and aq.kv_cache.qformat) - else args.kv_cache_qformat + kv_cache_quant_cfg=( + aq.kv_cache.model_dump() + if aq.kv_cache is not None + else _cli_kv_cache_quant_cfg() ), ) else: @@ -1148,7 +1160,7 @@ def _candidate_for_mtq(fmt): constraints={"effective_bits": args.auto_quantize_bits}, quantization_formats=[QUANT_CFG_CHOICES[fmt] for fmt in qformat_list], disabled_layers=default_disabled_layers, - kv_cache_qformat=args.kv_cache_qformat, + kv_cache_quant_cfg=_cli_kv_cache_quant_cfg(), ) else: diff --git a/modelopt/recipe/config.py b/modelopt/recipe/config.py index 1777b0f33eb..218bac82f40 100644 --- a/modelopt/recipe/config.py +++ b/modelopt/recipe/config.py @@ -19,7 +19,7 @@ import warnings from enum import Enum -from typing import ClassVar, Literal +from typing import Literal from pydantic import Field, field_validator, model_validator @@ -106,45 +106,6 @@ class ModelOptPTQRecipe(ModelOptRecipeBase): ) -class AutoQuantizeKVCache(ModeloptBaseConfig): - """KV-cache configuration for an AutoQuantize recipe (optional).""" - - # Mirrors the keys of KV_QUANT_CFG_CHOICES in examples/llm_ptq/hf_ptq.py. - # Kept inline (rather than imported) so the recipe schema stays free of - # example-script dependencies. Update both sides if new KV variants land. - # ClassVar annotation tells Pydantic this is a class-level constant, not a - # private model attribute (which is the default for leading-underscore names). - _SUPPORTED_QFORMATS: ClassVar[frozenset[str]] = frozenset( - { - "none", - "fp8_cast", - "fp8", - "fp8_affine", - "nvfp4_cast", - "nvfp4", - "nvfp4_affine", - "nvfp4_rotate", - } - ) - - qformat: str | None = ModeloptField( - default=None, - title="KV cache quantization format", - description="One of the entries in KV_QUANT_CFG_CHOICES, or 'none' to disable. " - "If omitted, the runtime --kv_cache_qformat CLI flag is used.", - ) - - @field_validator("qformat") - @classmethod - def _validate_qformat(cls, v: str | None) -> str | None: - if v is not None and v not in cls._SUPPORTED_QFORMATS: - raise ValueError( - f"Unsupported kv_cache.qformat: {v!r}. " - f"Expected one of {sorted(cls._SUPPORTED_QFORMATS)} or None." - ) - return v - - class AutoQuantizeConstraints(ModeloptBaseConfig): """Constraints passed to ``mtq.auto_quantize`` (matches its dict shape). @@ -201,16 +162,13 @@ class AutoQuantizeConfig(ModeloptBaseConfig): description="Glob patterns; matching layers are excluded from the search.", ) - score_checkpoint: str | None = ModeloptField( - default=None, - title="Search-state checkpoint path", - description="Path to save/restore search state for resume or cheap re-solve.", - ) - - kv_cache: AutoQuantizeKVCache | None = ModeloptField( + kv_cache: QuantizeConfig | None = ModeloptField( default=None, - title="KV cache override", - description="Optional KV cache config. If omitted, --kv_cache_qformat CLI flag is used.", + title="KV cache QuantizeConfig (optional)", + description="Optional full QuantizeConfig applied as a uniform post-step after the " + "LP search. Typically uses ``$import: configs/ptq/units/kv_*`` for a built-in KV " + "preset, or inlines a custom config. If omitted, the runtime --kv_cache_qformat " + "CLI flag is used as a fallback.", ) @field_validator("candidate_formats") diff --git a/modelopt/recipe/loader.py b/modelopt/recipe/loader.py index 0a9218ff7d0..1e78c9372d8 100644 --- a/modelopt/recipe/loader.py +++ b/modelopt/recipe/loader.py @@ -42,6 +42,7 @@ # must contain 'quantize'" instead of pydantic's generic missing-field error. _REQUIRED_SECTION_PER_RECIPE_TYPE: dict[RecipeType, str] = { RecipeType.PTQ: "quantize", + RecipeType.AUTO_QUANTIZE: "auto_quantize", RecipeType.SPECULATIVE_EAGLE: "eagle", RecipeType.SPECULATIVE_DFLASH: "dflash", RecipeType.SPECULATIVE_MEDUSA: "medusa", @@ -171,8 +172,12 @@ def _load_recipe_from_file( raw = yaml.safe_load(recipe_file.read_text()) or {} if not isinstance(raw, dict) or required_section not in raw: + # Speculative recipes use the family suffix ("EAGLE" not "SPECULATIVE_EAGLE"); + # every other multi-word recipe type uses the full value ("AUTO_QUANTIZE", not "QUANTIZE"). kind = ( - rtype.value.split("_", 1)[-1].upper() if "_" in rtype.value else rtype.value.upper() + rtype.value.removeprefix("speculative_").upper() + if rtype.value.startswith("speculative_") + else rtype.value.upper() ) raise ValueError(f"{kind} recipe file {recipe_file} must contain {required_section!r}.") diff --git a/modelopt_recipes/general/auto_quantize/nvfp4_fp8_at_4p8bits-kv_fp8_cast.yaml b/modelopt_recipes/general/auto_quantize/nvfp4_fp8_at_4p8bits-kv_fp8_cast.yaml index 0fc59f68377..af723d19889 100644 --- a/modelopt_recipes/general/auto_quantize/nvfp4_fp8_at_4p8bits-kv_fp8_cast.yaml +++ b/modelopt_recipes/general/auto_quantize/nvfp4_fp8_at_4p8bits-kv_fp8_cast.yaml @@ -19,6 +19,7 @@ imports: nvfp4: configs/ptq/presets/model/nvfp4 fp8: configs/ptq/presets/model/fp8 + kv_fp8_cast: configs/ptq/units/kv_fp8_cast metadata: recipe_type: auto_quantize @@ -33,7 +34,8 @@ auto_quantize: - $import: fp8 kv_cache: - qformat: fp8_cast + quant_cfg: + - $import: kv_fp8_cast method: gradient num_score_steps: 128 diff --git a/tests/unit/recipe/test_loader.py b/tests/unit/recipe/test_loader.py index d453f233899..6a6a7265bcf 100644 --- a/tests/unit/recipe/test_loader.py +++ b/tests/unit/recipe/test_loader.py @@ -272,7 +272,10 @@ def test_load_recipe_autoquantize_builtin(): aq = recipe.auto_quantize assert aq.constraints.effective_bits == 4.8 assert len(aq.candidate_formats) == 2 - assert aq.kv_cache is not None and aq.kv_cache.qformat == "fp8_cast" + # kv_cache is a full QuantizeConfig now (not a hardcoded qformat string). + assert aq.kv_cache is not None + assert aq.kv_cache.algorithm == "max" + assert len(aq.kv_cache.quant_cfg) >= 1 def test_load_recipe_autoquantize_defaults(): @@ -281,7 +284,6 @@ def test_load_recipe_autoquantize_defaults(): aq = recipe.auto_quantize assert aq.method == "gradient" assert aq.num_score_steps == 128 - assert aq.score_checkpoint is None def test_load_recipe_autoquantize_candidates_match_presets(): @@ -293,10 +295,13 @@ def test_load_recipe_autoquantize_candidates_match_presets(): def test_load_recipe_autoquantize_missing_section_raises(tmp_path): - """An AutoQuantize recipe missing the ``auto_quantize`` section is rejected.""" + """An AutoQuantize recipe missing the ``auto_quantize`` section is rejected + with the clean loader-level error (not the generic pydantic missing-field one).""" bad = tmp_path / "bad.yml" bad.write_text("metadata:\n recipe_type: auto_quantize\n") - with pytest.raises(ValueError, match="auto_quantize"): + with pytest.raises( + ValueError, match=r"AUTO_QUANTIZE recipe file .* must contain 'auto_quantize'" + ): load_recipe(bad) @@ -333,14 +338,6 @@ def test_load_recipe_autoquantize_kv_cache_optional(tmp_path): assert recipe.auto_quantize.kv_cache is None -def test_load_recipe_autoquantize_invalid_kv_qformat_raises(tmp_path): - """An unknown kv_cache.qformat is rejected at recipe-load time, not later.""" - bad = tmp_path / "bad.yml" - bad.write_text(_AQ_MINIMAL_BODY + " kv_cache:\n qformat: not_a_real_format\n") - with pytest.raises(ValueError, match="kv_cache.qformat"): - load_recipe(bad) - - # --------------------------------------------------------------------------- # load_recipe — EAGLE speculative decoding # --------------------------------------------------------------------------- From a2763fac68d9bc12a22408584feda3111a3d914a Mon Sep 17 00:00:00 2001 From: Juhi Mittal Date: Fri, 22 May 2026 22:45:29 +0000 Subject: [PATCH 4/4] add effective bits in the QuantRecipe field to override the estimate cost num_bits per recipe Signed-off-by: Juhi Mittal --- modelopt/torch/quantization/algorithms.py | 15 ++++++++--- modelopt/torch/quantization/config.py | 19 ++++++++++++++ .../nvfp4_fp8_at_4p8bits-kv_fp8_cast.yaml | 4 +++ tests/unit/recipe/test_loader.py | 20 ++++++++++++-- .../unit/torch/quantization/test_autoquant.py | 26 +++++++++++++++++++ 5 files changed, 79 insertions(+), 5 deletions(-) diff --git a/modelopt/torch/quantization/algorithms.py b/modelopt/torch/quantization/algorithms.py index e4e633e36ae..ba83139bcbe 100644 --- a/modelopt/torch/quantization/algorithms.py +++ b/modelopt/torch/quantization/algorithms.py @@ -49,9 +49,16 @@ def estimate_quant_compression(quant_cfg: QuantizeConfig) -> float: """Estimate the compression ratio of a quantization configuration. - Right now, we find the minimum compression ratio across all quantizer attribute configs. - This is not perfect but is a good proxy for the overall compression ratio. We will improve - this in future releases. + If ``quant_cfg.effective_bits`` is set, returns ``effective_bits / 16`` directly. This + is the override path for formats whose true effective bits don't match the per-quantizer + ``num_bits`` heuristic — e.g., NVFP4 has 4 value bits + a per-16-element FP8 scale + (8/16 = 0.5 bits/element), so true effective bits = 4.5, not the heuristic's 4.0. + + Otherwise, falls back to the heuristic: minimum compression ratio across all enabled + quantizer attribute configs (``num_bits / 16`` for ints, ``(E + M + 1) / 16`` for FP + tuples). This is a good proxy for the overall compression ratio of formats without + block-scale overhead, but under-counts block-quantized formats. We will improve this + in future releases. Args: quant_cfg: The quantization configuration to estimate compression for. @@ -59,6 +66,8 @@ def estimate_quant_compression(quant_cfg: QuantizeConfig) -> float: Returns: float: The estimated compression ratio (0.0 to 1.0). """ + if quant_cfg.effective_bits is not None: + return quant_cfg.effective_bits / 16.0 def estimate_quant_compression_for_quantizer(quantizer_attr_cfg): if isinstance(quantizer_attr_cfg, list): diff --git a/modelopt/torch/quantization/config.py b/modelopt/torch/quantization/config.py index fd95171ce43..0d8cf476843 100644 --- a/modelopt/torch/quantization/config.py +++ b/modelopt/torch/quantization/config.py @@ -1160,6 +1160,25 @@ class QuantizeConfig(ModeloptBaseConfig): validate_default=True, ) + effective_bits: float | None = ModeloptField( + default=None, + title="Effective bits per element (autoquant cost override)", + description=( + "Optional override for the autoquant LP cost model. If set, replaces the " + "heuristic estimate derived from ``num_bits``. Mainly useful for block-quantized " + "formats where the heuristic under-counts due to per-block scale overhead " + "(e.g., NVFP4 actual=4.5 vs heuristic=4.0). Must be in (0, 16] when set. " + "Read only by autoquant; other quantization paths ignore this field." + ), + ) + + @field_validator("effective_bits") + @classmethod + def _validate_effective_bits(cls, v: float | None) -> float | None: + if v is not None and not (0 < v <= 16): + raise ValueError(f"effective_bits must be in (0, 16], got {v}") + return v + @field_validator("quant_cfg", mode="before") @classmethod def normalize_quant_cfg( diff --git a/modelopt_recipes/general/auto_quantize/nvfp4_fp8_at_4p8bits-kv_fp8_cast.yaml b/modelopt_recipes/general/auto_quantize/nvfp4_fp8_at_4p8bits-kv_fp8_cast.yaml index af723d19889..c4b9a71c110 100644 --- a/modelopt_recipes/general/auto_quantize/nvfp4_fp8_at_4p8bits-kv_fp8_cast.yaml +++ b/modelopt_recipes/general/auto_quantize/nvfp4_fp8_at_4p8bits-kv_fp8_cast.yaml @@ -30,7 +30,11 @@ auto_quantize: effective_bits: 4.8 candidate_formats: + # NVFP4 true effective bits = 4 value bits + 8-bit FP8 scale per 16-element block + # = 4 + 0.5 = 4.5 bits/element. Override the heuristic's 4.0 so the LP cost is accurate. - $import: nvfp4 + effective_bits: 4.5 + # FP8 effective bits = 8 (heuristic is correct, per-tensor scale is negligible). - $import: fp8 kv_cache: diff --git a/tests/unit/recipe/test_loader.py b/tests/unit/recipe/test_loader.py index 6a6a7265bcf..1927ce482da 100644 --- a/tests/unit/recipe/test_loader.py +++ b/tests/unit/recipe/test_loader.py @@ -287,10 +287,15 @@ def test_load_recipe_autoquantize_defaults(): def test_load_recipe_autoquantize_candidates_match_presets(): - """Built-in AutoQuantize recipe's $imported candidates equal mtq.X_DEFAULT_CFG dicts.""" + """Built-in AutoQuantize recipe's $imported candidates equal preset + inline override.""" recipe = load_recipe("general/auto_quantize/nvfp4_fp8_at_4p8bits-kv_fp8_cast") candidates = recipe.auto_quantize.candidate_formats - assert candidates[0].model_dump(exclude_unset=True) == mtq.NVFP4_DEFAULT_CFG + + # NVFP4 candidate = canonical preset + inline effective_bits override. + expected_nvfp4 = {**mtq.NVFP4_DEFAULT_CFG, "effective_bits": 4.5} + assert candidates[0].model_dump(exclude_unset=True) == expected_nvfp4 + + # FP8 candidate = canonical preset exactly (no override). assert candidates[1].model_dump(exclude_unset=True) == mtq.FP8_DEFAULT_CFG @@ -338,6 +343,17 @@ def test_load_recipe_autoquantize_kv_cache_optional(tmp_path): assert recipe.auto_quantize.kv_cache is None +def test_load_recipe_autoquantize_effective_bits_inline_override(): + """Inline $import + sibling effective_bits merge applied per candidate.""" + recipe = load_recipe("general/auto_quantize/nvfp4_fp8_at_4p8bits-kv_fp8_cast") + candidates = recipe.auto_quantize.candidate_formats + + # NVFP4 candidate carries the override. + assert candidates[0].effective_bits == 4.5 + # FP8 candidate has no override; heuristic still applies. + assert candidates[1].effective_bits is None + + # --------------------------------------------------------------------------- # load_recipe — EAGLE speculative decoding # --------------------------------------------------------------------------- diff --git a/tests/unit/torch/quantization/test_autoquant.py b/tests/unit/torch/quantization/test_autoquant.py index 87ec73291e7..7ab308079c1 100644 --- a/tests/unit/torch/quantization/test_autoquant.py +++ b/tests/unit/torch/quantization/test_autoquant.py @@ -375,6 +375,32 @@ def test_estimate_quant_compression(): assert estimate_quant_compression(fp8_affine_kv_cfg) == 0.5 +def test_estimate_quant_compression_effective_bits_override(): + """``QuantizeConfig.effective_bits`` overrides the per-quantizer num_bits heuristic. + + Validates two things: + 1. The override path returns ``effective_bits / 16`` and bypasses the heuristic. + 2. Without the override, the heuristic returns the unchanged baseline value. + """ + # NVFP4 — heuristic returns 4.0 bits / 16 = 0.25, but true effective bits is 4.5. + nvfp4_cfg = mtq.config.QuantizeConfig(**mtq.NVFP4_DEFAULT_CFG) + assert nvfp4_cfg.effective_bits is None + assert estimate_quant_compression(nvfp4_cfg) == 0.25 # heuristic baseline + + nvfp4_cfg_overridden = mtq.config.QuantizeConfig(**mtq.NVFP4_DEFAULT_CFG, effective_bits=4.5) + assert estimate_quant_compression(nvfp4_cfg_overridden) == 4.5 / 16.0 + + # Override can also represent a higher cost (e.g., conservative for a sensitive recipe). + nvfp4_cfg_high = mtq.config.QuantizeConfig(**mtq.NVFP4_DEFAULT_CFG, effective_bits=16.0) + assert estimate_quant_compression(nvfp4_cfg_high) == 1.0 + + # Out-of-range values are rejected by the Pydantic validator. + with pytest.raises(ValueError, match="effective_bits must be in"): + mtq.config.QuantizeConfig(**mtq.NVFP4_DEFAULT_CFG, effective_bits=0.0) + with pytest.raises(ValueError, match="effective_bits must be in"): + mtq.config.QuantizeConfig(**mtq.NVFP4_DEFAULT_CFG, effective_bits=17.0) + + @pytest.mark.parametrize("method", ["gradient", "kl_div"]) def test_auto_quantize_checkpoint_resume(method, tmp_path, capsys): """Test that checkpoint can be used to resume an interrupted search."""