Skip to content

Commit d33ee36

Browse files
committed
Tighten ModeloptBaseConfig mapping semantics
Signed-off-by: Shengliang Xu <shengliangx@nvidia.com>
1 parent 0d31d46 commit d33ee36

5 files changed

Lines changed: 112 additions & 80 deletions

File tree

modelopt/torch/opt/config.py

Lines changed: 51 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,15 @@ def ModeloptField(default: Any = PydanticUndefined, **kwargs): # noqa: N802
6060
class ModeloptBaseConfig(BaseModel, MutableMapping[str, Any]):
6161
"""Our config base class for mode configuration.
6262
63-
The base class extends the capabilities of pydantic's BaseModel to provide additional methods
64-
and properties for easier access and manipulation of the configuration.
63+
The base class extends pydantic's BaseModel with a mapping interface so schema-backed
64+
config objects can keep the dict-style access patterns used by older ModelOpt code.
65+
66+
This is intentionally a fixed-key mutable mapping instead of a general dict. The mapping
67+
keys are the model fields exposed through their aliases when present, and lookups accept
68+
either a field name or its alias. Values are read and written through the pydantic model, so
69+
assignment validation still applies. New keys cannot be inserted, and existing keys cannot
70+
be deleted because the schema defines the complete key set; callers that need omission
71+
semantics should use model_dump(exclude_unset=True) or the explicit_* helpers.
6572
"""
6673

6774
model_config = PyDanticConfigDict(extra="forbid", validate_assignment=True)
@@ -111,44 +118,42 @@ def __contains__(self, key: str) -> bool:
111118

112119
def __getitem__(self, key: str) -> Any:
113120
"""Get the value for the given key (can be name or alias of field)."""
114-
return getattr(self, self.get_field_name_from_key(key))
121+
try:
122+
return getattr(self, self.get_field_name_from_key(key))
123+
except AttributeError as e:
124+
raise KeyError(key) from e
115125

116126
def __setitem__(self, key: str, value: Any) -> None:
117-
"""Set the value for the given key (can be name or alias of field)."""
118-
setattr(self, self.get_field_name_from_key(key), value)
127+
"""Set an existing field by name or alias, preserving pydantic assignment validation."""
128+
try:
129+
field_name = self.get_field_name_from_key(key)
130+
except AttributeError as e:
131+
raise KeyError(key) from e
132+
if field_name not in type(self).model_fields:
133+
raise KeyError(key)
134+
setattr(self, field_name, value)
119135

120136
def __delitem__(self, key: str) -> None:
121-
"""Unset the given key so exclude_unset dumps omit it."""
137+
"""Reject deletion because ModeloptBaseConfig exposes a fixed schema key set."""
122138
try:
123-
field_name = self.get_field_name_from_key(key)
139+
self.get_field_name_from_key(key)
124140
except AttributeError as e:
125141
raise KeyError(key) from e
126-
if field_name in self._iterable_model_extra:
127-
assert self.model_extra is not None
128-
del self.model_extra[field_name]
129-
self.model_fields_set.discard(field_name)
130-
return
131-
132-
field_info = type(self).model_fields[field_name]
133-
default = field_info.get_default(call_default_factory=True)
134-
if default is PydanticUndefined:
135-
raise KeyError(f"Key {key} cannot be unset because it has no default.")
136-
self.__dict__[field_name] = default
137-
self.model_fields_set.discard(field_name)
142+
raise TypeError("Config mapping keys are fixed and cannot be deleted.")
138143

139144
def get(self, key: str, default: Any = None) -> Any:
140145
"""Get the value for the given key (can be name or alias) or default if not found."""
141146
try:
142147
return self[key]
143-
except AttributeError:
148+
except KeyError:
144149
return default
145150

146151
def __len__(self) -> int:
147-
"""Return the length of the config."""
148-
return len(self.model_fields) + len(self._iterable_model_extra)
152+
"""Return the number of schema and extra keys exposed by the mapping."""
153+
return len(type(self).model_fields) + len(self._iterable_model_extra)
149154

150155
def __iter__(self) -> Iterator[str]:
151-
"""Iterate over aliases (or name if alias is not defined) of fields."""
156+
"""Iterate over schema keys, preferring aliases over field names."""
152157
for field_name, field_info in type(self).model_fields.items():
153158
yield field_info.alias or field_name
154159
yield from self._iterable_model_extra
@@ -157,6 +162,29 @@ def _get_kv_dict(self) -> dict[str, Any]:
157162
"""Return a dictionary with keys as aliases if possible."""
158163
return {k: self[k] for k in self}
159164

165+
def iter_explicit_keys(self) -> Iterator[str]:
166+
"""Iterate over explicitly set schema keys, preferring aliases over field names."""
167+
for field_name, field_info in type(self).model_fields.items():
168+
if field_name in self.model_fields_set:
169+
yield field_info.alias or field_name
170+
yield from self._iterable_model_extra
171+
172+
def _get_explicit_kv_dict(self) -> dict[str, Any]:
173+
"""Return explicitly set key-value pairs with keys as aliases if possible."""
174+
return {k: self[k] for k in self.iter_explicit_keys()}
175+
176+
def explicit_keys(self) -> KeysView[str]:
177+
"""Return the explicitly set keys of the config."""
178+
return self._get_explicit_kv_dict().keys()
179+
180+
def explicit_values(self) -> ValuesView[Any]:
181+
"""Return the explicitly set values of the config."""
182+
return self._get_explicit_kv_dict().values()
183+
184+
def explicit_items(self) -> ItemsView[str, Any]:
185+
"""Return the explicitly set items of the config with keys as aliases if possible."""
186+
return self._get_explicit_kv_dict().items()
187+
160188
def keys(self) -> KeysView[str]:
161189
"""Return the keys (aliases prioritized over names) of the config."""
162190
return self._get_kv_dict().keys()

modelopt/torch/quantization/config.py

Lines changed: 15 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -543,8 +543,8 @@ class QuantizerCfgEntry(ModeloptBaseConfig):
543543
"Attributes to apply to matched quantizers. A list configures a sequential quantizer."
544544
),
545545
)
546-
enable: bool | None = ModeloptField(
547-
default=None,
546+
enable: bool = ModeloptField(
547+
default=True,
548548
title="Quantizer enable flag.",
549549
description="Optional on/off toggle for matched quantizers, independent of cfg.",
550550
)
@@ -556,11 +556,6 @@ def validate_quantizer_cfg_entry(cls, values):
556556
if not isinstance(values, Mapping):
557557
return values
558558

559-
if "cfg" not in values and "enable" not in values:
560-
raise ValueError(
561-
"Each quant_cfg entry must specify 'cfg', 'enable', or both. "
562-
"An entry with only 'quantizer_name' has no effect."
563-
)
564559
if "cfg" in values and values["cfg"] is None:
565560
raise ValueError("cfg must be omitted or a valid mapping/list, not null.")
566561
if "enable" in values and values["enable"] is None:
@@ -1038,25 +1033,21 @@ def normalize_quant_cfg_list(
10381033
- Legacy ``nn.*``-scoped format: ``{"nn.<Class>": {"<quantizer_name>": <cfg>}}`` - converted
10391034
to a new-format entry with ``parent_class`` set.
10401035
1041-
**Validation** - an entry is rejected if it carries no instruction, i.e. it specifies neither
1042-
``cfg`` nor ``enable``. Concretely, the following are invalid:
1036+
**Validation** - an entry is rejected if its shape is invalid. Concretely, the following
1037+
are invalid:
10431038
10441039
- An empty entry ``{}``.
1045-
- An entry with only ``quantizer_name`` and no other keys - the only effect would be an
1046-
implicit ``enable=True``, which must be stated explicitly.
10471040
- An entry with ``enable=True`` (explicit or implicit) whose ``cfg`` is not a non-empty
10481041
``dict`` or ``list`` - e.g. ``{"quantizer_name": "*", "cfg": {}}`` or
10491042
``{"quantizer_name": "*", "cfg": 42}``. An enabled quantizer must have a valid
10501043
configuration.
10511044
1052-
**Normalization** - after conversion and validation every entry is put into canonical form:
1053-
1054-
- ``enable`` is set to ``True`` if not explicitly specified.
1055-
- ``cfg`` is set to ``None`` if not present in the entry.
1056-
1057-
For dict and legacy inputs, every returned entry is guaranteed to have
1058-
``quantizer_name``, ``enable``, and ``cfg`` set (plus optionally ``parent_class``). Typed
1059-
:class:`QuantizerCfgEntry` inputs are assumed to be already parsed and are preserved.
1045+
**Normalization** - after conversion and validation every entry is parsed as a
1046+
:class:`QuantizerCfgEntry`. Schema defaults are available through mapping access, so ``enable``
1047+
defaults to ``True`` and ``cfg`` defaults to ``None`` when omitted. Omitted defaults are not
1048+
marked as explicitly set, so ``model_dump(exclude_unset=True)`` preserves the user's sparse
1049+
input shape. Typed :class:`QuantizerCfgEntry` inputs are assumed to be already parsed and are
1050+
preserved.
10601051
10611052
Args:
10621053
v: A list of raw quant_cfg entries in any supported format, or a legacy flat dict.
@@ -1066,9 +1057,8 @@ def normalize_quant_cfg_list(
10661057
typed entries are preserved.
10671058
10681059
Raises:
1069-
ValueError: If any entry has only ``quantizer_name`` with neither ``cfg`` nor ``enable``,
1070-
if ``enable=True`` with an empty or non-dict/list ``cfg``, or if the entry format
1071-
is not recognized.
1060+
ValueError: If ``enable=True`` with an empty or non-dict/list ``cfg``, or if the entry
1061+
format is not recognized.
10721062
"""
10731063
if isinstance(v, list) and all(isinstance(raw, QuantizerCfgEntry) for raw in v):
10741064
return cast("list[QuantizerCfgEntry]", v)
@@ -1164,14 +1154,6 @@ def _dict_to_entry(key: str, value: Any) -> list[dict[str, Any]]:
11641154
raise ValueError(f"Invalid quant_cfg entry: {raw!r}.")
11651155

11661156
for entry in entries:
1167-
# Validate: must carry at least one instruction beyond the path selector.
1168-
if "cfg" not in entry and "enable" not in entry:
1169-
raise ValueError(
1170-
f"Invalid quant_cfg entry: {raw!r} - each entry must specify 'cfg', 'enable', "
1171-
"or both. An entry with only 'quantizer_name' has no effect (implicit "
1172-
"enable=True is not allowed; set it explicitly)."
1173-
)
1174-
11751157
# Validate: when cfg is present and enable=True, cfg must be a non-empty
11761158
# dict or list. An empty cfg would attempt to create a
11771159
# QuantizerAttributeConfig with no actual configuration.
@@ -1211,9 +1193,6 @@ def _dict_to_entry(key: str, value: Any) -> list[dict[str, Any]]:
12111193
"explicitly."
12121194
)
12131195

1214-
# Normalize: make enable explicit. cfg remains omitted when it is intentionally unset.
1215-
entry.setdefault("enable", True)
1216-
12171196
result.append(QuantizerCfgEntry.model_validate(entry))
12181197
return result
12191198

@@ -1981,22 +1960,20 @@ def _cfg_to_dict(cfg):
19811960
for entry in quant_cfg:
19821961
name = entry["quantizer_name"]
19831962
raw_cfg = entry.get("cfg")
1984-
enable = entry.get("enable")
1963+
enable = entry["enable"]
19851964
if "weight_quantizer" in name:
19861965
# We don't calibrate weight quantizer
19871966
continue
19881967
# Sequential quantizers (e.g. W4A8) have a list of cfg dicts
19891968
if isinstance(raw_cfg, list):
19901969
for _config in raw_cfg:
19911970
cfg = _cfg_to_dict(_config)
1992-
if enable is not None:
1993-
cfg["enable"] = enable
1971+
cfg["enable"] = enable
19941972
if _not_dynamic(cfg):
19951973
return True
19961974
continue
19971975
cfg = _cfg_to_dict(raw_cfg)
1998-
if enable is not None:
1999-
cfg["enable"] = enable
1976+
cfg["enable"] = enable
20001977
if _not_dynamic(cfg):
20011978
return True
20021979

modelopt/torch/quantization/conversion.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -254,7 +254,7 @@ def set_quantizer_by_cfg(quant_model: nn.Module, quant_cfg: QuantizeQuantCfgInpu
254254
for entry in quant_cfg:
255255
quantizer_name: str = entry["quantizer_name"]
256256
cfg = entry["cfg"] # None, QuantizerAttributeConfig, or list after normalization
257-
enable = entry["enable"] if entry["enable"] is not None else True
257+
enable = entry["enable"]
258258
parent_class_name = entry.get("parent_class")
259259
if parent_class_name:
260260
try:

tests/unit/torch/opt/test_config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ def _run_test(is_new_registered):
7272
assert config[lin_name] == lin_expected_value
7373
assert config[lin_alias] == lin_expected_value
7474
assert getattr(config, lin_name) == lin_expected_value
75-
with nullcontext() if is_new_registered else pytest.raises(AttributeError):
75+
with nullcontext() if is_new_registered else pytest.raises(KeyError):
7676
config[new_name]
7777

7878
# get

tests/unit/torch/quantization/test_config_validation.py

Lines changed: 44 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ def test_need_calibration():
5959

6060
def test_need_calibration_with_quantize_config_type():
6161
"""need_calibration accepts schema-backed QuantizeConfig objects."""
62+
assert need_calibration(QuantizeConfig())
6263
assert need_calibration(QuantizeConfig.model_validate(FP8_DEFAULT_CFG))
6364
assert not need_calibration(QuantizeConfig.model_validate(FP8_PER_CHANNEL_PER_TOKEN_CFG))
6465

@@ -100,16 +101,27 @@ def test_quantizer_cfg_entry_is_pydantic_and_dict_like():
100101
entry = QuantizerCfgEntry(quantizer_name="*", enable=False)
101102
assert isinstance(entry, ModeloptBaseConfig)
102103
assert entry["quantizer_name"] == "*"
103-
assert entry.get("cfg") is None
104+
assert entry["cfg"] is None
105+
assert "cfg" in entry
106+
assert list(entry) == ["quantizer_name", "parent_class", "cfg", "enable"]
107+
assert dict(entry.items()) == {
108+
"quantizer_name": "*",
109+
"parent_class": None,
110+
"cfg": None,
111+
"enable": False,
112+
}
113+
assert dict(entry.explicit_items()) == {"quantizer_name": "*", "enable": False}
114+
with pytest.raises(KeyError):
115+
entry["unknown"] = 1
104116
assert entry.model_dump(exclude_unset=True) == {"quantizer_name": "*", "enable": False}
105117

106118
cfg_entry = QuantizerCfgEntry(quantizer_name="*weight_quantizer", cfg={"num_bits": 8})
107119
assert isinstance(cfg_entry["cfg"], QuantizerAttributeConfig)
108120
assert _cfg_to_dict(cfg_entry["cfg"]) == {"num_bits": 8}
109121

110122

111-
def test_quantizer_cfg_entry_mutable_mapping_delitem_unsets_field():
112-
"""Deleting a config key resets it to unset for exclude_unset dumps."""
123+
def test_quantizer_cfg_entry_mutable_mapping_rejects_key_deletion():
124+
"""ModeloptBaseConfig mappings have a fixed key set and reject deletion."""
113125
entry = QuantizerCfgEntry(quantizer_name="*weight_quantizer", cfg={"num_bits": 8}, enable=True)
114126
assert isinstance(entry, MutableMapping)
115127
assert entry.model_dump(exclude_unset=True) == {
@@ -118,11 +130,14 @@ def test_quantizer_cfg_entry_mutable_mapping_delitem_unsets_field():
118130
"enable": True,
119131
}
120132

121-
del entry["cfg"]
133+
with pytest.raises(TypeError):
134+
del entry["cfg"]
122135

123-
assert entry["cfg"] is None
136+
assert "cfg" in entry
137+
assert entry["cfg"] is not None
124138
assert entry.model_dump(exclude_unset=True) == {
125139
"quantizer_name": "*weight_quantizer",
140+
"cfg": {"num_bits": 8},
126141
"enable": True,
127142
}
128143

@@ -196,10 +211,13 @@ def test_quantizer_cfg_entry_rejects_explicit_null_values(raw, match):
196211
normalize_quant_cfg_list([raw])
197212

198213

199-
def test_quantizer_cfg_entry_rejects_no_effect_entry():
200-
"""Direct QuantizerCfgEntry construction rejects entries with no cfg or enable."""
201-
with pytest.raises(ValidationError, match="must specify 'cfg', 'enable'"):
202-
QuantizerCfgEntry(quantizer_name="*")
214+
def test_quantizer_cfg_entry_defaults_enable_true():
215+
"""Direct QuantizerCfgEntry construction uses enable=True when omitted."""
216+
entry = QuantizerCfgEntry(quantizer_name="*")
217+
assert entry["enable"] is True
218+
assert entry["cfg"] is None
219+
assert dict(entry.explicit_items()) == {"quantizer_name": "*"}
220+
assert entry.model_dump(exclude_unset=True) == {"quantizer_name": "*"}
203221

204222

205223
def test_quantizer_cfg_entry_rejects_empty_name():
@@ -231,7 +249,8 @@ def test_new_format_passthrough(self):
231249
assert result[0]["quantizer_name"] == "*weight_quantizer"
232250
assert isinstance(result[0]["cfg"], QuantizerAttributeConfig)
233251
assert _cfg_to_dict(result[0]["cfg"]) == {"num_bits": 8, "axis": 0}
234-
assert result[0]["enable"] is True # defaulted
252+
assert result[0]["enable"] is True # schema default
253+
assert "enable" not in dict(result[0].explicit_items())
235254

236255
def test_typed_entry_list_passthrough(self):
237256
"""Already-parsed QuantizerCfgEntry lists are returned unchanged."""
@@ -256,6 +275,7 @@ def test_mixed_typed_and_dict_entries_normalize_to_typed_entries(self):
256275
assert isinstance(result[1], QuantizerCfgEntry)
257276
assert _cfg_to_dict(result[1]["cfg"]) == {"num_bits": 8}
258277
assert result[1]["enable"] is True
278+
assert "enable" not in dict(result[1].explicit_items())
259279

260280
def test_new_format_enable_false(self):
261281
"""Explicit enable=False is preserved."""
@@ -277,7 +297,8 @@ def test_legacy_single_key_dict(self):
277297
result = normalize_quant_cfg_list(raw)
278298
assert result[0]["quantizer_name"] == "*weight_quantizer"
279299
assert _cfg_to_dict(result[0]["cfg"]) == {"num_bits": 8, "axis": 0}
280-
assert result[0]["enable"] is True # defaulted
300+
assert result[0]["enable"] is True # schema default
301+
assert "enable" not in dict(result[0].explicit_items())
281302

282303
def test_legacy_single_key_dict_with_enable(self):
283304
"""Legacy {'*path': {'enable': False}} splits enable out from cfg."""
@@ -296,17 +317,19 @@ def test_legacy_nn_class_scoped(self):
296317
assert result[0]["enable"] is False
297318

298319
def test_normalization_cfg_defaults_to_none(self):
299-
"""Entries without cfg get cfg=None after normalization."""
320+
"""Entries without cfg expose the default mapping key but keep it unset."""
300321
raw = [{"quantizer_name": "*lm_head*", "enable": False}]
301322
result = normalize_quant_cfg_list(raw)
302323
assert "cfg" in result[0]
303324
assert result[0]["cfg"] is None
325+
assert "cfg" not in dict(result[0].explicit_items())
304326

305327
def test_normalization_enable_defaults_to_true(self):
306-
"""Entries with cfg but no enable get enable=True after normalization."""
328+
"""Entries with cfg but no enable read as enable=True without marking it explicit."""
307329
raw = [{"quantizer_name": "*", "cfg": {"num_bits": 4}}]
308330
result = normalize_quant_cfg_list(raw)
309331
assert result[0]["enable"] is True
332+
assert "enable" not in dict(result[0].explicit_items())
310333

311334
def test_empty_list(self):
312335
"""Empty list is returned unchanged."""
@@ -322,10 +345,13 @@ def test_multiple_entries_order_preserved(self):
322345
assert result[0]["quantizer_name"] == "*"
323346
assert result[1]["quantizer_name"] == "*weight_quantizer"
324347

325-
def test_error_on_quantizer_name_only(self):
326-
"""Entry with only quantizer_name and no cfg or enable is rejected."""
327-
with pytest.raises(ValueError, match="must specify 'cfg', 'enable'"):
328-
normalize_quant_cfg_list([{"quantizer_name": "*"}])
348+
def test_quantizer_name_only_defaults_enable_true(self):
349+
"""Entry with only quantizer_name uses enable=True from the schema default."""
350+
result = normalize_quant_cfg_list([{"quantizer_name": "*"}])
351+
assert result[0]["enable"] is True
352+
assert result[0]["cfg"] is None
353+
assert dict(result[0].explicit_items()) == {"quantizer_name": "*"}
354+
assert result[0].model_dump(exclude_unset=True) == {"quantizer_name": "*"}
329355

330356
def test_error_on_empty_dict(self):
331357
"""An empty dict entry is rejected."""
@@ -421,6 +447,7 @@ def test_legacy_flat_dict_conversion(self):
421447
assert result[1]["quantizer_name"] == "*weight_quantizer"
422448
assert _cfg_to_dict(result[1]["cfg"]) == {"num_bits": 8, "axis": 0}
423449
assert result[1]["enable"] is True
450+
assert "enable" not in dict(result[1].explicit_items())
424451

425452
def test_legacy_enable_only_produces_cfg_none(self):
426453
"""Legacy {'*': {'enable': False}} should produce cfg=None, not cfg={}."""

0 commit comments

Comments
 (0)