Skip to content

Commit 5baba0b

Browse files
committed
quant config
Signed-off-by: Shengliang Xu <shengliangx@nvidia.com>
1 parent dc67001 commit 5baba0b

5 files changed

Lines changed: 173 additions & 148 deletions

File tree

modelopt/recipe/_config_loader.py

Lines changed: 126 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,8 @@ def _parse_exmy(s: str) -> tuple[int, int] | str:
6262
return s
6363

6464

65-
def load_config(config_file: str | Path | Traversable) -> dict[str, Any] | list[Any]:
66-
"""Load a config yaml.
65+
def _load_raw_config(config_file: str | Path | Traversable) -> dict[str, Any] | list[Any]:
66+
"""Load a config YAML without resolving ``$import`` references.
6767
6868
config_file: Path to a config yaml file. The path suffix can be omitted.
6969
"""
@@ -141,3 +141,127 @@ def load_config(config_file: str | Path | Traversable) -> dict[str, Any] | list[
141141
f"Config file {config_path} must contain a YAML mapping or list, got {type(_raw).__name__}"
142142
)
143143
return _parse_exmy_num_bits(_raw)
144+
145+
146+
# ---------------------------------------------------------------------------
147+
# $import resolution
148+
# ---------------------------------------------------------------------------
149+
150+
_IMPORT_KEY = "$import"
151+
152+
153+
def _resolve_imports(
154+
data: dict[str, Any], _loading: frozenset[str] | None = None
155+
) -> dict[str, Any]:
156+
"""Resolve the ``imports`` section and ``$import`` references.
157+
158+
See ``modelopt.recipe.loader`` module docstring for the full specification.
159+
This function lives in ``_config_loader`` (not ``loader``) so that it can be
160+
used from ``modelopt.torch.quantization.config`` without circular imports.
161+
"""
162+
imports_dict = data.pop("imports", None)
163+
if not imports_dict:
164+
return data
165+
166+
if not isinstance(imports_dict, dict):
167+
raise ValueError(
168+
f"'imports' must be a dict mapping names to config paths, got: {type(imports_dict).__name__}"
169+
)
170+
171+
if _loading is None:
172+
_loading = frozenset()
173+
174+
# Build name → config mapping (recursively resolve nested imports)
175+
import_map: dict[str, Any] = {}
176+
for name, config_path in imports_dict.items():
177+
if not config_path:
178+
raise ValueError(f"Import {name!r} has an empty config path.")
179+
if config_path in _loading:
180+
raise ValueError(
181+
f"Circular import detected: {config_path!r} is already being loaded. "
182+
f"Import chain: {sorted(_loading)}"
183+
)
184+
snippet = _load_raw_config(config_path)
185+
if isinstance(snippet, dict) and "imports" in snippet:
186+
snippet = _resolve_imports(snippet, _loading | {config_path})
187+
# Unwrap _list_content (multi-document YAML: imports + list content)
188+
if isinstance(snippet, dict) and "_list_content" in snippet:
189+
snippet = snippet["_list_content"]
190+
import_map[name] = snippet
191+
192+
def _lookup(ref_name: str, context: str) -> Any:
193+
if ref_name not in import_map:
194+
raise ValueError(
195+
f"Unknown $import reference {ref_name!r} in {context}. "
196+
f"Available imports: {list(import_map.keys())}"
197+
)
198+
return import_map[ref_name]
199+
200+
def _resolve_list(entries: list[Any]) -> list[Any]:
201+
"""Resolve $import markers in a list of entries."""
202+
resolved: list[Any] = []
203+
for entry in entries:
204+
if isinstance(entry, dict) and _IMPORT_KEY in entry:
205+
if len(entry) > 1:
206+
raise ValueError(
207+
f"$import must be the only key in the dict, got extra keys: "
208+
f"{sorted(k for k in entry if k != _IMPORT_KEY)}"
209+
)
210+
imported = _lookup(entry[_IMPORT_KEY], "list entry")
211+
if not isinstance(imported, list):
212+
raise ValueError(
213+
f"$import {entry[_IMPORT_KEY]!r} in list must resolve to a "
214+
f"list, got {type(imported).__name__}."
215+
)
216+
resolved.extend(imported)
217+
elif (
218+
isinstance(entry, dict)
219+
and isinstance(entry.get("cfg"), dict)
220+
and _IMPORT_KEY in entry["cfg"]
221+
):
222+
ref = entry["cfg"].pop(_IMPORT_KEY)
223+
inline_keys = dict(entry["cfg"])
224+
ref_names = ref if isinstance(ref, list) else [ref]
225+
226+
merged: dict[str, Any] = {}
227+
for rname in ref_names:
228+
snippet = _lookup(rname, f"cfg of {entry}")
229+
if not isinstance(snippet, dict):
230+
raise ValueError(
231+
f"$import {rname!r} in cfg must resolve to a dict, "
232+
f"got {type(snippet).__name__}."
233+
)
234+
merged.update(snippet)
235+
236+
merged.update(inline_keys)
237+
entry["cfg"] = merged
238+
resolved.append(entry)
239+
else:
240+
resolved.append(entry)
241+
return resolved
242+
243+
# Resolve in quant_cfg (top-level or nested under quantize)
244+
for container in [data, data.get("quantize", {})]:
245+
if isinstance(container, dict):
246+
quant_cfg = container.get("quant_cfg")
247+
if isinstance(quant_cfg, list):
248+
container["quant_cfg"] = _resolve_list(quant_cfg)
249+
250+
# Resolve in _list_content (multi-document snippets)
251+
if "_list_content" in data:
252+
data["_list_content"] = _resolve_list(data["_list_content"])
253+
254+
return data
255+
256+
257+
def load_config(config_path: str | Path | Traversable) -> dict[str, Any] | list[Any]:
258+
"""Load a YAML config and resolve all ``$import`` references.
259+
260+
This is the primary config loading entry point. It loads the YAML file,
261+
resolves any ``imports`` / ``$import`` directives, and returns the final
262+
config dict or list.
263+
"""
264+
data = _load_raw_config(config_path)
265+
if isinstance(data, dict) and "imports" in data:
266+
data = _resolve_imports(data)
267+
return data

modelopt/recipe/loader.py

Lines changed: 4 additions & 131 deletions
Original file line numberDiff line numberDiff line change
@@ -22,139 +22,12 @@
2222
from pathlib import Path
2323
from typing import Any
2424

25-
from ._config_loader import BUILTIN_RECIPES_LIB, load_config
25+
from ._config_loader import BUILTIN_RECIPES_LIB, _load_raw_config, _resolve_imports, load_config
2626
from .config import ModelOptPTQRecipe, ModelOptRecipeBase, RecipeType
2727

2828
__all__ = ["load_config", "load_recipe"]
2929

3030

31-
_IMPORT_KEY = "$import"
32-
33-
34-
def _resolve_imports(
35-
data: dict[str, Any], _loading: frozenset[str] | None = None
36-
) -> dict[str, Any]:
37-
"""Resolve the ``imports`` section and ``$import`` references in a recipe.
38-
39-
An ``imports`` block is a dict mapping short names to config file paths::
40-
41-
imports:
42-
fp8: configs/numerics/fp8
43-
nvfp4: configs/numerics/nvfp4_dynamic
44-
45-
References use the explicit ``$import`` marker so they are never confused
46-
with literal string values::
47-
48-
quant_cfg:
49-
- $import: base_disable_all # entire entry replaced (or list spliced)
50-
- quantizer_name: '*weight_quantizer'
51-
cfg:
52-
$import: fp8 # cfg value replaced
53-
54-
Resolution is **recursive**: an imported snippet may itself contain an
55-
``imports`` section. Circular imports are detected and raise ``ValueError``.
56-
"""
57-
imports_dict = data.pop("imports", None)
58-
if not imports_dict:
59-
return data
60-
61-
if not isinstance(imports_dict, dict):
62-
raise ValueError(
63-
f"'imports' must be a dict mapping names to config paths, got: {type(imports_dict).__name__}"
64-
)
65-
66-
if _loading is None:
67-
_loading = frozenset()
68-
69-
# Build name → config mapping (recursively resolve nested imports)
70-
import_map: dict[str, Any] = {}
71-
for name, config_path in imports_dict.items():
72-
if not config_path:
73-
raise ValueError(f"Import {name!r} has an empty config path.")
74-
if config_path in _loading:
75-
raise ValueError(
76-
f"Circular import detected: {config_path!r} is already being loaded. "
77-
f"Import chain: {sorted(_loading)}"
78-
)
79-
snippet = load_config(config_path)
80-
if isinstance(snippet, dict) and "imports" in snippet:
81-
snippet = _resolve_imports(snippet, _loading | {config_path})
82-
# Unwrap _list_content (multi-document YAML: imports + list content)
83-
if isinstance(snippet, dict) and "_list_content" in snippet:
84-
snippet = snippet["_list_content"]
85-
import_map[name] = snippet
86-
87-
def _lookup(ref_name: str, context: str) -> Any:
88-
if ref_name not in import_map:
89-
raise ValueError(
90-
f"Unknown $import reference {ref_name!r} in {context}. "
91-
f"Available imports: {list(import_map.keys())}"
92-
)
93-
return import_map[ref_name]
94-
95-
def _resolve_list(entries: list[Any]) -> list[Any]:
96-
"""Resolve $import markers in a list of quant_cfg-style entries."""
97-
resolved: list[Any] = []
98-
for entry in entries:
99-
if isinstance(entry, dict) and _IMPORT_KEY in entry:
100-
# {$import: name} → splice imported list
101-
if len(entry) > 1:
102-
raise ValueError(
103-
f"$import must be the only key in the dict, got extra keys: "
104-
f"{sorted(k for k in entry if k != _IMPORT_KEY)}"
105-
)
106-
imported = _lookup(entry[_IMPORT_KEY], "list entry")
107-
if not isinstance(imported, list):
108-
raise ValueError(
109-
f"$import {entry[_IMPORT_KEY]!r} in list must resolve to a "
110-
f"list, got {type(imported).__name__}."
111-
)
112-
resolved.extend(imported)
113-
elif (
114-
isinstance(entry, dict)
115-
and isinstance(entry.get("cfg"), dict)
116-
and _IMPORT_KEY in entry["cfg"]
117-
):
118-
# cfg: {$import: name_or_list, ...inline} → import then override
119-
#
120-
# Precedence (lowest → highest):
121-
# 1. Imports in list order (later imports override earlier)
122-
# 2. Inline keys (override all imports)
123-
ref = entry["cfg"].pop(_IMPORT_KEY)
124-
inline_keys = dict(entry["cfg"])
125-
ref_names = ref if isinstance(ref, list) else [ref]
126-
127-
merged: dict[str, Any] = {}
128-
for name in ref_names:
129-
snippet = _lookup(name, f"cfg of {entry}")
130-
if not isinstance(snippet, dict):
131-
raise ValueError(
132-
f"$import {name!r} in cfg must resolve to a dict, "
133-
f"got {type(snippet).__name__}."
134-
)
135-
merged.update(snippet)
136-
137-
merged.update(inline_keys)
138-
entry["cfg"] = merged
139-
resolved.append(entry)
140-
else:
141-
resolved.append(entry)
142-
return resolved
143-
144-
# Resolve $import references in quant_cfg entries
145-
quantize = data.get("quantize")
146-
if isinstance(quantize, dict):
147-
quant_cfg = quantize.get("quant_cfg")
148-
if isinstance(quant_cfg, list):
149-
quantize["quant_cfg"] = _resolve_list(quant_cfg)
150-
151-
# Resolve $import references in _list_content (multi-document snippets)
152-
if "_list_content" in data:
153-
data["_list_content"] = _resolve_list(data["_list_content"])
154-
155-
return data
156-
157-
15831
def _resolve_recipe_path(recipe_path: str | Path | Traversable) -> Path | Traversable:
15932
"""Resolve a recipe path, checking the built-in library first then the filesystem.
16033
@@ -214,7 +87,7 @@ def _load_recipe_from_file(recipe_file: Path | Traversable) -> ModelOptRecipeBas
21487
The file must contain a ``metadata`` section with at least ``recipe_type``,
21588
plus a ``quant_cfg`` mapping and an optional ``algorithm`` for PTQ recipes.
21689
"""
217-
raw = load_config(recipe_file)
90+
raw = _load_raw_config(recipe_file)
21891
assert isinstance(raw, dict), f"Recipe file {recipe_file} must be a YAML mapping."
21992
data = _resolve_imports(raw)
22093

@@ -247,7 +120,7 @@ def _load_recipe_from_dir(recipe_dir: Path | Traversable) -> ModelOptRecipeBase:
247120
f"Cannot find a recipe descriptor in {recipe_dir}. Looked for: recipe.yml, recipe.yaml"
248121
)
249122

250-
recipe_data = load_config(recipe_file)
123+
recipe_data = _load_raw_config(recipe_file)
251124
assert isinstance(recipe_data, dict), f"Recipe file {recipe_file} must be a YAML mapping."
252125
metadata = recipe_data.get("metadata", {})
253126
recipe_type = metadata.get("recipe_type")
@@ -266,7 +139,7 @@ def _load_recipe_from_dir(recipe_dir: Path | Traversable) -> ModelOptRecipeBase:
266139
f"Cannot find quantize in {recipe_dir}. Looked for: quantize.yml, quantize.yaml"
267140
)
268141
# Resolve imports: imports are in recipe.yml, quantize data is separate
269-
quantize_data = load_config(quantize_file)
142+
quantize_data = _load_raw_config(quantize_file)
270143
assert isinstance(quantize_data, dict), f"{quantize_file} must be a YAML mapping."
271144
combined: dict[str, Any] = {"quantize": quantize_data}
272145
imports = recipe_data.get("imports")

modelopt/torch/quantization/config.py

Lines changed: 2 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,7 @@
157157
from pydantic import ValidationInfo, field_validator, model_validator
158158
from typing_extensions import Required, TypedDict
159159

160+
from modelopt.recipe._config_loader import load_config
160161
from modelopt.torch.opt.config import ModeloptBaseConfig, ModeloptField
161162
from modelopt.torch.utils.network import ConstructorLike
162163

@@ -272,21 +273,7 @@ def find_quant_cfg_entry_by_path(
272273
"algorithm": "max",
273274
}
274275

275-
FP8_DEFAULT_CFG = {
276-
"quant_cfg": [
277-
*_base_disable_all,
278-
{
279-
"quantizer_name": "*weight_quantizer",
280-
"cfg": {"num_bits": (4, 3), "axis": None},
281-
},
282-
{
283-
"quantizer_name": "*input_quantizer",
284-
"cfg": {"num_bits": (4, 3), "axis": None},
285-
},
286-
*_default_disabled_quantizer_cfg,
287-
],
288-
"algorithm": "max",
289-
}
276+
FP8_DEFAULT_CFG: dict[str, Any] = load_config("configs/ptq/presets/fp8_default")
290277

291278
MAMBA_MOE_FP8_AGGRESSIVE_CFG = {
292279
"quant_cfg": [
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# PTQ Preset Configs
2+
3+
This directory holds preset quantization configurations that serve as the
4+
single source of truth for the hardcoded `*_CFG` dicts in
5+
`modelopt.torch.quantization.config` (e.g., `FP8_DEFAULT_CFG`).
6+
7+
Each preset is a complete, self-contained config with `algorithm` and
8+
`quant_cfg` — ready to pass directly to `mtq.quantize()`. Presets compose
9+
from the reusable snippets in `configs/numerics/` and `configs/ptq/` via
10+
the `$import` system.
11+
12+
When adding a new preset, use existing snippets where possible and keep
13+
the YAML as the authoritative definition — the Python config should load
14+
from here rather than hardcoding the dict.
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
# FP8 per-tensor weight and activation (W8A8), max calibration.
17+
# Equivalent to the hardcoded FP8_DEFAULT_CFG in config.py.
18+
imports:
19+
base_disable_all: configs/ptq/base_disable_all
20+
w8a8: configs/ptq/w8a8_fp8_fp8
21+
default_disabled: configs/ptq/default_disabled_quantizers
22+
23+
algorithm: max
24+
quant_cfg:
25+
- $import: base_disable_all
26+
- $import: w8a8
27+
- $import: default_disabled

0 commit comments

Comments
 (0)