Skip to content

Commit f5650bd

Browse files
authored
Schematize config loading and quantizer config entries (#1405)
### What does this PR do? This PR makes ModelOpt config loading schema-aware and moves `quant_cfg` entries from `TypedDict` validation to Pydantic validation. Key changes: - `load_config()` now returns validated schema instances when a schema is provided through `schema_type=...` or declared with `# modelopt-schema:`. - Without a schema, it still returns the raw resolved dict/list. - With a schema, it returns the validated schema object, such as `QuantizeConfig` or `list[QuantizerCfgEntry]`. - `QuantizerCfgEntry` is now a `ModeloptBaseConfig` Pydantic model. - The “must specify `cfg`, `enable`, or both” rule is enforced wherever entries are constructed. - Enabled entries must provide non-empty `cfg` values when `cfg` is present. - Existing mapping-style access like `entry["cfg"]` and `entry.get("enable")` continues to work. - `RecipeMetadataConfig` is now a `ModeloptBaseConfig`, so recipe metadata uses the same schema validation path. - Recipe loading now delegates shape validation to `load_config()` instead of manually checking loaded dicts. - `normalize_quant_cfg_list()` now accepts: - `Sequence[QuantizerCfgEntry]` - `Sequence[Mapping[str, Any]]` - legacy flat mapping configs, with deprecation warnings - Public quantization config constants remain plain dict/list structures for backward compatibility, even when they are built from schema-validated YAML snippets. ### Behavior changes - `load_config()` may now return a schema instance instead of a raw `dict`/`list` when a schema is available. Callers that checked `isinstance(result, dict)` should use `isinstance(result, Mapping)` or check the expected schema type. - Normalized `quant_cfg` entries are now `QuantizerCfgEntry` objects internally. Mapping-style access is preserved, but direct equality checks against literal dicts should use `entry.model_dump()`. ### Testing - `python -m pytest tests/unit/recipe/test_loader.py` - `python -m pytest tests/unit/torch/quantization/test_config_validation.py` - `ruff check` / `ruff format --check` on touched files - Commit-time pre-commit hooks passed --------- Signed-off-by: Shengliang Xu <shengliangx@nvidia.com>
1 parent 7038dec commit f5650bd

11 files changed

Lines changed: 510 additions & 291 deletions

File tree

modelopt/onnx/llm_export_utils/quantization_utils.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -69,9 +69,7 @@ def get_quant_config(precision, lm_head_precision="fp16"):
6969
else:
7070
raise ValueError(f"Unsupported precision: {precision}")
7171

72-
quant_cfg_list: list = [
73-
e for e in quant_cfg["quant_cfg"] if isinstance(e, dict) and "quantizer_name" in e
74-
]
72+
quant_cfg_list: list = [e for e in quant_cfg["quant_cfg"] if "quantizer_name" in e]
7573

7674
if lm_head_precision == "fp8":
7775
quant_cfg_list.append(

modelopt/recipe/config.py

Lines changed: 24 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,10 @@
2020
import warnings
2121
from enum import Enum
2222

23-
from pydantic import field_validator, model_validator
24-
from typing_extensions import NotRequired, TypedDict
23+
from pydantic import Field, model_validator
2524

2625
from modelopt.torch.opt.config import ModeloptBaseConfig, ModeloptField
27-
from modelopt.torch.quantization.config import QuantizeConfig
26+
from modelopt.torch.quantization.config import QuantizeConfig # noqa: TC001
2827
from modelopt.torch.speculative.config import DFlashConfig, EagleConfig, MedusaConfig
2928
from modelopt.torch.speculative.plugins.hf_training_args import DataArguments as SpecDataArgs
3029
from modelopt.torch.speculative.plugins.hf_training_args import ModelArguments as SpecModelArgs
@@ -43,14 +42,21 @@ class RecipeType(str, Enum):
4342
# QAT = "qat" # Not implemented yet, will be added in the future.
4443

4544

46-
class RecipeMetadataConfig(TypedDict):
47-
"""YAML shape of the recipe metadata section."""
45+
_DEFAULT_RECIPE_DESCRIPTION = "Model optimization recipe."
4846

49-
recipe_type: RecipeType
50-
description: NotRequired[str]
5147

48+
class RecipeMetadataConfig(ModeloptBaseConfig):
49+
"""YAML shape of the recipe metadata section."""
5250

53-
_DEFAULT_RECIPE_DESCRIPTION = "Model optimization recipe."
51+
recipe_type: RecipeType = Field(
52+
title="Recipe type",
53+
description="The type of the recipe (e.g. PTQ).",
54+
)
55+
description: str = ModeloptField(
56+
default=_DEFAULT_RECIPE_DESCRIPTION,
57+
title="Description",
58+
description="Human-readable description of the recipe.",
59+
)
5460

5561

5662
def _metadata_field(recipe_type: RecipeType):
@@ -69,45 +75,32 @@ class ModelOptRecipeBase(ModeloptBaseConfig):
6975
If a layer name matches ``"*output_layer*"``, the attributes will be replaced with ``{"enable": False}``.
7076
"""
7177

72-
metadata: RecipeMetadataConfig = ModeloptField(
73-
default={"recipe_type": RecipeType.PTQ, "description": _DEFAULT_RECIPE_DESCRIPTION},
78+
metadata: RecipeMetadataConfig = Field(
7479
title="Metadata",
75-
description="Recipe metadata containing the recipe type and description.",
76-
validate_default=True,
80+
description="Recipe metadata containing the recipe type and description. "
81+
"Required: a recipe without a ``metadata`` section is rejected so that a "
82+
"missing section can't silently fall back to a default recipe type.",
7783
)
7884

79-
@field_validator("metadata")
80-
@classmethod
81-
def validate_metadata(cls, metadata: RecipeMetadataConfig) -> RecipeMetadataConfig:
82-
"""Validate recipe metadata and fill defaults for optional fields."""
83-
if metadata["recipe_type"] not in RecipeType:
84-
raise ValueError(
85-
f"Unsupported recipe type: {metadata['recipe_type']}. "
86-
f"Only {list(RecipeType)} are currently supported."
87-
)
88-
return {"description": _DEFAULT_RECIPE_DESCRIPTION, **metadata}
89-
9085
@property
9186
def recipe_type(self) -> RecipeType:
9287
"""Return the recipe type from metadata."""
93-
return self.metadata["recipe_type"]
88+
return self.metadata.recipe_type
9489

9590
@property
9691
def description(self) -> str:
9792
"""Return the recipe description from metadata."""
98-
return self.metadata.get("description", _DEFAULT_RECIPE_DESCRIPTION)
93+
return self.metadata.description
9994

10095

10196
class ModelOptPTQRecipe(ModelOptRecipeBase):
10297
"""Our config class for PTQ recipes."""
10398

104-
metadata: RecipeMetadataConfig = _metadata_field(RecipeType.PTQ)
105-
106-
quantize: QuantizeConfig = ModeloptField(
107-
default=QuantizeConfig(),
99+
quantize: QuantizeConfig = Field(
108100
title="PTQ config",
109-
description="PTQ config containing quant_cfg and algorithm.",
110-
validate_default=True,
101+
description="PTQ config containing quant_cfg and algorithm. Required: a PTQ "
102+
"recipe without a ``quantize`` section is rejected so that a missing section "
103+
"can't silently fall back to the default INT8 config.",
111104
)
112105

113106

modelopt/recipe/loader.py

Lines changed: 51 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,6 @@
2929

3030
from .config import (
3131
RECIPE_TYPE_TO_CLASS,
32-
ModelOptDFlashRecipe,
33-
ModelOptEagleRecipe,
34-
ModelOptMedusaRecipe,
3532
ModelOptPTQRecipe,
3633
ModelOptRecipeBase,
3734
RecipeMetadataConfig,
@@ -40,6 +37,16 @@
4037

4138
__all__ = ["load_config", "load_recipe"]
4239

40+
# Each recipe type's mandatory top-level body section. Checked at the loader level (on the
41+
# raw YAML, before pydantic fills in defaults) so the user sees a clear "PTQ recipe file X
42+
# must contain 'quantize'" instead of pydantic's generic missing-field error.
43+
_REQUIRED_SECTION_PER_RECIPE_TYPE: dict[RecipeType, str] = {
44+
RecipeType.PTQ: "quantize",
45+
RecipeType.SPECULATIVE_EAGLE: "eagle",
46+
RecipeType.SPECULATIVE_DFLASH: "dflash",
47+
RecipeType.SPECULATIVE_MEDUSA: "medusa",
48+
}
49+
4350

4451
def _resolve_recipe_path(recipe_path: str | Path | Traversable) -> Path | Traversable:
4552
"""Resolve a recipe path, checking the built-in library first then the filesystem.
@@ -148,63 +155,48 @@ def _load_recipe_from_file(
148155
plus the algorithm-specific section (``quantize`` / ``eagle`` / ``dflash`` / ``medusa``).
149156
"""
150157
rtype = _peek_recipe_type(recipe_file)
151-
schema_type = RECIPE_TYPE_TO_CLASS.get(rtype) if rtype is not None else None
152-
data = load_config(recipe_file, schema_type=schema_type)
153-
if not isinstance(data, dict):
154-
raise ValueError(
155-
f"Recipe file {recipe_file} must be a YAML mapping, got {type(data).__name__}."
156-
)
158+
if rtype is None:
159+
raise ValueError(f"Recipe file {recipe_file} must contain a 'metadata.recipe_type' field.")
160+
schema_class = RECIPE_TYPE_TO_CLASS.get(rtype)
161+
if schema_class is None:
162+
raise ValueError(f"Unsupported recipe type: {rtype!r}")
163+
164+
# Pre-flight check on the *raw* YAML so the user sees a clear loader-level error
165+
# rather than a generic pydantic missing-field error. Speculative recipes' body
166+
# sections have field-level defaults, so this check is what keeps their loader
167+
# semantics consistent with PTQ.
168+
required_section = _REQUIRED_SECTION_PER_RECIPE_TYPE.get(rtype)
169+
if required_section is not None:
170+
import yaml
171+
172+
raw = yaml.safe_load(recipe_file.read_text()) or {}
173+
if not isinstance(raw, dict) or required_section not in raw:
174+
kind = (
175+
rtype.value.split("_", 1)[-1].upper() if "_" in rtype.value else rtype.value.upper()
176+
)
177+
raise ValueError(f"{kind} recipe file {recipe_file} must contain {required_section!r}.")
178+
179+
# Passing ``schema_type=schema_class`` to ``load_config`` enables typed-list
180+
# ``$import`` resolution (e.g. ``$import: disable_all`` spliced into
181+
# ``quantize.quant_cfg`` needs to know the list's element schema is
182+
# :class:`QuantizerCfgEntry`). The return value is already a validated schema
183+
# instance.
157184
if overrides:
185+
# Overrides have to be applied before pydantic validation. Round-trip through
186+
# ``model_dump()`` so $imports are resolved and the dict has the resolved shape;
187+
# then splice the dotlist values and re-validate.
188+
recipe = load_config(recipe_file, schema_type=schema_class)
189+
data = recipe.model_dump()
158190
data = _apply_dotlist(data, overrides)
191+
return schema_class.model_validate(data)
159192

160-
metadata = data.get("metadata", {})
161-
if not isinstance(metadata, dict):
193+
recipe = load_config(recipe_file, schema_type=schema_class)
194+
if not isinstance(recipe, schema_class):
162195
raise ValueError(
163-
f"Recipe file {recipe_file} field 'metadata' must be a mapping, "
164-
f"got {type(metadata).__name__}."
165-
)
166-
recipe_type = metadata.get("recipe_type")
167-
if recipe_type is None:
168-
raise ValueError(f"Recipe file {recipe_file} must contain a 'metadata.recipe_type' field.")
169-
170-
if recipe_type == RecipeType.PTQ:
171-
if "quantize" not in data:
172-
raise ValueError(f"PTQ recipe file {recipe_file} must contain 'quantize'.")
173-
return ModelOptPTQRecipe(
174-
metadata=metadata,
175-
quantize=data["quantize"],
196+
f"Recipe file {recipe_file} must produce a {schema_class.__name__}, "
197+
f"got {type(recipe).__name__}."
176198
)
177-
if recipe_type == RecipeType.SPECULATIVE_EAGLE:
178-
if "eagle" not in data:
179-
raise ValueError(f"EAGLE recipe file {recipe_file} must contain 'eagle'.")
180-
return ModelOptEagleRecipe(
181-
metadata=metadata,
182-
model=data.get("model") or {},
183-
data=data.get("data") or {},
184-
training=data.get("training") or {},
185-
eagle=data["eagle"],
186-
)
187-
if recipe_type == RecipeType.SPECULATIVE_DFLASH:
188-
if "dflash" not in data:
189-
raise ValueError(f"DFlash recipe file {recipe_file} must contain 'dflash'.")
190-
return ModelOptDFlashRecipe(
191-
metadata=metadata,
192-
model=data.get("model") or {},
193-
data=data.get("data") or {},
194-
training=data.get("training") or {},
195-
dflash=data["dflash"],
196-
)
197-
if recipe_type == RecipeType.SPECULATIVE_MEDUSA:
198-
if "medusa" not in data:
199-
raise ValueError(f"Medusa recipe file {recipe_file} must contain 'medusa'.")
200-
return ModelOptMedusaRecipe(
201-
metadata=metadata,
202-
model=data.get("model") or {},
203-
data=data.get("data") or {},
204-
training=data.get("training") or {},
205-
medusa=data["medusa"],
206-
)
207-
raise ValueError(f"Unsupported recipe type: {recipe_type!r}")
199+
return recipe
208200

209201

210202
def _find_recipe_section_file(
@@ -229,25 +221,10 @@ def _load_recipe_from_dir(recipe_dir: Path | Traversable) -> ModelOptRecipeBase:
229221
quantize.
230222
"""
231223
metadata_file = _find_recipe_section_file(recipe_dir, "metadata")
232-
233224
metadata = load_config(metadata_file, schema_type=RecipeMetadataConfig)
234-
if not isinstance(metadata, dict):
235-
raise ValueError(
236-
f"Metadata file {metadata_file} must be a YAML mapping, got {type(metadata).__name__}."
237-
)
238-
recipe_type = metadata.get("recipe_type")
239-
if recipe_type is None:
240-
raise ValueError(f"Metadata file {metadata_file} must contain a 'recipe_type' field.")
241225

242-
if recipe_type == RecipeType.PTQ:
226+
if metadata.recipe_type == RecipeType.PTQ:
243227
quantize_file = _find_recipe_section_file(recipe_dir, "quantize")
244-
quantize_data = load_config(quantize_file, schema_type=QuantizeConfig)
245-
if not isinstance(quantize_data, dict):
246-
raise ValueError(
247-
f"{quantize_file} must be a YAML mapping, got {type(quantize_data).__name__}."
248-
)
249-
return ModelOptPTQRecipe(
250-
metadata=metadata,
251-
quantize=quantize_data,
252-
)
253-
raise ValueError(f"Unsupported recipe type: {recipe_type!r}")
228+
quantize_cfg = load_config(quantize_file, schema_type=QuantizeConfig)
229+
return ModelOptPTQRecipe(metadata=metadata, quantize=quantize_cfg)
230+
raise ValueError(f"Unsupported recipe type: {metadata.recipe_type!r}")

modelopt/torch/opt/config.py

Lines changed: 45 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
import fnmatch
1919
import json
20-
from collections.abc import Callable, ItemsView, Iterator, KeysView, ValuesView
20+
from collections.abc import Callable, ItemsView, Iterator, KeysView, MutableMapping, ValuesView
2121
from typing import Any, TypeAlias
2222

2323
import torch
@@ -57,11 +57,18 @@ def ModeloptField(default: Any = PydanticUndefined, **kwargs): # noqa: N802
5757
# TODO: expand config classes to searcher
5858

5959

60-
class ModeloptBaseConfig(BaseModel):
60+
class ModeloptBaseConfig(BaseModel, MutableMapping):
6161
"""Our config base class for mode configuration.
6262
6363
The base class extends the capabilities of pydantic's BaseModel to provide additional methods
6464
and properties for easier access and manipulation of the configuration.
65+
66+
Inherits from :class:`collections.abc.MutableMapping` so instances satisfy
67+
``isinstance(cfg, Mapping)`` / ``isinstance(cfg, MutableMapping)`` checks and pick up the
68+
mixin methods (``pop``, ``popitem``, ``setdefault``, ``clear``). Schema fields are fixed,
69+
so ``__delitem__`` raises :class:`TypeError`; the inherited ``pop`` / ``clear`` /
70+
``popitem`` therefore also raise on any existing key, while ``pop(key, default)`` for a
71+
missing key still returns the default normally.
6572
"""
6673

6774
model_config = PyDanticConfigDict(extra="forbid", validate_assignment=True)
@@ -110,18 +117,49 @@ def __contains__(self, key: str) -> bool:
110117
return False
111118

112119
def __getitem__(self, key: str) -> Any:
113-
"""Get the value for the given key (can be name or alias of field)."""
114-
return getattr(self, self.get_field_name_from_key(key))
120+
"""Get the value for the given key (can be name or alias of field).
121+
122+
Raises :class:`KeyError` for missing keys so the class behaves like a regular
123+
:class:`Mapping` — required for the inherited ``MutableMapping`` mixin methods
124+
(``pop``, ``setdefault``, ...) to dispatch correctly.
125+
"""
126+
try:
127+
return getattr(self, self.get_field_name_from_key(key))
128+
except AttributeError:
129+
raise KeyError(key) from None
115130

116131
def __setitem__(self, key: str, value: Any) -> None:
117-
"""Set the value for the given key (can be name or alias of field)."""
118-
setattr(self, self.get_field_name_from_key(key), value)
132+
"""Set the value for the given key (can be name or alias of field).
133+
134+
Raises :class:`KeyError` (not :class:`AttributeError`) for unknown keys so the
135+
class matches the :class:`MutableMapping` protocol — both for direct
136+
``cfg["unknown"] = value`` writes and for inherited mixin helpers like
137+
``setdefault`` that write through ``__setitem__``.
138+
"""
139+
try:
140+
setattr(self, self.get_field_name_from_key(key), value)
141+
except AttributeError:
142+
raise KeyError(key) from None
143+
144+
def __delitem__(self, key: str) -> None:
145+
"""Reject key deletion.
146+
147+
``ModeloptBaseConfig`` exposes a fixed pydantic schema, so removing a key is
148+
ill-defined: schema fields can't disappear, and silently resetting them to their
149+
defaults would surprise callers. Raise ``TypeError`` instead. Defined so the
150+
class fully satisfies the ``MutableMapping`` protocol (``__delitem__`` is
151+
required), without committing to actual deletion semantics.
152+
"""
153+
raise TypeError(
154+
f"{type(self).__name__} does not support key deletion; schema fields are "
155+
f"fixed (attempted to delete {key!r})."
156+
)
119157

120158
def get(self, key: str, default: Any = None) -> Any:
121159
"""Get the value for the given key (can be name or alias) or default if not found."""
122160
try:
123161
return self[key]
124-
except AttributeError:
162+
except KeyError:
125163
return default
126164

127165
def __len__(self) -> int:

0 commit comments

Comments
 (0)