Skip to content

Commit 00fa5bd

Browse files
authored
ModelOpt Framework, Recipe Lib, converting subset of existing recipes 1/N (#1000)
### What does this PR do? 1. start a new config system using yaml/yml files. 2. add a new top level package: modelopt_recipes I want it to be a top level package so we can make it clear that the modelopt package holds the code, this new package holds recipes 3. implement some of the existing quantization recipes using the new config system as model agnostic general recipes, but not actually in use. these recipes sit inside modelopt_recipes/general/ptq/... 4. make sure the configs from the new config system match the exisiting configs 5. extend the hf_ptq script to enable recipe based PTQ 8. testted hf_ptq using both builtin and extenal config file. example script: ### Usage ```bash python examples/llm_ptq/hf_ptq.py \ --model Qwen/Qwen3-8B \ --recipe general/ptq/fp8_default-fp8_kv \ ... ``` ### Testing ```bash python examples/llm_ptq/hf_ptq.py \ --model Qwen/Qwen3-8B \ --recipe general/ptq/fp8_default-fp8_kv \ --export_path=fp8_default-fp8_kv \ --calib_size=16 \ --batch_size=0 \ --trust_remote_code \ --export_fmt=hf ``` ### Before your PR is "*Ready for review*" Make sure you read and follow [Contributor guidelines](https://github.com/NVIDIA/Model-Optimizer/blob/main/CONTRIBUTING.md) and your commits are signed (`git commit -s -S`). Make sure you read and follow the [Security Best Practices](https://github.com/NVIDIA/Model-Optimizer/blob/main/SECURITY.md#security-coding-practices-for-contributors) (e.g. avoiding hardcoded `trust_remote_code=True`, `torch.load(..., weights_only=False)`, `pickle`, etc.). - Is this change backward compatible?: ✅ / ❌ / N/A <!--- If ❌, explain why. --> - If you copied code from any other sources or added a new PIP dependency, did you follow guidance in `CONTRIBUTING.md`: ✅ / ❌ / N/A <!--- Mandatory --> - Did you write any new necessary tests?: ✅ / ❌ / N/A <!--- Mandatory for new features or examples. --> - Did you update [Changelog](https://github.com/NVIDIA/Model-Optimizer/blob/main/CHANGELOG.rst)?: ✅ / ❌ / N/A <!--- Only for new features, API changes, critical bug fixes or backward incompatible changes. --> ### Additional Information <!-- E.g. related issue. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Recipe-driven PTQ workflows via YAML recipes and new recipe loader; CLI gains a --recipe option and --pyt_ckpt_path renamed to --model. * Many new PTQ recipe and config presets (FP8, INT4/INT8, NVFP4, MXFPx, KV-cache variants) and improved runtime config loading/merging. * **Documentation** * Added READMEs describing recipe/config layout. * **Tests** * New unit tests covering config loading, inheritance and recipe loading. * **Chores** * Added YAML/OmegaConf runtime support and packaging of recipe YAMLs. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Shengliang Xu <shengliangx@nvidia.com>
1 parent cb1ff32 commit 00fa5bd

File tree

21 files changed

+1015
-71
lines changed

21 files changed

+1015
-71
lines changed

examples/llm_ptq/example_utils.py

Lines changed: 3 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,6 @@
4545
except ImportError:
4646
snapshot_download = None
4747

48-
import modelopt.torch.quantization as mtq
4948
from modelopt.torch.utils.image_processor import BaseImageProcessor, MllamaImageProcessor
5049

5150
logger = logging.getLogger(__name__)
@@ -199,22 +198,13 @@ def calibrate_loop(_model):
199198

200199
def build_quant_cfg(
201200
qformat,
202-
kv_cache_qformat,
201+
quant_cfg,
203202
awq_block_size,
204203
model_type,
205-
quant_cfg_choices,
206-
kv_quant_cfg_choices,
207204
moe_calib_experts_ratio: float | None = None,
208205
) -> dict[str, Any]:
209-
quant_cfg = {}
210-
assert qformat in quant_cfg_choices, (
211-
f"Unsupported quantization format: {qformat} with {kv_cache_qformat} KV cache"
212-
)
213-
214-
quant_cfg = quant_cfg_choices[qformat]
215-
216-
if "awq" in qformat:
217-
quant_cfg = copy.deepcopy(quant_cfg_choices[qformat])
206+
quant_cfg = copy.deepcopy(quant_cfg)
207+
if "awq" in str(quant_cfg.get("algorithm")):
218208
weight_quantizer = quant_cfg["quant_cfg"]["*weight_quantizer"]
219209
if isinstance(weight_quantizer, list):
220210
weight_quantizer = weight_quantizer[0]
@@ -226,16 +216,6 @@ def build_quant_cfg(
226216
if qformat == "w4a8_awq" and model_type in ["gemma", "mpt"]:
227217
quant_cfg["algorithm"] = {"method": "awq_lite", "alpha_step": 1}
228218

229-
enable_quant_kv_cache = kv_cache_qformat != "none"
230-
print(f"{'Enable' if enable_quant_kv_cache else 'Disable'} KV cache quantization")
231-
232-
# Check if any bmm_quantizer is in the quant_cfg. If so, we need to enable the bmm_quantizer.
233-
if enable_quant_kv_cache:
234-
quant_cfg = mtq.update_quant_cfg_with_kv_cache_quant(
235-
quant_cfg,
236-
getattr(mtq, kv_quant_cfg_choices[kv_cache_qformat])["quant_cfg"],
237-
)
238-
239219
if moe_calib_experts_ratio:
240220
assert 0 < moe_calib_experts_ratio <= 1, "moe_calib_experts_ratio must be between 0 and 1"
241221
if isinstance(quant_cfg["algorithm"], str):

examples/llm_ptq/hf_ptq.py

Lines changed: 51 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
import modelopt.torch.opt as mto
5151
import modelopt.torch.quantization as mtq
5252
import modelopt.torch.sparsity as mts
53+
from modelopt.recipe import ModelOptPTQRecipe, load_recipe
5354
from modelopt.torch.export import (
5455
export_hf_checkpoint,
5556
export_speculative_decoding,
@@ -262,7 +263,7 @@ def auto_quantize(
262263
assert qformat_list, "No quantization formats provided"
263264
# Check if all provided quantization formats are supported
264265
assert all(
265-
args.qformat
266+
qformat
266267
in [
267268
"fp8",
268269
"int8_sq",
@@ -277,7 +278,7 @@ def auto_quantize(
277278
"nvfp4_omlp_only",
278279
"mxfp8",
279280
]
280-
for args.qformat in qformat_list
281+
for qformat in qformat_list
281282
), "One or more quantization formats provided are not supported for unified checkpoint export"
282283

283284
def loss_func(output, data):
@@ -548,9 +549,6 @@ def mono_quantize(
548549
print("Quantization will only be applied to the decoder (text generation) component")
549550

550551
if not model_is_already_quantized or calibration_only:
551-
if model_type == "gptoss" and args.qformat == "nvfp4_mlp_only":
552-
print("Applying nvfp4 quantization (MoE only) for gpt-oss")
553-
554552
# quantize the model
555553

556554
use_calibration = need_calibration(quant_cfg)
@@ -746,8 +744,6 @@ def pre_quantize(
746744
)
747745
else:
748746
generated_ids_before_ptq = full_model.generate(preview_input_ids, max_new_tokens=100)
749-
if model_type == "gptoss" and args.qformat == "nvfp4_mlp_only":
750-
print("Applying nvfp4 quantization (MoE only) for gpt-oss")
751747

752748
return preview_input_ids, generated_ids_before_ptq
753749

@@ -923,38 +919,42 @@ def quantize_main(
923919

924920
else:
925921
# mono quantization
926-
assert len(args.qformat.split(",")) == 1, (
927-
"Plain quantization supports only one quantization format."
928-
)
929922

930-
assert (
931-
args.qformat
932-
in [
933-
"int8_wo",
934-
"int4_awq",
935-
"fp8",
936-
"nvfp4",
937-
"nvfp4_awq",
938-
"nvfp4_mse",
939-
"w4a8_awq",
940-
"fp8_pb_wo",
941-
"w4a8_mxfp4_fp8",
942-
"nvfp4_mlp_only",
943-
"nvfp4_omlp_only",
944-
"mxfp8",
945-
]
946-
or args.kv_cache_qformat in KV_QUANT_CFG_CHOICES
947-
), f"Plain quantization format {args.qformat} not supported for HF export path"
948-
949-
quant_cfg = build_quant_cfg(
950-
args.qformat,
951-
args.kv_cache_qformat,
952-
args.awq_block_size,
953-
model_type,
954-
QUANT_CFG_CHOICES,
955-
KV_QUANT_CFG_CHOICES,
956-
args.moe_calib_experts_ratio,
957-
)
923+
if args.recipe is not None:
924+
print(f"Use recipe {args.recipe} for quantization")
925+
recipe = load_recipe(args.recipe)
926+
assert isinstance(recipe, ModelOptPTQRecipe), (
927+
f"Expected PTQ recipe, but got {type(recipe).__name__} from {args.recipe}"
928+
)
929+
quant_cfg = recipe.ptq_cfg
930+
931+
else:
932+
assert len(args.qformat.split(",")) == 1, (
933+
"Plain quantization supports only one quantization format."
934+
)
935+
936+
assert args.qformat in QUANT_CFG_CHOICES, (
937+
f"Unsupported quantization format: {args.qformat}, choices are: {list(QUANT_CFG_CHOICES.keys())}"
938+
)
939+
quant_cfg = QUANT_CFG_CHOICES[args.qformat]
940+
941+
quant_cfg = build_quant_cfg(
942+
args.qformat,
943+
quant_cfg,
944+
args.awq_block_size,
945+
model_type,
946+
args.moe_calib_experts_ratio,
947+
)
948+
949+
enable_quant_kv_cache = args.kv_cache_qformat != "none"
950+
print(f"{'Enable' if enable_quant_kv_cache else 'Disable'} KV cache quantization")
951+
952+
# Check if any bmm_quantizer is in the quant_cfg. If so, we need to enable the bmm_quantizer.
953+
if enable_quant_kv_cache:
954+
quant_cfg = mtq.update_quant_cfg_with_kv_cache_quant(
955+
quant_cfg,
956+
getattr(mtq, KV_QUANT_CFG_CHOICES[args.kv_cache_qformat])["quant_cfg"],
957+
)
958958

959959
# Exclude MTP layers from quantization if detected (e.g., GLM-4.7's layer 92)
960960
# These layers are typically speculative decoding layers that should be exported as-is
@@ -1013,9 +1013,21 @@ def parse_args() -> argparse.Namespace:
10131013
parser = argparse.ArgumentParser(description=__doc__)
10141014
parser.add_argument(
10151015
"--pyt_ckpt_path",
1016-
help="Specify where the PyTorch checkpoint path is",
1016+
"--model",
1017+
help=(
1018+
"Model name or path to the PyTorch checkpoint to be quantized. "
1019+
"Can be a local path or a Huggingface model name."
1020+
),
10171021
required=True,
10181022
)
1023+
parser.add_argument(
1024+
"--recipe",
1025+
help=(
1026+
"PTQ recipe YAML file or name without suffix (e.g. general/ptq/nvfp4_default-fp8_kv)."
1027+
),
1028+
default=None,
1029+
)
1030+
10191031
parser.add_argument("--device", default="cuda")
10201032
parser.add_argument(
10211033
"--qformat",

examples/llm_ptq/multinode_ptq.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -327,16 +327,25 @@ def main(args):
327327
trust_remote_code=args.trust_remote_code,
328328
)
329329

330-
# Build quantization config
330+
quant_cfg = QUANT_CFG_CHOICES[args.qformat]
331+
331332
quant_cfg = build_quant_cfg(
332333
args.qformat,
333-
args.kv_cache_qformat,
334+
quant_cfg,
334335
args.awq_block_size,
335336
model_type,
336-
QUANT_CFG_CHOICES,
337-
KV_QUANT_CFG_CHOICES,
338337
)
339338

339+
enable_quant_kv_cache = args.kv_cache_qformat != "none"
340+
print(f"{'Enable' if enable_quant_kv_cache else 'Disable'} KV cache quantization")
341+
342+
# Check if any bmm_quantizer is in the quant_cfg. If so, we need to enable the bmm_quantizer.
343+
if enable_quant_kv_cache:
344+
quant_cfg = mtq.update_quant_cfg_with_kv_cache_quant(
345+
quant_cfg,
346+
getattr(mtq, KV_QUANT_CFG_CHOICES[args.kv_cache_qformat])["quant_cfg"],
347+
)
348+
340349
# Quantize the model
341350
if accelerator.is_main_process:
342351
print("Starting quantization...")

modelopt/recipe/__init__.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Module for the ModelOpt recipe lib.
17+
18+
``modelopt.recipe`` contains tooling to:
19+
20+
* load and store model optimization recipes
21+
* (TODO) utilities to manipulate the recipes, such as merging multiple recipes together, or
22+
overriding some fields in a recipe with user-provided values.
23+
24+
"""
25+
26+
from .config import *
27+
from .loader import *

modelopt/recipe/_config_loader.py

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""YAML config loading utilities.
17+
18+
This module is intentionally free of ``modelopt.torch`` imports so that
19+
``modelopt.torch.quantization.config`` can import :func:`load_config` without
20+
triggering a circular import through ``modelopt.recipe.loader``.
21+
"""
22+
23+
from importlib.resources import files
24+
25+
try:
26+
from importlib.resources.abc import Traversable
27+
except ImportError: # Python < 3.11
28+
from importlib.abc import Traversable
29+
import re
30+
from pathlib import Path
31+
from typing import Any
32+
33+
import yaml
34+
35+
# Root to all built-in recipes. Users can create own recipes.
36+
BUILTIN_RECIPES_LIB = files("modelopt_recipes")
37+
38+
_EXMY_RE = re.compile(r"^[Ee](\d+)[Mm](\d+)$")
39+
_EXMY_KEYS = frozenset({"num_bits", "scale_bits"})
40+
41+
42+
def _parse_exmy_num_bits(obj: Any) -> Any:
43+
"""Recursively convert ``ExMy`` strings in ``num_bits`` / ``scale_bits`` to ``(x, y)`` tuples."""
44+
if isinstance(obj, dict):
45+
return {
46+
k: (
47+
_parse_exmy(v)
48+
if k in _EXMY_KEYS and isinstance(v, str)
49+
else _parse_exmy_num_bits(v)
50+
)
51+
for k, v in obj.items()
52+
}
53+
if isinstance(obj, list):
54+
return [_parse_exmy_num_bits(item) for item in obj]
55+
return obj
56+
57+
58+
def _parse_exmy(s: str) -> tuple[int, int] | str:
59+
m = _EXMY_RE.match(s)
60+
if m:
61+
return (int(m.group(1)), int(m.group(2)))
62+
return s
63+
64+
65+
def load_config(config_file: str | Path | Traversable) -> dict[str, Any]:
66+
"""Load a config yaml.
67+
68+
config_file: Path to a config yaml file. The path suffix can be omitted.
69+
"""
70+
paths_to_check: list[Path | Traversable] = []
71+
if isinstance(config_file, str):
72+
if not config_file.endswith(".yml") and not config_file.endswith(".yaml"):
73+
paths_to_check.append(Path(f"{config_file}.yml"))
74+
paths_to_check.append(Path(f"{config_file}.yaml"))
75+
paths_to_check.append(BUILTIN_RECIPES_LIB.joinpath(f"{config_file}.yml"))
76+
paths_to_check.append(BUILTIN_RECIPES_LIB.joinpath(f"{config_file}.yaml"))
77+
else:
78+
paths_to_check.append(Path(config_file))
79+
paths_to_check.append(BUILTIN_RECIPES_LIB.joinpath(config_file))
80+
elif isinstance(config_file, Path):
81+
if config_file.suffix in (".yml", ".yaml"):
82+
paths_to_check.append(config_file)
83+
if not config_file.is_absolute():
84+
paths_to_check.append(BUILTIN_RECIPES_LIB.joinpath(str(config_file)))
85+
else:
86+
paths_to_check.append(Path(f"{config_file}.yml"))
87+
paths_to_check.append(Path(f"{config_file}.yaml"))
88+
if not config_file.is_absolute():
89+
paths_to_check.append(BUILTIN_RECIPES_LIB.joinpath(f"{config_file}.yml"))
90+
paths_to_check.append(BUILTIN_RECIPES_LIB.joinpath(f"{config_file}.yaml"))
91+
elif isinstance(config_file, Traversable):
92+
paths_to_check.append(config_file)
93+
else:
94+
raise ValueError(f"Invalid config file of {config_file}")
95+
96+
config_path = None
97+
for path in paths_to_check:
98+
if path.is_file():
99+
config_path = path
100+
break
101+
if not config_path:
102+
raise ValueError(
103+
f"Cannot find config file of {config_file}, paths checked: {paths_to_check}"
104+
)
105+
106+
_raw = yaml.safe_load(config_path.read_text(encoding="utf-8"))
107+
if _raw is None:
108+
return {}
109+
if not isinstance(_raw, dict):
110+
raise ValueError(
111+
f"Config file {config_path} must contain a YAML mapping, got {type(_raw).__name__}"
112+
)
113+
return _parse_exmy_num_bits(_raw)

0 commit comments

Comments
 (0)