Skip to content

Commit 1cceb95

Browse files
authored
[OMNIML-3689] PTQ quant_cfg semantic correction. Design in doc _quant_cfg.rst (#1094)
### What does this PR do? #### Summary Redesigns the `quant_cfg` configuration format in ModelOpt's PyTorch quantization stack, replacing the previous dict-based format with an **ordered list of typed `QuantizerCfgEntry` dicts**. ##### Motivation The old `quant_cfg` dict had several pain points: - **Ambiguous precedence**: no explicit way to reason about which entry wins when multiple keys match a quantizer - **Mixed key namespaces**: wildcard paths and PyTorch class names lived in the same dict level, requiring ad-hoc dispatch - **Magic `"default"` key**: an implicit, undocumented catch-all that was easy to misuse - **Poor composability**: merging two configs required dict updates that silently discarded keys - **No YAML round-trip fidelity**: the nested structure couldn't be expressed cleanly in YAML ##### New format `quant_cfg` is now an ordered list of `QuantizerCfgEntry` TypedDicts. Each entry has: - `quantizer_name` *(required)*: `fnmatch` wildcard matched against quantizer module names - `cfg` *(optional)*: dict (or list of dicts) of `QuantizerAttributeConfig` fields - `enable` *(optional)*: toggles quantizer on/off independently of `cfg` - `parent_class` *(optional)*: restricts match to quantizers whose parent module is of the given PyTorch class (e.g. `"nn.BatchNorm2d"`) Entries are applied in list order; later entries override earlier ones. The canonical pattern is deny-all first (`_base_disable_all`), then selectively re-enable and configure, then apply standard exclusions (`_default_disabled_quantizer_cfg`). ##### Changes **Core library (`modelopt/torch/quantization/`)** - **`config.py`**: - Added `QuantizerCfgEntry` TypedDict (line 163) and `find_quant_cfg_entry_by_path()` helper for exact-match lookup of entries by path. - Added `normalize_quant_cfg_list()` (line 1539) that converts legacy formats (flat dict, single-key dicts, `nn.*`-scoped dicts, `"default"` key) to canonical `QuantizerCfgEntry` lists. After normalization every entry is guaranteed to have explicit `quantizer_name`, `enable`, and `cfg` keys. - Converted `_default_disabled_quantizer_cfg` and `_mamba_moe_disabled_quantizer_cfg` from dicts to lists of `QuantizerCfgEntry`. - Added `_base_disable_all` (line 205): canonical deny-all entry (`[{"quantizer_name": "*", "enable": False}]`). - Converted all ~30 built-in config constants (`INT8_DEFAULT_CFG`, `FP8_DEFAULT_CFG`, `NVFP4_DEFAULT_CFG`, etc.) to list format using `*_base_disable_all` and `*_default_disabled_quantizer_cfg` unpacking. - KV-cache configs (`FP8_KV_CFG`, `NVFP4_KV_CFG`, etc.) are now minimal lists designed to be concatenated with a primary config — they intentionally omit `_base_disable_all` and `"algorithm"`. - Added two `QuantizeConfig` Pydantic field validators: a `mode="before"` validator that calls `normalize_quant_cfg_list()`, and a `mode="after"` validator that validates `cfg` dicts against `QuantizerAttributeConfig`. - Updated `need_calibration()` to iterate the normalized list instead of the old dict. - Changed `QuantizeQuantCfgType` alias from `dict[str | Callable, ...]` to `list[QuantizerCfgEntry]`. - **`conversion.py`**: - Rewrote `set_quantizer_by_cfg()` (line 217) to iterate the list directly. Each entry's `parent_class` is resolved via `QuantModuleRegistry[parent_class_name]` (the existing `_DMRegistryCls` registry). - Added `set_quantizer_attributes_full()` (line 314): full replacement of quantizer attributes from a `QuantizerAttributeConfig`. Unspecified fields revert to defaults, enforcing entry atomicity. Can also upgrade `TensorQuantizer` → `SequentialQuantizer` or downgrade the reverse. - Added `set_quantizer_attributes_partial()` (line 384): merges a partial `dict` of attributes into existing quantizer state. Does NOT change quantizer structure. Used for enable-only entries. - Added `set_quantizer_by_cfg_context()` context manager (line 447) that temporarily applies a `quant_cfg` list and restores original quantizer state on exit. - Deprecated `set_quantizer_attribute()` (line 525) with a `DeprecationWarning` pointing to the new functions. - **`tensor_quantizer.py`**: - `TensorQuantizer.set_from_attribute_config()`: narrowed type hint from `dict` to `dict[str, Any]`. - Added `_axis_setter` and `_block_sizes_setter` custom setters so that `axis` and `block_sizes` changes properly propagate to the calibrator and maintain mutual exclusivity. - `SequentialQuantizer.set_from_attribute_config()`: narrowed signature to `list[QuantizerAttributeConfig] | list[dict[str, Any]]` (removed the old union with single values). - **`algorithms.py`**: - Updated `_match_quantizer_cfg()` to iterate the list and return `(matched_cfg, matched_enable)` tuple with last-match-wins. - Updated `_cfg_to_dict()`, `estimate_quant_compression()`, and `QuantRecipe` to work with the list-based format. - Updated `get_auto_quantize_config()` to emit list-format `quant_cfg`. - **`model_quant.py`**: `disable_quantizer()` / `enable_quantizer()` now call `set_quantizer_attributes_partial()` directly instead of the deprecated `set_quantizer_attribute()`. Updated docstrings and code examples to show the list format. - **`utils/core_utils.py`**: `disable_lora_quantizers_in_config()` and `update_quant_cfg_with_kv_cache_quant()` updated to append `QuantizerCfgEntry` dicts to the list. - **Other**: minor updates to `backends/fp8_per_tensor_gemm.py`, `backends/nvfp4_gemm.py`, `compress.py`, `model_calib.py`, `export/unified_export_hf.py`, and `sparsity/attention_sparsity/conversion.py` to use the list format. - **`onnx/llm_export_utils/quantization_utils.py`**: Updated quantization config construction to use list format. **YAML recipes (`modelopt_recipes/`)** - Converted all 5 general PTQ recipes to the new list format: - `general/ptq/fp8_default-fp8_kv.yml` - `general/ptq/nvfp4_default-fp8_kv.yml` - `general/ptq/nvfp4_experts_only-fp8_kv.yml` - `general/ptq/nvfp4_mlp_only-fp8_kv.yml` - `general/ptq/nvfp4_omlp_only-fp8_kv.yml` - Converted model-specific recipe: `models/Step3.5-Flash/nvfp4-mlp-only.yaml` **Documentation (`docs/`)** - New guide: `docs/source/guides/_quant_cfg.rst` — comprehensive reference covering entry format, ordering semantics, entry atomicity, `enable` vs `cfg` independence, `parent_class` filtering, and common patterns (deny-all-then-enable, customizing a built-in config, building from scratch). - Updated `_pytorch_quantization.rst` code examples to show the list format with `copy.deepcopy` and `.append()`. - Added `_quant_cfg.rst` to the quantization guide table of contents. **Examples** - Updated all quantization examples to use the list format: `deepseek/ptq.py`, `diffusers/quantization/config.py`, `llm_ptq/hf_ptq.py`, `llm_qat/main.py`, `vllm_serve/vllm_ptq_utils.py`, `llm_autodeploy/run_auto_quantize.py`, `llm_eval/quantization_utils.py`, `llm_ptq/example_utils.py`, `windows/torch_onnx/diffusers/qad_example/sample_example_qad_diffusers.py`, and 2 notebooks. **Tests** - New test file: `tests/unit/torch/quantization/test_config_validation.py` — unit tests for `need_calibration()`, `normalize_quant_cfg_list()` (new format, legacy format conversions, error cases), `find_quant_cfg_entry_by_path()`, `_match_quantizer_cfg()`, and `QuantizeConfig` Pydantic validators. - Extended `tests/unit/torch/quantization/test_quantize_cpu.py` with tests for `set_quantizer_attributes_full()` (atomicity, parent_class filtering, SequentialQuantizer creation), list ordering, enable-only entry behavior, and end-to-end legacy dict format. - Updated 20+ existing test files across `tests/unit/`, `tests/gpu/`, `tests/gpu_megatron/`, and `tests/_test_utils/` to use the list format. ##### Backward compatibility `normalize_quant_cfg_list()` is called automatically by the `QuantizeConfig` Pydantic `mode="before"` validator, so existing code passing the old dict-based format (flat dict like `{"*weight_quantizer": {"num_bits": 8}}`, single-key dict lists, or `nn.*`-scoped dicts with `parent_class` semantics) continues to work without modification. The legacy `"default"` key is converted to `quantizer_name: "*"`. `set_quantizer_attribute()` is preserved as a deprecated wrapper around `set_quantizer_attributes_partial()`. #### Test coverage - **Unit tests**: new `test_config_validation.py` with tests for normalization, validation, path lookup, and cfg matching. Extended `test_quantize_cpu.py` with tests for full/partial attribute setting, ordering, atomicity, and legacy backward compatibility. - **System testing**: ``` python examples/llm_ptq/hf_ptq.py \ --model Qwen/Qwen3-8B \ --recipe general/ptq/fp8_default-fp8_kv \ --export_path=build/fp8_default-fp8_kv42 \ --calib_size=16 \ --batch_size=0 \ --trust_remote_code \ --export_fmt=hf ``` ### Additional Information <!-- E.g. related issue. --> --------- Signed-off-by: Shengliang Xu <shengliangx@nvidia.com>
1 parent c542c09 commit 1cceb95

62 files changed

Lines changed: 3361 additions & 1320 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

CHANGELOG.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,10 @@ NVIDIA Model Optimizer Changelog
1414
- Add support for vLLM fakequant reload using ModelOpt state for HF models. See `examples/vllm_serve/README.md <https://github.com/NVIDIA/Model-Optimizer/tree/main/examples/vllm_serve#load-qatptq-model-and-serve-in-vllm-wip>`_ for more details.
1515
- [Early Testing] Add Claude Code PTQ skill (``.claude/skills/ptq/``) for agent-assisted post-training quantization. The skill guides the agent through environment detection, model support checking, format selection, and execution via the launcher or manual SLURM/Docker/bare GPU paths. Includes handling for unlisted models with custom module patching. This feature is in early testing — use with caution.
1616

17+
**Backward Breaking Changes**
18+
19+
- The ``quant_cfg`` field in quantization configs is now an **ordered list** of ``QuantizerCfgEntry`` dicts instead of a flat dictionary. Each entry specifies a ``quantizer_name`` wildcard, an optional ``parent_class`` filter, a ``cfg`` dict of quantizer attributes, and/or an ``enable`` flag. Entries are applied in list order with later entries overriding earlier ones. The old dict-based format is still accepted and automatically converted via ``normalize_quant_cfg_list()``, but now emits a ``DeprecationWarning``; new code should use the list format. All built-in configs (e.g. ``FP8_DEFAULT_CFG``, ``INT4_AWQ_CFG``, ``NVFP4_DEFAULT_CFG``), examples, and YAML recipes have been updated. See the :ref:`quant-cfg` documentation for the new format reference and migration guide.
20+
1721
**Bug Fixes**
1822

1923
- Fix Minitron pruning (``mcore_minitron``) for MoE models. Importance estimation hooks were incorrectly registered for MoE modules and NAS step was hanging before this.

docs/source/guides/1_quantization.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ Below, you can find the documentation for the quantization toolkit in ModelOpt:
1919
./_basic_quantization.rst
2020
./_choosing_quant_methods.rst
2121
./_pytorch_quantization.rst
22+
./_quant_cfg.rst
2223
./_customized_model_quantization.rst
2324
./_compress_quantized_models.rst
2425
./_onnx_quantization.rst

docs/source/guides/_pytorch_quantization.rst

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -237,14 +237,16 @@ For debugging purposes or simple customizations, you can modify an existing conf
237237

238238
.. code-block:: python
239239
240-
# Create a copy of the default INT8 configuration
241-
config = mtq.INT8_DEFAULT_CFG.copy()
240+
import copy
242241
243-
# Disable input quantizers for all layers
244-
config["quant_cfg"]["*input_quantizer"]["enable"] = False
242+
# Create a deep copy of the default INT8 configuration
243+
config = copy.deepcopy(mtq.INT8_DEFAULT_CFG)
244+
245+
# Disable input quantizers for all layers (appended last, so it takes precedence)
246+
config["quant_cfg"].append({"quantizer_name": "*input_quantizer", "enable": False})
245247
246248
# Disable all quantizers for layers matching the pattern "layer1.*"
247-
config["quant_cfg"]["*layer1.*"] = {"enable": False}
249+
config["quant_cfg"].append({"quantizer_name": "*layer1.*", "enable": False})
248250
249251
Advanced Configuration Creation
250252
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
@@ -253,18 +255,23 @@ For exploring new quantization recipes, you can compose a completely new configu
253255

254256
.. code-block:: python
255257
258+
from modelopt.torch.quantization.config import _default_disabled_quantizer_cfg
259+
256260
# Custom configuration for INT4 block-wise weights and INT8 dynamic activations
257261
MY_CUSTOM_CONFIG = {
258-
"quant_cfg": {
262+
"quant_cfg": [
263+
# Disable all quantizers by default, then enable selectively
264+
{"quantizer_name": "*", "enable": False},
265+
259266
# Configure weight quantizers with 4-bit precision and 128-element blocks
260-
"*weight_quantizer": {"num_bits": 4, "block_sizes": {-1: 128}, "enable": True},
267+
{"quantizer_name": "*weight_quantizer", "cfg": {"num_bits": 4, "block_sizes": {-1: 128}}, "enable": True},
261268
262269
# Configure input quantizers with 8-bit dynamic quantization
263-
"*input_quantizer": {"num_bits": 8, "type": "dynamic", "block_sizes": {-1: None}},
270+
{"quantizer_name": "*input_quantizer", "cfg": {"num_bits": 8, "type": "dynamic", "block_sizes": {-1: None}}},
264271
265272
# Include default disabled quantizer configurations
266-
**_default_disabled_quantizer_cfg,
267-
},
273+
*_default_disabled_quantizer_cfg,
274+
],
268275
"algorithm": "max",
269276
}
270277
@@ -394,8 +401,10 @@ You can specify ``custom_calib`` as ``algorithm`` in ``quant_cfg`` to use it. He
394401
395402
# create quantization configuration with "custom_calib" method
396403
quant_cfg = {
397-
'quant_cfg': {'*weight_quantizer': ..},
398-
'algorithm': {"method": 'custom_calib'},
404+
'quant_cfg': [
405+
{"quantizer_name": "*weight_quantizer", "cfg": {...}},
406+
],
407+
'algorithm': {"method": 'custom_calib'},
399408
}
400409
401410

0 commit comments

Comments
 (0)