Skip to content

Commit 8b1d3c6

Browse files
committed
address review comments- remove score_checkpoint from Autoquant_YAML, update the kv_cache pydantic type in YAML str -> QuantizeConfig, also update the dispatch in hf_ptq.py now, also add REQUIRED_SECTION_PER_RECIPE_TYPE for Autoquantize and fix a minor bug there
Signed-off-by: Juhi Mittal <juhim@nvidia.com>
1 parent e5953d9 commit 8b1d3c6

5 files changed

Lines changed: 58 additions & 84 deletions

File tree

examples/llm_ptq/hf_ptq.py

Lines changed: 33 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -310,7 +310,7 @@ def auto_quantize(
310310
constraints: dict,
311311
quantization_formats: list[dict],
312312
disabled_layers: list[str],
313-
kv_cache_qformat: str,
313+
kv_cache_quant_cfg: dict | None,
314314
):
315315
"""Pure orchestrator: build forward_step/loss_func, call mtq.auto_quantize,
316316
run KV cache post-step. All knobs are explicit keyword-only args; the
@@ -396,25 +396,24 @@ def forward_step(model, batch):
396396
)
397397

398398
calibrate_loop = create_forward_loop(dataloader=calib_dataloader)
399-
enable_quant_kv_cache = kv_cache_qformat != "none"
400-
print(f"{'Enable' if enable_quant_kv_cache else 'Disable'} KV cache quantization")
401-
if enable_quant_kv_cache:
402-
kv_cache_quant_cfg = copy.deepcopy(
403-
getattr(mtq, KV_QUANT_CFG_CHOICES[kv_cache_qformat])["quant_cfg"]
404-
)
405-
kv_cache_quant_cfg = [
406-
e for e in kv_cache_quant_cfg if e["quantizer_name"] != "*"
399+
print(f"{'Enable' if kv_cache_quant_cfg is not None else 'Disable'} KV cache quantization")
400+
if kv_cache_quant_cfg is not None:
401+
kv_entries = [
402+
e for e in copy.deepcopy(kv_cache_quant_cfg["quant_cfg"]) if e["quantizer_name"] != "*"
407403
] # keep other quantizers from auto_quantize
408404

409-
if kv_cache_qformat in _KV_CAST_FORMATS:
410-
_set_kv_cache_constant_amax(kv_cache_quant_cfg)
411-
412-
mtq.set_quantizer_by_cfg(language_model, quant_cfg=kv_cache_quant_cfg)
413-
if kv_cache_qformat not in _KV_CAST_FORMATS:
405+
mtq.set_quantizer_by_cfg(language_model, quant_cfg=kv_entries)
406+
# Calibrate only when at least one KV entry doesn't pin amax via use_constant_amax.
407+
# Cast-variant presets (kv_fp8_cast, kv_nvfp4_cast) bake this in; data-driven
408+
# variants (kv_fp8, kv_nvfp4, etc.) need a calibration pass.
409+
needs_calibration = not all(
410+
(e.get("cfg") or {}).get("use_constant_amax") is True for e in kv_entries
411+
)
412+
if needs_calibration:
414413
# Calibrate only the KV cache quantizers; disable all others.
415414
with mtq.set_quantizer_by_cfg_context(
416415
language_model,
417-
[{"quantizer_name": "*", "enable": False}, *kv_cache_quant_cfg],
416+
[{"quantizer_name": "*", "enable": False}, *kv_entries],
418417
):
419418
mtq.calibrate(language_model, algorithm="max", forward_loop=calibrate_loop)
420419
return language_model
@@ -1075,6 +1074,19 @@ def _is_layerwise(obj):
10751074
if "parent_class" not in entry
10761075
]
10771076

1077+
# Resolve --kv_cache_qformat to a full QuantizeConfig dict (or None). Used as the
1078+
# CLI fallback when a recipe is silent on KV cache, and as the sole source for the
1079+
# CLI autoquant branch. Cast variants get use_constant_amax injected at this layer
1080+
# so the helper can stay format-agnostic (it just checks use_constant_amax to
1081+
# decide whether to calibrate).
1082+
def _cli_kv_cache_quant_cfg():
1083+
if args.kv_cache_qformat == "none":
1084+
return None
1085+
cfg = copy.deepcopy(getattr(mtq, KV_QUANT_CFG_CHOICES[args.kv_cache_qformat]))
1086+
if args.kv_cache_qformat in _KV_CAST_FORMATS:
1087+
_set_kv_cache_constant_amax(cfg["quant_cfg"])
1088+
return cfg
1089+
10781090
if isinstance(recipe, ModelOptAutoQuantizeRecipe):
10791091
aq = recipe.auto_quantize
10801092

@@ -1101,14 +1113,14 @@ def _candidate_for_mtq(fmt):
11011113
full_model=full_model,
11021114
auto_quantize_method=aq.method,
11031115
auto_quantize_score_size=aq.num_score_steps,
1104-
auto_quantize_checkpoint=aq.score_checkpoint,
1116+
auto_quantize_checkpoint=args.auto_quantize_checkpoint,
11051117
constraints=aq.constraints.model_dump(exclude_none=True),
11061118
quantization_formats=[_candidate_for_mtq(fmt) for fmt in aq.candidate_formats],
11071119
disabled_layers=aq.disabled_layers or default_disabled_layers,
1108-
kv_cache_qformat=(
1109-
aq.kv_cache.qformat
1110-
if (aq.kv_cache and aq.kv_cache.qformat)
1111-
else args.kv_cache_qformat
1120+
kv_cache_quant_cfg=(
1121+
aq.kv_cache.model_dump()
1122+
if aq.kv_cache is not None
1123+
else _cli_kv_cache_quant_cfg()
11121124
),
11131125
)
11141126
else:
@@ -1148,7 +1160,7 @@ def _candidate_for_mtq(fmt):
11481160
constraints={"effective_bits": args.auto_quantize_bits},
11491161
quantization_formats=[QUANT_CFG_CHOICES[fmt] for fmt in qformat_list],
11501162
disabled_layers=default_disabled_layers,
1151-
kv_cache_qformat=args.kv_cache_qformat,
1163+
kv_cache_quant_cfg=_cli_kv_cache_quant_cfg(),
11521164
)
11531165

11541166
else:

modelopt/recipe/config.py

Lines changed: 7 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
import warnings
2121
from enum import Enum
22-
from typing import ClassVar, Literal
22+
from typing import Literal
2323

2424
from pydantic import Field, field_validator, model_validator
2525

@@ -106,45 +106,6 @@ class ModelOptPTQRecipe(ModelOptRecipeBase):
106106
)
107107

108108

109-
class AutoQuantizeKVCache(ModeloptBaseConfig):
110-
"""KV-cache configuration for an AutoQuantize recipe (optional)."""
111-
112-
# Mirrors the keys of KV_QUANT_CFG_CHOICES in examples/llm_ptq/hf_ptq.py.
113-
# Kept inline (rather than imported) so the recipe schema stays free of
114-
# example-script dependencies. Update both sides if new KV variants land.
115-
# ClassVar annotation tells Pydantic this is a class-level constant, not a
116-
# private model attribute (which is the default for leading-underscore names).
117-
_SUPPORTED_QFORMATS: ClassVar[frozenset[str]] = frozenset(
118-
{
119-
"none",
120-
"fp8_cast",
121-
"fp8",
122-
"fp8_affine",
123-
"nvfp4_cast",
124-
"nvfp4",
125-
"nvfp4_affine",
126-
"nvfp4_rotate",
127-
}
128-
)
129-
130-
qformat: str | None = ModeloptField(
131-
default=None,
132-
title="KV cache quantization format",
133-
description="One of the entries in KV_QUANT_CFG_CHOICES, or 'none' to disable. "
134-
"If omitted, the runtime --kv_cache_qformat CLI flag is used.",
135-
)
136-
137-
@field_validator("qformat")
138-
@classmethod
139-
def _validate_qformat(cls, v: str | None) -> str | None:
140-
if v is not None and v not in cls._SUPPORTED_QFORMATS:
141-
raise ValueError(
142-
f"Unsupported kv_cache.qformat: {v!r}. "
143-
f"Expected one of {sorted(cls._SUPPORTED_QFORMATS)} or None."
144-
)
145-
return v
146-
147-
148109
class AutoQuantizeConstraints(ModeloptBaseConfig):
149110
"""Constraints passed to ``mtq.auto_quantize`` (matches its dict shape).
150111
@@ -201,16 +162,13 @@ class AutoQuantizeConfig(ModeloptBaseConfig):
201162
description="Glob patterns; matching layers are excluded from the search.",
202163
)
203164

204-
score_checkpoint: str | None = ModeloptField(
205-
default=None,
206-
title="Search-state checkpoint path",
207-
description="Path to save/restore search state for resume or cheap re-solve.",
208-
)
209-
210-
kv_cache: AutoQuantizeKVCache | None = ModeloptField(
165+
kv_cache: QuantizeConfig | None = ModeloptField(
211166
default=None,
212-
title="KV cache override",
213-
description="Optional KV cache config. If omitted, --kv_cache_qformat CLI flag is used.",
167+
title="KV cache QuantizeConfig (optional)",
168+
description="Optional full QuantizeConfig applied as a uniform post-step after the "
169+
"LP search. Typically uses ``$import: configs/ptq/units/kv_*`` for a built-in KV "
170+
"preset, or inlines a custom config. If omitted, the runtime --kv_cache_qformat "
171+
"CLI flag is used as a fallback.",
214172
)
215173

216174
@field_validator("candidate_formats")

modelopt/recipe/loader.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
# must contain 'quantize'" instead of pydantic's generic missing-field error.
4343
_REQUIRED_SECTION_PER_RECIPE_TYPE: dict[RecipeType, str] = {
4444
RecipeType.PTQ: "quantize",
45+
RecipeType.AUTO_QUANTIZE: "auto_quantize",
4546
RecipeType.SPECULATIVE_EAGLE: "eagle",
4647
RecipeType.SPECULATIVE_DFLASH: "dflash",
4748
RecipeType.SPECULATIVE_MEDUSA: "medusa",
@@ -171,8 +172,12 @@ def _load_recipe_from_file(
171172

172173
raw = yaml.safe_load(recipe_file.read_text()) or {}
173174
if not isinstance(raw, dict) or required_section not in raw:
175+
# Speculative recipes use the family suffix ("EAGLE" not "SPECULATIVE_EAGLE");
176+
# every other multi-word recipe type uses the full value ("AUTO_QUANTIZE", not "QUANTIZE").
174177
kind = (
175-
rtype.value.split("_", 1)[-1].upper() if "_" in rtype.value else rtype.value.upper()
178+
rtype.value.removeprefix("speculative_").upper()
179+
if rtype.value.startswith("speculative_")
180+
else rtype.value.upper()
176181
)
177182
raise ValueError(f"{kind} recipe file {recipe_file} must contain {required_section!r}.")
178183

modelopt_recipes/general/auto_quantize/nvfp4_fp8_at_4p8bits-kv_fp8_cast.yaml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
imports:
2020
nvfp4: configs/ptq/presets/model/nvfp4
2121
fp8: configs/ptq/presets/model/fp8
22+
kv_fp8_cast: configs/ptq/units/kv_fp8_cast
2223

2324
metadata:
2425
recipe_type: auto_quantize
@@ -33,7 +34,8 @@ auto_quantize:
3334
- $import: fp8
3435

3536
kv_cache:
36-
qformat: fp8_cast
37+
quant_cfg:
38+
- $import: kv_fp8_cast
3739

3840
method: gradient
3941
num_score_steps: 128

tests/unit/recipe/test_loader.py

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -272,7 +272,10 @@ def test_load_recipe_autoquantize_builtin():
272272
aq = recipe.auto_quantize
273273
assert aq.constraints.effective_bits == 4.8
274274
assert len(aq.candidate_formats) == 2
275-
assert aq.kv_cache is not None and aq.kv_cache.qformat == "fp8_cast"
275+
# kv_cache is a full QuantizeConfig now (not a hardcoded qformat string).
276+
assert aq.kv_cache is not None
277+
assert aq.kv_cache.algorithm == "max"
278+
assert len(aq.kv_cache.quant_cfg) >= 1
276279

277280

278281
def test_load_recipe_autoquantize_defaults():
@@ -281,7 +284,6 @@ def test_load_recipe_autoquantize_defaults():
281284
aq = recipe.auto_quantize
282285
assert aq.method == "gradient"
283286
assert aq.num_score_steps == 128
284-
assert aq.score_checkpoint is None
285287

286288

287289
def test_load_recipe_autoquantize_candidates_match_presets():
@@ -293,10 +295,13 @@ def test_load_recipe_autoquantize_candidates_match_presets():
293295

294296

295297
def test_load_recipe_autoquantize_missing_section_raises(tmp_path):
296-
"""An AutoQuantize recipe missing the ``auto_quantize`` section is rejected."""
298+
"""An AutoQuantize recipe missing the ``auto_quantize`` section is rejected
299+
with the clean loader-level error (not the generic pydantic missing-field one)."""
297300
bad = tmp_path / "bad.yml"
298301
bad.write_text("metadata:\n recipe_type: auto_quantize\n")
299-
with pytest.raises(ValueError, match="auto_quantize"):
302+
with pytest.raises(
303+
ValueError, match=r"AUTO_QUANTIZE recipe file .* must contain 'auto_quantize'"
304+
):
300305
load_recipe(bad)
301306

302307

@@ -333,14 +338,6 @@ def test_load_recipe_autoquantize_kv_cache_optional(tmp_path):
333338
assert recipe.auto_quantize.kv_cache is None
334339

335340

336-
def test_load_recipe_autoquantize_invalid_kv_qformat_raises(tmp_path):
337-
"""An unknown kv_cache.qformat is rejected at recipe-load time, not later."""
338-
bad = tmp_path / "bad.yml"
339-
bad.write_text(_AQ_MINIMAL_BODY + " kv_cache:\n qformat: not_a_real_format\n")
340-
with pytest.raises(ValueError, match="kv_cache.qformat"):
341-
load_recipe(bad)
342-
343-
344341
# ---------------------------------------------------------------------------
345342
# load_recipe — EAGLE speculative decoding
346343
# ---------------------------------------------------------------------------

0 commit comments

Comments
 (0)