Skip to content

Commit 5877887

Browse files
committed
address review comments
Signed-off-by: Juhi Mittal <juhim@nvidia.com>
1 parent cda4150 commit 5877887

3 files changed

Lines changed: 39 additions & 4 deletions

File tree

examples/llm_ptq/hf_ptq.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1084,7 +1084,7 @@ def _is_layerwise(obj):
10841084
# All auto_quantize() knobs are resolved here before calling the helper.
10851085
# Helper is a leaf orchestrator — it does not know whether inputs came from
10861086
# CLI args or a recipe.
1087-
if isinstance(recipe, ModelOptAutoQuantizeRecipe) or args.auto_quantize_bits:
1087+
if isinstance(recipe, ModelOptAutoQuantizeRecipe) or args.auto_quantize_bits is not None:
10881088
default_disabled_layers = [
10891089
entry["quantizer_name"]
10901090
for entry in _default_disabled_quantizer_cfg

modelopt/recipe/config.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
import warnings
2121
from enum import Enum
22-
from typing import Literal
22+
from typing import ClassVar, Literal
2323

2424
from pydantic import Field, field_validator, model_validator
2525

@@ -109,13 +109,41 @@ class ModelOptPTQRecipe(ModelOptRecipeBase):
109109
class AutoQuantizeKVCache(ModeloptBaseConfig):
110110
"""KV-cache configuration for an AutoQuantize recipe (optional)."""
111111

112+
# Mirrors the keys of KV_QUANT_CFG_CHOICES in examples/llm_ptq/hf_ptq.py.
113+
# Kept inline (rather than imported) so the recipe schema stays free of
114+
# example-script dependencies. Update both sides if new KV variants land.
115+
# ClassVar annotation tells Pydantic this is a class-level constant, not a
116+
# private model attribute (which is the default for leading-underscore names).
117+
_SUPPORTED_QFORMATS: ClassVar[frozenset[str]] = frozenset(
118+
{
119+
"none",
120+
"fp8_cast",
121+
"fp8",
122+
"fp8_affine",
123+
"nvfp4_cast",
124+
"nvfp4",
125+
"nvfp4_affine",
126+
"nvfp4_rotate",
127+
}
128+
)
129+
112130
qformat: str | None = ModeloptField(
113131
default=None,
114132
title="KV cache quantization format",
115133
description="One of the entries in KV_QUANT_CFG_CHOICES, or 'none' to disable. "
116134
"If omitted, the runtime --kv_cache_qformat CLI flag is used.",
117135
)
118136

137+
@field_validator("qformat")
138+
@classmethod
139+
def _validate_qformat(cls, v: str | None) -> str | None:
140+
if v is not None and v not in cls._SUPPORTED_QFORMATS:
141+
raise ValueError(
142+
f"Unsupported kv_cache.qformat: {v!r}. "
143+
f"Expected one of {sorted(cls._SUPPORTED_QFORMATS)} or None."
144+
)
145+
return v
146+
119147

120148
class AutoQuantizeConstraints(ModeloptBaseConfig):
121149
"""Constraints passed to ``mtq.auto_quantize`` (matches its dict shape).

tests/unit/recipe/test_loader.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
import pytest
2222

23+
import modelopt.torch.quantization as mtq
2324
from modelopt.recipe.config import (
2425
ModelOptAutoQuantizeRecipe,
2526
ModelOptDFlashRecipe,
@@ -285,8 +286,6 @@ def test_load_recipe_autoquantize_defaults():
285286

286287
def test_load_recipe_autoquantize_candidates_match_presets():
287288
"""Built-in AutoQuantize recipe's $imported candidates equal mtq.X_DEFAULT_CFG dicts."""
288-
import modelopt.torch.quantization as mtq
289-
290289
recipe = load_recipe("general/auto_quantize/nvfp4_fp8_at_4p8bits-kv_fp8_cast")
291290
candidates = recipe.auto_quantize.candidate_formats
292291
assert candidates[0].model_dump(exclude_unset=True) == mtq.NVFP4_DEFAULT_CFG
@@ -334,6 +333,14 @@ def test_load_recipe_autoquantize_kv_cache_optional(tmp_path):
334333
assert recipe.auto_quantize.kv_cache is None
335334

336335

336+
def test_load_recipe_autoquantize_invalid_kv_qformat_raises(tmp_path):
337+
"""An unknown kv_cache.qformat is rejected at recipe-load time, not later."""
338+
bad = tmp_path / "bad.yml"
339+
bad.write_text(_AQ_MINIMAL_BODY + " kv_cache:\n qformat: not_a_real_format\n")
340+
with pytest.raises(ValueError, match="kv_cache.qformat"):
341+
load_recipe(bad)
342+
343+
337344
# ---------------------------------------------------------------------------
338345
# load_recipe — EAGLE speculative decoding
339346
# ---------------------------------------------------------------------------

0 commit comments

Comments
 (0)