Skip to content

Commit fe6dbcf

Browse files
committed
Address quant config review feedback
Signed-off-by: Shengliang Xu <shengliangx@nvidia.com>
1 parent 90a39be commit fe6dbcf

9 files changed

Lines changed: 110 additions & 20 deletions

File tree

examples/llm_ptq/example_utils.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242
ProcessorMixin,
4343
)
4444

45-
from modelopt.torch.quantization.config import QuantizeConfig
45+
from modelopt.torch.quantization.config import QuantizeConfig, QuantizerCfgEntry
4646

4747
try:
4848
from huggingface_hub import snapshot_download
@@ -249,10 +249,14 @@ def build_quant_cfg(
249249

250250
if model_type == "phi4mm":
251251
# Only quantize the language model
252-
quant_cfg_obj["quant_cfg"].append({"quantizer_name": "*speech*", "enable": False})
253-
quant_cfg_obj["quant_cfg"].append({"quantizer_name": "*audio*", "enable": False})
254-
quant_cfg_obj["quant_cfg"].append({"quantizer_name": "*image*", "enable": False})
255-
quant_cfg_obj["quant_cfg"].append({"quantizer_name": "*vision*", "enable": False})
252+
quant_cfg_obj["quant_cfg"].extend(
253+
[
254+
QuantizerCfgEntry(quantizer_name="*speech*", enable=False),
255+
QuantizerCfgEntry(quantizer_name="*audio*", enable=False),
256+
QuantizerCfgEntry(quantizer_name="*image*", enable=False),
257+
QuantizerCfgEntry(quantizer_name="*vision*", enable=False),
258+
]
259+
)
256260

257261
return quant_cfg_obj
258262

examples/vllm_serve/vllm_ptq_utils.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
# limitations under the License.
1515

1616
import dataclasses
17-
from collections.abc import Callable
17+
from collections.abc import Callable, Mapping
1818
from typing import Any
1919

2020
import torch
@@ -119,7 +119,11 @@ def update_kv_cfg_for_mla(model: torch.nn.Module, kv_quant_cfg: list) -> list:
119119
return kv_quant_cfg
120120

121121
kv_entry = next(
122-
(e for e in kv_quant_cfg if e.get("quantizer_name") == "*[kv]_bmm_quantizer"),
122+
(
123+
e
124+
for e in kv_quant_cfg
125+
if isinstance(e, Mapping) and e.get("quantizer_name") == "*[kv]_bmm_quantizer"
126+
),
123127
None,
124128
)
125129
if kv_entry is not None:

modelopt/torch/opt/config.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,10 @@ def __setitem__(self, key: str, value: Any) -> None:
119119

120120
def __delitem__(self, key: str) -> None:
121121
"""Unset the given key so exclude_unset dumps omit it."""
122-
field_name = self.get_field_name_from_key(key)
122+
try:
123+
field_name = self.get_field_name_from_key(key)
124+
except AttributeError as e:
125+
raise KeyError(key) from e
123126
if field_name in self._iterable_model_extra:
124127
assert self.model_extra is not None
125128
del self.model_extra[field_name]
@@ -129,8 +132,8 @@ def __delitem__(self, key: str) -> None:
129132
field_info = type(self).model_fields[field_name]
130133
default = field_info.get_default(call_default_factory=True)
131134
if default is PydanticUndefined:
132-
raise AttributeError(f"Key {key} cannot be unset because it has no default.")
133-
setattr(self, field_name, default)
135+
raise KeyError(f"Key {key} cannot be unset because it has no default.")
136+
self.__dict__[field_name] = default
134137
self.model_fields_set.discard(field_name)
135138

136139
def get(self, key: str, default: Any = None) -> Any:

modelopt/torch/quantization/config.py

Lines changed: 39 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -561,6 +561,10 @@ def validate_quantizer_cfg_entry(cls, values):
561561
"Each quant_cfg entry must specify 'cfg', 'enable', or both. "
562562
"An entry with only 'quantizer_name' has no effect."
563563
)
564+
if "cfg" in values and values["cfg"] is None:
565+
raise ValueError("cfg must be omitted or a valid mapping/list, not null.")
566+
if "enable" in values and values["enable"] is None:
567+
raise ValueError("enable must be a boolean when provided, not null.")
564568

565569
cfg = values.get("cfg")
566570
enable = values.get("enable", True)
@@ -1008,15 +1012,15 @@ class GPTQCalibConfig(QuantizeAlgorithmConfig):
10081012

10091013

10101014
QuantizerCfgListConfig = list[QuantizerCfgEntry]
1011-
QuantizeQuantCfgInputType = Sequence[QuantizerCfgEntry | Mapping[str, Any]]
1015+
QuantizeQuantCfgInputType = Mapping[str, Any] | Sequence[QuantizerCfgEntry | Mapping[str, Any]]
10121016

10131017
_QuantizeAlgoCfgType = str | dict | QuantizeAlgorithmConfig | None
10141018

10151019
QuantizeAlgoCfgType = _QuantizeAlgoCfgType | list[_QuantizeAlgoCfgType] | None
10161020

10171021

10181022
def normalize_quant_cfg_list(
1019-
v: Mapping[str, Any] | list[QuantizerCfgEntry | Mapping[str, Any]],
1023+
v: Mapping[str, Any] | Sequence[QuantizerCfgEntry | Mapping[str, Any]],
10201024
) -> list[QuantizerCfgEntry]:
10211025
"""Normalize a raw quant_cfg into a list of :class:`QuantizerCfgEntry` objects.
10221026
@@ -1099,15 +1103,19 @@ def _dict_to_entry(key: str, value: Any) -> list[dict[str, Any]]:
10991103
if isinstance(sub_cfg, QuantizerAttributeConfig):
11001104
enable = None
11011105
cfg = sub_cfg
1102-
else:
1106+
elif isinstance(sub_cfg, Mapping):
11031107
sub_cfg = dict(sub_cfg)
11041108
enable = sub_cfg.pop("enable", None)
11051109
cfg = sub_cfg or None
1110+
else:
1111+
enable = None
1112+
cfg = sub_cfg
11061113
entry: dict[str, Any] = {
11071114
"parent_class": key,
11081115
"quantizer_name": q_path,
1109-
"cfg": cfg,
11101116
}
1117+
if cfg is not None:
1118+
entry["cfg"] = cfg
11111119
if enable is not None:
11121120
entry["enable"] = enable
11131121
entries.append(entry)
@@ -1119,7 +1127,9 @@ def _dict_to_entry(key: str, value: Any) -> list[dict[str, Any]]:
11191127
else:
11201128
cfg = value
11211129
enable = None
1122-
entry = {"quantizer_name": key, "cfg": cfg}
1130+
entry = {"quantizer_name": key}
1131+
if cfg is not None:
1132+
entry["cfg"] = cfg
11231133
if enable is not None:
11241134
entry["enable"] = enable
11251135
return [entry]
@@ -1165,6 +1175,17 @@ def _dict_to_entry(key: str, value: Any) -> list[dict[str, Any]]:
11651175
# Validate: when cfg is present and enable=True, cfg must be a non-empty
11661176
# dict or list. An empty cfg would attempt to create a
11671177
# QuantizerAttributeConfig with no actual configuration.
1178+
if "cfg" in entry and entry["cfg"] is None:
1179+
raise ValueError(
1180+
f"Invalid quant_cfg entry: {raw!r} - 'cfg' must be omitted or a "
1181+
"valid mapping/list, not null."
1182+
)
1183+
if "enable" in entry and entry["enable"] is None:
1184+
raise ValueError(
1185+
f"Invalid quant_cfg entry: {raw!r} - 'enable' must be a boolean "
1186+
"when provided, not null."
1187+
)
1188+
11681189
cfg = entry.get("cfg")
11691190
enable = entry.get("enable", True)
11701191
if enable and cfg is not None:
@@ -1190,9 +1211,8 @@ def _dict_to_entry(key: str, value: Any) -> list[dict[str, Any]]:
11901211
"explicitly."
11911212
)
11921213

1193-
# Normalize: make enable and cfg always explicit.
1214+
# Normalize: make enable explicit. cfg remains omitted when it is intentionally unset.
11941215
entry.setdefault("enable", True)
1195-
entry.setdefault("cfg", None)
11961216

11971217
result.append(QuantizerCfgEntry.model_validate(entry))
11981218
return result
@@ -1201,6 +1221,18 @@ def _dict_to_entry(key: str, value: Any) -> list[dict[str, Any]]:
12011221
class QuantizeConfig(ModeloptBaseConfig):
12021222
"""Default configuration for ``quantize`` mode."""
12031223

1224+
def model_dump(self, **kwargs):
1225+
"""Dump quant_cfg entries without unset optional fields."""
1226+
data = super().model_dump(**kwargs)
1227+
if "quant_cfg" in data:
1228+
data["quant_cfg"] = [
1229+
entry.model_dump(exclude_unset=True)
1230+
if isinstance(entry, QuantizerCfgEntry)
1231+
else {k: v for k, v in entry.items() if v is not None}
1232+
for entry in self.quant_cfg
1233+
]
1234+
return data
1235+
12041236
quant_cfg: QuantizerCfgListConfig = ModeloptField(
12051237
default=[{"quantizer_name": "*", "cfg": {"num_bits": 8, "axis": None}}],
12061238
title="Quantization configuration",

modelopt/torch/quantization/conversion.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -249,7 +249,7 @@ def set_quantizer_by_cfg(quant_model: nn.Module, quant_cfg: QuantizeQuantCfgInpu
249249
250250
See :ref:`quant-cfg` for the full format reference and common patterns.
251251
"""
252-
quant_cfg = normalize_quant_cfg_list(list(quant_cfg))
252+
quant_cfg = normalize_quant_cfg_list(quant_cfg)
253253

254254
for entry in quant_cfg:
255255
quantizer_name: str = entry["quantizer_name"]
@@ -496,7 +496,7 @@ def set_quantizer_by_cfg_context(quant_model: nn.Module, quant_cfg: QuantizeQuan
496496
Yields:
497497
None — the context body runs with the new quantizer attributes active.
498498
"""
499-
quant_cfg = normalize_quant_cfg_list(list(quant_cfg))
499+
quant_cfg = normalize_quant_cfg_list(quant_cfg)
500500

501501
for entry in quant_cfg:
502502
if isinstance(entry.get("cfg"), list):

modelopt/torch/quantization/nn/modules/tensor_quantizer.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1431,6 +1431,8 @@ def set_from_attribute_config(
14311431
if not isinstance(attributes, (list, tuple)):
14321432
assert isinstance(attributes, Mapping), "attributes must be a list or a mapping."
14331433
attributes = [attributes] * len(self)
1434+
elif len(attributes) != len(self):
1435+
raise ValueError(f"Expected {len(self)} attribute configs, but got {len(attributes)}.")
14341436

14351437
for attribute, quantizer in zip(attributes, self):
14361438
quantizer.set_from_attribute_config(attribute)

modelopt/torch/quantization/utils/core_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -935,7 +935,7 @@ def update_quant_cfg_with_kv_cache_quant(
935935
inner = list(
936936
updated_quant_cfg.get("quant_cfg") or [QuantizerCfgEntry(quantizer_name="*", enable=False)]
937937
)
938-
updated_quant_cfg["quant_cfg"] = inner + list(kv_cache_quant_cfg)
938+
updated_quant_cfg["quant_cfg"] = inner + copy.deepcopy(list(kv_cache_quant_cfg))
939939

940940
# Set default algorithm for kv cache quantization if not provided.
941941
if not updated_quant_cfg.get("algorithm"):

tests/unit/torch/quantization/test_config_validation.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,9 @@ def test_quantizer_cfg_entry_mutable_mapping_delitem_unsets_field():
126126
"enable": True,
127127
}
128128

129+
with pytest.raises(KeyError):
130+
del entry["missing"]
131+
129132

130133
def test_public_preset_quant_cfg_entries_are_typed_and_dict_like():
131134
"""Public preset constants are typed but keep dict-style entry access."""
@@ -177,6 +180,22 @@ def test_mixed_raw_dict_and_modelopt_config_entries_normalize_after_mutation():
177180
assert weight_entry["cfg"]["num_bits"] == "e4m3"
178181

179182

183+
@pytest.mark.parametrize(
184+
("raw", "match"),
185+
[
186+
({"quantizer_name": "*", "cfg": None}, "'?cfg'? must be omitted"),
187+
({"quantizer_name": "*", "enable": None}, "'?enable'? must be a boolean"),
188+
],
189+
)
190+
def test_quantizer_cfg_entry_rejects_explicit_null_values(raw, match):
191+
"""Explicit null cfg/enable values are rejected instead of treated as omitted."""
192+
with pytest.raises(ValidationError, match=match):
193+
QuantizerCfgEntry.model_validate(raw)
194+
195+
with pytest.raises(ValueError, match=match):
196+
normalize_quant_cfg_list([raw])
197+
198+
180199
def test_quantizer_cfg_entry_rejects_no_effect_entry():
181200
"""Direct QuantizerCfgEntry construction rejects entries with no cfg or enable."""
182201
with pytest.raises(ValidationError, match="must specify 'cfg', 'enable'"):
@@ -469,6 +488,26 @@ def test_legacy_nn_class_with_cfg(self):
469488
assert _cfg_to_dict(result[0]["cfg"]) == {"num_bits": 4, "axis": 0}
470489
assert result[0]["enable"] is True
471490

491+
def test_legacy_nn_class_with_list_valued_cfg(self):
492+
"""Legacy nn.* scoped format preserves list-valued SequentialQuantizer cfg."""
493+
raw = [
494+
{
495+
"nn.Linear": {
496+
"*weight_quantizer": [
497+
{"num_bits": 4, "block_sizes": {-1: 128, "type": "static"}},
498+
{"num_bits": 8, "axis": 0},
499+
]
500+
}
501+
}
502+
]
503+
result = normalize_quant_cfg_list(raw)
504+
assert len(result) == 1
505+
assert result[0]["parent_class"] == "nn.Linear"
506+
assert result[0]["quantizer_name"] == "*weight_quantizer"
507+
assert isinstance(result[0]["cfg"], list)
508+
assert _cfg_to_dict(result[0]["cfg"]) == raw[0]["nn.Linear"]["*weight_quantizer"]
509+
assert result[0]["enable"] is True
510+
472511
def test_legacy_list_valued_cfg(self):
473512
"""Legacy dict format with list-valued cfg (SequentialQuantizer) normalizes correctly."""
474513
raw = [

tests/unit/torch/quantization/test_quantize_cpu.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -401,6 +401,12 @@ def test_list_attributes_creates_sequential_quantizer(self):
401401
assert isinstance(module, SequentialQuantizer)
402402
assert len(module) == 2
403403

404+
def test_sequential_quantizer_rejects_mismatched_attribute_list_length(self):
405+
"""SequentialQuantizer rejects partial list configs instead of silently zipping."""
406+
quantizer = SequentialQuantizer(TensorQuantizer(), TensorQuantizer())
407+
with pytest.raises(ValueError, match="Expected 2 attribute configs, but got 1"):
408+
quantizer.set_from_attribute_config([QuantizerAttributeConfig(num_bits=8)])
409+
404410

405411
def test_ordering_later_entry_overrides_earlier():
406412
"""Later entries in quant_cfg override earlier ones for the same quantizer."""

0 commit comments

Comments
 (0)