Skip to content

Commit a932284

Browse files
committed
fix(quant): make embedding input_quantizer absorb wildcard configs
The previous design raised in _QuantEmbedding.forward whenever input_quantizer.is_enabled, on the theory that any non-disable config was an explicit user mistake. That assumption was wrong for wildcard configs: the default QuantizeConfig is just [{"quantizer_name": "*", "cfg": {"num_bits": 8, ...}}] (no embedding opt-out), so the wildcard enables embed_tokens.input_quantizer for tiny Llama-style tests and the forward guard fires — breaking test_peft_save_load and test_transformers_save_load. Switch _UnsettableInputQuantizer.set_from_attribute_config to absorb the incoming config like a normal quantizer, then force _disabled = True at the end. The "throw on explicit set" semantics are preserved via the .enable / .enable_quant / .enable_calib overrides, which catch the direct mistakes users would actually make. The forward-time guard (and the corresponding test) are removed since the invariant is now maintained at the configure step. Signed-off-by: ajrasane <131806219+ajrasane@users.noreply.github.com>
1 parent f5f4227 commit a932284

2 files changed

Lines changed: 39 additions & 24 deletions

File tree

modelopt/torch/quantization/nn/modules/quant_embedding.py

Lines changed: 30 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -39,12 +39,13 @@ class _UnsettableInputQuantizer(TensorQuantizer):
3939
"""TensorQuantizer slot for nn.Embedding.input — present but not enable-able.
4040
4141
Embedding inputs are integer indices that cannot be fake-quantized. The attribute
42-
is kept so introspection code (export, calibration helpers) can find it. Wildcard
43-
configs (e.g. ``NVFP4_DEFAULT_CFG``'s ``*input_quantizer``) are accepted silently
44-
so that the standard "deny-all → enable wildcards → opt-out specific layers"
45-
pattern in the stock configs still works. Direct calls to ``enable*()`` raise
46-
immediately, and ``_QuantEmbedding.forward`` raises if the final state ends up
47-
enabled (e.g. a user explicitly targeted this quantizer).
42+
is kept so introspection code (export, calibration helpers) can find it.
43+
44+
Wildcard configs (e.g. the default ``QuantizeConfig`` ``"*"`` rule or
45+
``NVFP4_DEFAULT_CFG``'s ``*input_quantizer``) are accepted silently, then the
46+
quantizer is force-disabled — wildcards don't really mean "enable embedding
47+
input quant", they mean "enable input quant in general". Direct, explicit
48+
attempts (calling ``enable``/``enable_quant``/``enable_calib``) raise loudly.
4849
"""
4950

5051
def enable(self):
@@ -59,21 +60,34 @@ def enable_calib(self):
5960
"""Disallowed for embedding inputs."""
6061
raise RuntimeError(_INPUT_QUANTIZER_ERR)
6162

63+
def set_from_attribute_config(self, attribute_cfg):
64+
"""Apply the config like any quantizer, then force-disable us.
65+
66+
This absorbs wildcard configs from stock recipes without raising. The
67+
quantizer's other attributes (``num_bits``, ``axis``, etc.) take on the
68+
config values for introspection, but ``_disabled`` is forced back to
69+
``True`` so forward is always a no-op.
70+
"""
71+
super().set_from_attribute_config(attribute_cfg)
72+
self._disabled = True
73+
6274

6375
@QuantModuleRegistry.register({nn.Embedding: "nn.Embedding"})
6476
class _QuantEmbedding(QuantModule):
6577
"""Quantized version of ``nn.Embedding``.
6678
6779
The literal input to ``nn.Embedding`` is integer indices, which cannot be
6880
fake-quantized. The ``input_quantizer`` attribute is kept (for symmetry with
69-
other quant modules and for introspection by export/calibration code) but
70-
configuring it raises — see ``_UnsettableInputQuantizer``. Only the embedding
81+
other quant modules and for introspection by export/calibration code) but is
82+
permanently disabled — see ``_UnsettableInputQuantizer``. Only the embedding
7183
table (weight) and the lookup output (an activation feeding downstream layers)
7284
are quantizable.
7385
7486
Quantizer roles:
7587
- ``weight_quantizer``: quantizes the embedding table (``self.weight``).
76-
- ``input_quantizer``: permanently disabled placeholder — raises on configure.
88+
- ``input_quantizer``: permanently disabled placeholder — direct
89+
``enable*()`` calls raise; configs that target it are absorbed and the
90+
quantizer is force-disabled.
7791
- ``output_quantizer``: optional activation quantizer for the lookup output,
7892
disabled by default.
7993
"""
@@ -119,10 +133,13 @@ def _setup(self):
119133
self._register_dynamic_attribute("weight", self._get_quantized_weight)
120134

121135
def forward(self, input, *args, **kwargs):
122-
"""Quantize the embedding table, look up, then optionally quantize the output."""
123-
if self.input_quantizer.is_enabled:
124-
# Caught any config or call that managed to flip _disabled to False.
125-
raise RuntimeError(_INPUT_QUANTIZER_ERR)
136+
"""Quantize the embedding table, look up, then optionally quantize the output.
137+
138+
``input_quantizer`` is intentionally never applied — embedding inputs are
139+
integer indices. ``_UnsettableInputQuantizer.set_from_attribute_config``
140+
keeps that quantizer disabled regardless of what configs target it, so we
141+
rely on that invariant rather than a runtime check here.
142+
"""
126143
if is_torch_export_mode():
127144
# quantize_weight()'s attribute write is not allowed under torch.export;
128145
# weight quantization is still applied inline via _get_quantized_weight's

tests/unit/torch/quantization/test_quant_embedding.py

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -92,24 +92,22 @@ def test_input_quantizer_mutators_raise(self, method):
9292
with pytest.raises(RuntimeError, match="nn.Embedding"):
9393
getattr(qemb.input_quantizer, method)()
9494

95-
def test_forward_raises_if_input_quantizer_enabled(self):
96-
"""Forward catches back-door flips of input_quantizer._disabled."""
97-
qemb = _make_quant_embedding()
98-
qemb.input_quantizer._disabled = False
99-
with pytest.raises(RuntimeError, match="nn.Embedding"):
100-
qemb(torch.randint(0, VOCAB_SIZE, (4, 6)))
95+
def test_wildcard_config_keeps_input_quantizer_disabled(self):
96+
"""set_from_attribute_config absorbs any cfg but force-disables input_quantizer.
10197
102-
def test_wildcard_config_accepted_then_opt_out(self):
103-
"""Wildcard cfg on ``*input_quantizer`` must not raise — stock NVFP4_DEFAULT_CFG relies on it.
104-
A follow-up ``enable: false`` rule restores the disabled state."""
98+
Stock recipes' ``*input_quantizer`` wildcard (and the default ``QuantizeConfig``
99+
``"*"`` rule) target every quantizer including the embedding's input slot.
100+
The quantizer must end up disabled regardless of what the cfg said.
101+
"""
105102
qemb = _make_quant_embedding()
106103
set_quantizer_attributes_partial(
107104
qemb,
108105
"*input_quantizer",
109106
QuantizerAttributeConfig(num_bits=8, axis=None).model_dump(),
110107
)
111-
set_quantizer_attributes_partial(qemb, "*input_quantizer", {"enable": False})
112-
qemb(torch.randint(0, VOCAB_SIZE, (4, 6))) # forward succeeds
108+
assert not qemb.input_quantizer.is_enabled
109+
# Forward still works — input_quantizer is disabled and never applied.
110+
qemb(torch.randint(0, VOCAB_SIZE, (4, 6)))
113111

114112

115113
def _embedding_nvfp4_cfg() -> dict:

0 commit comments

Comments
 (0)