Skip to content

Commit e4dc020

Browse files
authored
[OMNIML-4775] Move built-in PTQ quantization configs to YAML (#1423)
### What does this PR do? Type of change: refactor This PR moves the built-in PTQ quantization config definitions out of hard-coded Python dictionaries and into schema-backed YAML config files, and factors shared blocks into reusable composable snippets. - Adds reusable numeric config snippets under `modelopt_recipes/configs/numerics/`. - Adds YAML presets for the built-in model PTQ configs under `modelopt_recipes/configs/ptq/presets/model/`. - Adds YAML presets for KV-cache quantization configs under `modelopt_recipes/configs/ptq/presets/kv/`. - Adds YAML presets for the Diffusers-specific PTQ configs under `modelopt_recipes/configs/ptq/presets/diffusers/` and re-points `examples/diffusers/quantization/config.py` constants at them via `load_config`. - Adds reusable KV quantization units (`kv_fp8_affine`, `kv_nvfp4`, `kv_nvfp4_affine`, `kv_nvfp4_rotate`, `kv_*_cast` variants) under `modelopt_recipes/configs/ptq/units/`. - Adds reusable model-side units following the `component_numerics[_type]` convention: - `attention_qkv_fp8` — FP8 E4M3 on attention q/k/v bmm and softmax quantizers; shared by `model/` and `diffusers/` `nvfp4_fp8_mha` presets. - `block_sparse_moe_nvfp4` — NVFP4 W4A4 on `*block_sparse_moe*` weight/input quantizers; shared by `nvfp4_mlp_only`, `nvfp4_experts_only`, `nvfp4_omlp_only`. - `experts_nvfp4` — NVFP4 W4A4 on `*.experts.*` weight/input quantizers; shared by `nvfp4_mlp_only` and `nvfp4_experts_only`. - Switches the existing 5 NVFP4 presets (default + awq lite/clip/full + svdquant) and 4 mamba_moe presets to `$import` the existing `w4a4_nvfp4_nvfp4` / `w8a8_fp8_fp8` units instead of re-inlining the same weight+input quantizer pairs. - Moves the recently-added `W4A16_NVFP4_CFG` to YAML (`presets/model/w4a16_nvfp4.yaml`) composed from the existing `units/w4_nvfp4` snippet. - Updates `modelopt.torch.quantization.config` built-in config constants to load `QuantizeConfig` objects from YAML with `load_config(..., schema_type=QuantizeConfig).model_dump(exclude_unset=True)` via a new `_load_quantize_config_dict` helper; the constants remain plain `dict[str, Any]` for backwards compatibility with consumers that do mapping-style mutation (e.g. `entry["cfg"]` assignment). - Simplifies the cfg-list loader (`_load_quantizer_cfg_dict_list`) down to a 4-line list/single normalization now that the three call sites all load schema-typed YAMLs. - Adds/updates recipe loader coverage for built-in schema-backed config snippets. ### Latent-bug fixes surfaced by the refactor Two small correctness fixes are included alongside the mechanical refactor; flagging them explicitly: - **`examples/diffusers/quantization/quantize.py`** — adds an explicit `base_cfg = copy.deepcopy(base_cfg)` before applying runtime overrides. The existing `# Build a fresh config dict so we never mutate the global constants` comment had been aspirational only; in practice `reset_set_int8_config` accumulated `PercentileCalibrator` entries into `mtq.INT8_SMOOTHQUANT_CFG`/`INT8_DEFAULT_CONFIG` across repeated calls, and `set_quant_config_attr` added `trt_high_precision_dtype` keys into globally-shared cfg dicts. The deepcopy makes the code match the comment. - **`choices` set in `modelopt/torch/quantization/config.py`** — adds `MXFP6_DEFAULT_CFG` and `NVFP4_W4A4_WEIGHT_LOCAL_HESSIAN_CFG` to the documented public set of valid `mtq.*_CFG` names. Both constants exist on main but were missing from `choices`, so CLIs that gate on `mtq.config.choices` (e.g., `hf_ptq.py --qformat`) couldn't reach them even though the configs themselves were fully supported. ### Usage Existing Python imports continue to work: ```python import modelopt.torch.quantization as mtq cfg = mtq.FP8_DEFAULT_CFG model = mtq.quantize(model, cfg, forward_loop) ``` The built-in constants are plain `dict[str, Any]` (sparse — only explicitly-set fields are present), but their definitions now come from YAML snippets and presets composed through the existing `$import` system. Reusable YAML snippets can be composed through `$import`, for example: ```yaml # modelopt-schema: modelopt.torch.quantization.config.QuantizeConfig imports: base_disable_all: configs/ptq/units/base_disable_all w4a4_nvfp4_nvfp4: configs/ptq/units/w4a4_nvfp4_nvfp4 default_disabled_quantizers: configs/ptq/units/default_disabled_quantizers algorithm: max quant_cfg: - $import: base_disable_all - $import: w4a4_nvfp4_nvfp4 - $import: default_disabled_quantizers ``` ### Testing Local checks run: - `nox -s "unit-3.10(torch_211, tf_latest)"` — 2329 passed, 12 skipped. - `nox -s pre_commit_all` — all hooks pass (ruff check / ruff format / mypy / YAML format / license / bandit / markdownlint). - YAML parse + `$import` resolution sanity check across all changed config files. ### Before your PR is "*Ready for review*" Make sure you read and follow [Contributor guidelines](https://github.com/NVIDIA/Model-Optimizer/blob/main/CONTRIBUTING.md) and your commits are signed (`git commit -s -S`). Make sure you read and follow the [Security Best Practices](https://github.com/NVIDIA/Model-Optimizer/blob/main/SECURITY.md#security-coding-practices-for-contributors) (e.g. avoiding hardcoded `trust_remote_code=True`, `torch.load(..., weights_only=False)`, `pickle`, etc.). - Is this change backward compatible?: ✅ Existing built-in Python config constants keep the same public names and dict semantics. - If you copied code from any other sources or added a new PIP dependency, did you follow guidance in `CONTRIBUTING.md`: N/A - Did you write any new necessary tests?: ✅ Adds/updates recipe loader coverage for schema-backed built-in snippets. - Did you update [Changelog](https://github.com/NVIDIA/Model-Optimizer/blob/main/CHANGELOG.rst)?: N/A - Did you get Claude approval on this PR?: ❌ ### Additional Information This PR was previously stacked on #1405, which has since merged to `main`. The branch has been rebased onto `main` and no longer depends on any other open PR. <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Many new quantization numeric configs and PTQ presets added (INT4/INT8/MXFP4/MXFP6/MXFP8/MXINT8/NVFP4), plus Diffusers, KV-cache (affine/cast/rotate) and MLP/MoE-targeted presets. * **Refactor** * Presets and shared snippets migrated to schema-backed YAML sources and centralized loading; INT8 percentile calibration avoids mutating shared base configs. * **Tests** * Tests now discover packaged config snippets at runtime and validate import/append behaviors. * **Documentation** * Presets README and numerous header descriptions updated. * **Chores** * Minor typing and script improvements. <!-- review_stack_entry_start --> [![Review Change Stack](https://storage.googleapis.com/coderabbit_public_assets/review-stack-in-coderabbit-ui.svg)](https://app.coderabbit.ai/change-stack/NVIDIA/Model-Optimizer/pull/1423?utm_source=github_walkthrough&utm_medium=github&utm_campaign=change_stack) <!-- review_stack_entry_end --> <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Shengliang Xu <shengliangx@nvidia.com>
1 parent a5bc6f8 commit e4dc020

90 files changed

Lines changed: 2050 additions & 721 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

examples/diffusers/quantization/config.py

Lines changed: 14 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -16,82 +16,21 @@
1616
import torch.nn as nn
1717
from calib.plugin_calib import PercentileCalibrator
1818

19-
FP8_DEFAULT_CONFIG = {
20-
"quant_cfg": [
21-
{"quantizer_name": "*", "enable": False},
22-
{"quantizer_name": "*weight_quantizer", "cfg": {"num_bits": (4, 3), "axis": None}},
23-
{"quantizer_name": "*input_quantizer", "cfg": {"num_bits": (4, 3), "axis": None}},
24-
{"quantizer_name": "*output_quantizer", "enable": False},
25-
{"quantizer_name": "*softmax_quantizer", "cfg": {"num_bits": (4, 3), "axis": None}},
26-
],
27-
"algorithm": "max",
28-
}
19+
from modelopt.torch.opt.config_loader import load_config
20+
from modelopt.torch.quantization.config import QuantizeConfig
2921

30-
INT8_DEFAULT_CONFIG = {
31-
"quant_cfg": [
32-
{"quantizer_name": "*", "enable": False},
33-
{"quantizer_name": "*weight_quantizer", "cfg": {"num_bits": 8, "axis": 0}},
34-
{"quantizer_name": "*input_quantizer", "cfg": {"num_bits": 8, "axis": None}},
35-
{"quantizer_name": "*output_quantizer", "enable": False},
36-
],
37-
"algorithm": "max",
38-
}
39-
40-
NVFP4_DEFAULT_CONFIG = {
41-
"quant_cfg": [
42-
{"quantizer_name": "*", "enable": False},
43-
{
44-
"quantizer_name": "*weight_quantizer",
45-
"cfg": {
46-
"num_bits": (2, 1),
47-
"block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)},
48-
"axis": None,
49-
},
50-
"enable": True,
51-
},
52-
{
53-
"quantizer_name": "*input_quantizer",
54-
"cfg": {
55-
"num_bits": (2, 1),
56-
"block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)},
57-
"axis": None,
58-
},
59-
"enable": True,
60-
},
61-
{"quantizer_name": "*output_quantizer", "enable": False},
62-
{"quantizer_name": "*softmax_quantizer", "cfg": {"num_bits": (4, 3), "axis": None}},
63-
],
64-
"algorithm": "max",
65-
}
66-
67-
NVFP4_FP8_MHA_CONFIG = {
68-
"quant_cfg": [
69-
{"quantizer_name": "*", "enable": False},
70-
{
71-
"quantizer_name": "**weight_quantizer",
72-
"cfg": {
73-
"num_bits": (2, 1),
74-
"block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)},
75-
"axis": None,
76-
},
77-
"enable": True,
78-
},
79-
{
80-
"quantizer_name": "**input_quantizer",
81-
"cfg": {
82-
"num_bits": (2, 1),
83-
"block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)},
84-
"axis": None,
85-
},
86-
"enable": True,
87-
},
88-
{"quantizer_name": "*output_quantizer", "enable": False},
89-
{"quantizer_name": "*[qkv]_bmm_quantizer", "cfg": {"num_bits": (4, 3), "axis": None}},
90-
{"quantizer_name": "*softmax_quantizer", "cfg": {"num_bits": (4, 3), "axis": None}},
91-
{"quantizer_name": "*bmm2_output_quantizer", "cfg": {"num_bits": (4, 3), "axis": None}},
92-
],
93-
"algorithm": {"method": "svdquant", "lowrank": 32},
94-
}
22+
FP8_DEFAULT_CONFIG = load_config(
23+
"configs/ptq/presets/diffusers/fp8", schema_type=QuantizeConfig
24+
).model_dump(exclude_unset=True)
25+
INT8_DEFAULT_CONFIG = load_config(
26+
"configs/ptq/presets/diffusers/int8", schema_type=QuantizeConfig
27+
).model_dump(exclude_unset=True)
28+
NVFP4_DEFAULT_CONFIG = load_config(
29+
"configs/ptq/presets/diffusers/nvfp4", schema_type=QuantizeConfig
30+
).model_dump(exclude_unset=True)
31+
NVFP4_FP8_MHA_CONFIG = load_config(
32+
"configs/ptq/presets/diffusers/nvfp4_fp8_mha", schema_type=QuantizeConfig
33+
).model_dump(exclude_unset=True)
9534

9635

9736
def set_quant_config_attr(quant_config, trt_high_precision_dtype, quant_algo, **kwargs):

examples/diffusers/quantization/quantize.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
# limitations under the License.
1515

1616
import argparse
17+
import copy
1718
import logging
1819
import sys
1920
import time as time
@@ -114,19 +115,13 @@ def get_quant_config(self, n_steps: int, backbone: torch.nn.Module) -> Any:
114115
"""
115116
self.logger.info(f"Building quantization config for {self.config.format.value}")
116117

118+
apply_int8_percentile_calibrator = False
117119
if self.config.format == QuantFormat.INT8:
118120
if self.config.algo == QuantAlgo.SMOOTHQUANT:
119121
base_cfg = mtq.INT8_SMOOTHQUANT_CFG
120122
else:
121123
base_cfg = INT8_DEFAULT_CONFIG
122-
if self.config.collect_method != CollectMethod.DEFAULT:
123-
reset_set_int8_config(
124-
base_cfg,
125-
self.config.percentile,
126-
n_steps,
127-
collect_method=self.config.collect_method.value,
128-
backbone=backbone,
129-
)
124+
apply_int8_percentile_calibrator = self.config.collect_method != CollectMethod.DEFAULT
130125
elif self.config.format == QuantFormat.FP8:
131126
base_cfg = FP8_DEFAULT_CONFIG
132127
elif self.config.format == QuantFormat.FP4:
@@ -137,7 +132,18 @@ def get_quant_config(self, n_steps: int, backbone: torch.nn.Module) -> Any:
137132
else:
138133
raise NotImplementedError(f"Unknown format {self.config.format}")
139134

140-
# Build a fresh config dict so we never mutate the global constants.
135+
# Build a fresh config dict so runtime overrides never mutate the global constants.
136+
base_cfg = copy.deepcopy(base_cfg)
137+
138+
if apply_int8_percentile_calibrator:
139+
reset_set_int8_config(
140+
base_cfg,
141+
self.config.percentile,
142+
n_steps,
143+
collect_method=self.config.collect_method.value,
144+
backbone=backbone,
145+
)
146+
141147
quant_cfg_list = list(base_cfg["quant_cfg"])
142148

143149
if self.config.format == QuantFormat.FP4:

examples/llm_autodeploy/run_auto_quantize.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
import argparse
1717
from collections import defaultdict
18+
from typing import Any
1819

1920
import torch
2021
from transformers import AutoModelForCausalLM, AutoTokenizer
@@ -24,7 +25,7 @@
2425
from modelopt.torch.utils import create_forward_loop
2526
from modelopt.torch.utils.dataset_utils import get_dataset_dataloader
2627

27-
SUPPORT_QUANT_FORMAT = {
28+
SUPPORT_QUANT_FORMAT: dict[str, dict[str, Any]] = {
2829
"fp8": mtq.FP8_DEFAULT_CFG,
2930
"nvfp4": mtq.NVFP4_DEFAULT_CFG,
3031
}
@@ -87,7 +88,7 @@ def loss_func(output, data):
8788
data_loader=calib_dataloader,
8889
forward_step=lambda model, batch: model(**batch),
8990
loss_func=loss_func,
90-
quantization_formats=[SUPPORT_QUANT_FORMAT[format] for format in qformat_list],
91+
quantization_formats=[SUPPORT_QUANT_FORMAT[quant_format] for quant_format in qformat_list],
9192
num_calib_steps=len(calib_dataloader),
9293
num_score_steps=min(
9394
len(calib_dataloader), 128 // batch_size

modelopt/torch/opt/config_loader.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -336,7 +336,19 @@ def _schema_equal(left: Any | None, right: Any | None) -> bool:
336336
def _list_element_schema(schema_type: Any | None) -> Any | None:
337337
"""Return the element schema for a typed ``list[T]`` annotation."""
338338
schema_type = _unwrap_schema_type(schema_type)
339-
if get_origin(schema_type) is not list:
339+
origin = get_origin(schema_type)
340+
if origin in (UnionType, Union):
341+
element_schemas = []
342+
for arg in get_args(schema_type):
343+
if arg is NoneType:
344+
continue
345+
element_schema = _list_element_schema(arg)
346+
if element_schema is None:
347+
continue
348+
if not any(_schema_equal(element_schema, seen) for seen in element_schemas):
349+
element_schemas.append(element_schema)
350+
return element_schemas[0] if len(element_schemas) == 1 else None
351+
if origin is not list:
340352
return None
341353
args = get_args(schema_type)
342354
if len(args) != 1 or args[0] is Any:
@@ -510,6 +522,12 @@ def _resolve_list_import(
510522
if _schema_equal(imported.schema_type, element_schema):
511523
return [imported.data]
512524

525+
element_schema_unwrapped = _unwrap_schema_type(element_schema)
526+
if isinstance(imported.data, dict) and (
527+
element_schema_unwrapped is dict or get_origin(element_schema_unwrapped) is dict
528+
):
529+
return [imported.data]
530+
513531
raise ValueError(
514532
f"$import {ref_name!r} in list at {context} has schema "
515533
f"{_schema_label(imported.schema_type, imported.schema)!r}; expected either "

0 commit comments

Comments
 (0)