Skip to content

Commit a2763fa

Browse files
committed
add effective bits in the QuantRecipe field to override the estimate cost num_bits per recipe
Signed-off-by: Juhi Mittal <juhim@nvidia.com>
1 parent e15dc62 commit a2763fa

5 files changed

Lines changed: 79 additions & 5 deletions

File tree

modelopt/torch/quantization/algorithms.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,16 +49,25 @@
4949
def estimate_quant_compression(quant_cfg: QuantizeConfig) -> float:
5050
"""Estimate the compression ratio of a quantization configuration.
5151
52-
Right now, we find the minimum compression ratio across all quantizer attribute configs.
53-
This is not perfect but is a good proxy for the overall compression ratio. We will improve
54-
this in future releases.
52+
If ``quant_cfg.effective_bits`` is set, returns ``effective_bits / 16`` directly. This
53+
is the override path for formats whose true effective bits don't match the per-quantizer
54+
``num_bits`` heuristic — e.g., NVFP4 has 4 value bits + a per-16-element FP8 scale
55+
(8/16 = 0.5 bits/element), so true effective bits = 4.5, not the heuristic's 4.0.
56+
57+
Otherwise, falls back to the heuristic: minimum compression ratio across all enabled
58+
quantizer attribute configs (``num_bits / 16`` for ints, ``(E + M + 1) / 16`` for FP
59+
tuples). This is a good proxy for the overall compression ratio of formats without
60+
block-scale overhead, but under-counts block-quantized formats. We will improve this
61+
in future releases.
5562
5663
Args:
5764
quant_cfg: The quantization configuration to estimate compression for.
5865
5966
Returns:
6067
float: The estimated compression ratio (0.0 to 1.0).
6168
"""
69+
if quant_cfg.effective_bits is not None:
70+
return quant_cfg.effective_bits / 16.0
6271

6372
def estimate_quant_compression_for_quantizer(quantizer_attr_cfg):
6473
if isinstance(quantizer_attr_cfg, list):

modelopt/torch/quantization/config.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1160,6 +1160,25 @@ class QuantizeConfig(ModeloptBaseConfig):
11601160
validate_default=True,
11611161
)
11621162

1163+
effective_bits: float | None = ModeloptField(
1164+
default=None,
1165+
title="Effective bits per element (autoquant cost override)",
1166+
description=(
1167+
"Optional override for the autoquant LP cost model. If set, replaces the "
1168+
"heuristic estimate derived from ``num_bits``. Mainly useful for block-quantized "
1169+
"formats where the heuristic under-counts due to per-block scale overhead "
1170+
"(e.g., NVFP4 actual=4.5 vs heuristic=4.0). Must be in (0, 16] when set. "
1171+
"Read only by autoquant; other quantization paths ignore this field."
1172+
),
1173+
)
1174+
1175+
@field_validator("effective_bits")
1176+
@classmethod
1177+
def _validate_effective_bits(cls, v: float | None) -> float | None:
1178+
if v is not None and not (0 < v <= 16):
1179+
raise ValueError(f"effective_bits must be in (0, 16], got {v}")
1180+
return v
1181+
11631182
@field_validator("quant_cfg", mode="before")
11641183
@classmethod
11651184
def normalize_quant_cfg(

modelopt_recipes/general/auto_quantize/nvfp4_fp8_at_4p8bits-kv_fp8_cast.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,11 @@ auto_quantize:
3030
effective_bits: 4.8
3131

3232
candidate_formats:
33+
# NVFP4 true effective bits = 4 value bits + 8-bit FP8 scale per 16-element block
34+
# = 4 + 0.5 = 4.5 bits/element. Override the heuristic's 4.0 so the LP cost is accurate.
3335
- $import: nvfp4
36+
effective_bits: 4.5
37+
# FP8 effective bits = 8 (heuristic is correct, per-tensor scale is negligible).
3438
- $import: fp8
3539

3640
kv_cache:

tests/unit/recipe/test_loader.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -287,10 +287,15 @@ def test_load_recipe_autoquantize_defaults():
287287

288288

289289
def test_load_recipe_autoquantize_candidates_match_presets():
290-
"""Built-in AutoQuantize recipe's $imported candidates equal mtq.X_DEFAULT_CFG dicts."""
290+
"""Built-in AutoQuantize recipe's $imported candidates equal preset + inline override."""
291291
recipe = load_recipe("general/auto_quantize/nvfp4_fp8_at_4p8bits-kv_fp8_cast")
292292
candidates = recipe.auto_quantize.candidate_formats
293-
assert candidates[0].model_dump(exclude_unset=True) == mtq.NVFP4_DEFAULT_CFG
293+
294+
# NVFP4 candidate = canonical preset + inline effective_bits override.
295+
expected_nvfp4 = {**mtq.NVFP4_DEFAULT_CFG, "effective_bits": 4.5}
296+
assert candidates[0].model_dump(exclude_unset=True) == expected_nvfp4
297+
298+
# FP8 candidate = canonical preset exactly (no override).
294299
assert candidates[1].model_dump(exclude_unset=True) == mtq.FP8_DEFAULT_CFG
295300

296301

@@ -338,6 +343,17 @@ def test_load_recipe_autoquantize_kv_cache_optional(tmp_path):
338343
assert recipe.auto_quantize.kv_cache is None
339344

340345

346+
def test_load_recipe_autoquantize_effective_bits_inline_override():
347+
"""Inline $import + sibling effective_bits merge applied per candidate."""
348+
recipe = load_recipe("general/auto_quantize/nvfp4_fp8_at_4p8bits-kv_fp8_cast")
349+
candidates = recipe.auto_quantize.candidate_formats
350+
351+
# NVFP4 candidate carries the override.
352+
assert candidates[0].effective_bits == 4.5
353+
# FP8 candidate has no override; heuristic still applies.
354+
assert candidates[1].effective_bits is None
355+
356+
341357
# ---------------------------------------------------------------------------
342358
# load_recipe — EAGLE speculative decoding
343359
# ---------------------------------------------------------------------------

tests/unit/torch/quantization/test_autoquant.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -375,6 +375,32 @@ def test_estimate_quant_compression():
375375
assert estimate_quant_compression(fp8_affine_kv_cfg) == 0.5
376376

377377

378+
def test_estimate_quant_compression_effective_bits_override():
379+
"""``QuantizeConfig.effective_bits`` overrides the per-quantizer num_bits heuristic.
380+
381+
Validates two things:
382+
1. The override path returns ``effective_bits / 16`` and bypasses the heuristic.
383+
2. Without the override, the heuristic returns the unchanged baseline value.
384+
"""
385+
# NVFP4 — heuristic returns 4.0 bits / 16 = 0.25, but true effective bits is 4.5.
386+
nvfp4_cfg = mtq.config.QuantizeConfig(**mtq.NVFP4_DEFAULT_CFG)
387+
assert nvfp4_cfg.effective_bits is None
388+
assert estimate_quant_compression(nvfp4_cfg) == 0.25 # heuristic baseline
389+
390+
nvfp4_cfg_overridden = mtq.config.QuantizeConfig(**mtq.NVFP4_DEFAULT_CFG, effective_bits=4.5)
391+
assert estimate_quant_compression(nvfp4_cfg_overridden) == 4.5 / 16.0
392+
393+
# Override can also represent a higher cost (e.g., conservative for a sensitive recipe).
394+
nvfp4_cfg_high = mtq.config.QuantizeConfig(**mtq.NVFP4_DEFAULT_CFG, effective_bits=16.0)
395+
assert estimate_quant_compression(nvfp4_cfg_high) == 1.0
396+
397+
# Out-of-range values are rejected by the Pydantic validator.
398+
with pytest.raises(ValueError, match="effective_bits must be in"):
399+
mtq.config.QuantizeConfig(**mtq.NVFP4_DEFAULT_CFG, effective_bits=0.0)
400+
with pytest.raises(ValueError, match="effective_bits must be in"):
401+
mtq.config.QuantizeConfig(**mtq.NVFP4_DEFAULT_CFG, effective_bits=17.0)
402+
403+
378404
@pytest.mark.parametrize("method", ["gradient", "kl_div"])
379405
def test_auto_quantize_checkpoint_resume(method, tmp_path, capsys):
380406
"""Test that checkpoint can be used to resume an interrupted search."""

0 commit comments

Comments
 (0)