From f2cc1b7e9475a4a522336a288f969a48bedb50e4 Mon Sep 17 00:00:00 2001 From: ajrasane <131806219+ajrasane@users.noreply.github.com> Date: Thu, 14 May 2026 17:22:45 +0000 Subject: [PATCH 1/3] feat(quant): support quantized nn.Embedding MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Register nn.Embedding in QuantModuleRegistry so the embedding table and the lookup activations participate in quantization. The literal input is integer indices, so input_quantizer is a non-configurable placeholder that raises on direct enable*() calls and at forward-time if its _disabled flag is flipped — wildcard configs (e.g. NVFP4_DEFAULT_CFG's *input_quantizer) are accepted silently so the stock deny-all → enable wildcards → opt-out pattern continues to work, and the opt-out is installed by default (parent_class: nn.Embedding in default_disabled_quantizers.yaml). export_hf_checkpoint packs quantized embedding weights through the same path as Linear layers. Signed-off-by: ajrasane <131806219+ajrasane@users.noreply.github.com> --- CHANGELOG.rst | 1 + modelopt/torch/export/unified_export_hf.py | 10 ++ modelopt/torch/quantization/nn/__init__.py | 1 + .../nn/modules/quant_embedding.py | 134 ++++++++++++++++++ .../torch/quantization/utils/core_utils.py | 5 + .../units/default_disabled_quantizers.yaml | 3 + .../quantization/test_quant_embedding.py | 104 ++++++++++++++ 7 files changed, 258 insertions(+) create mode 100644 modelopt/torch/quantization/nn/modules/quant_embedding.py create mode 100644 tests/unit/torch/quantization/test_quant_embedding.py diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 997a1069401..4068cfdd0c8 100755 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -25,6 +25,7 @@ Changelog - Add ``--cast_mxfp4_to_nvfp4`` flag to ``examples/llm_ptq/hf_ptq.py`` for closed-form, bit-exact MXFP4 → NVFP4 weight conversion. Supports the GPT-OSS family (``openai/gpt-oss-20b``, ``openai/gpt-oss-120b``). See `examples/llm_ptq/README.md `__ for usage. - DeepSeek PTQ (``examples/deepseek/ptq.py``) now defaults to native top-k calibration with post-hoc per-layer peer-max sync of expert ``input_quantizer.amax``; the all-experts path is preserved behind ``--calib_all_experts``. - Add NVFP4 W4A16 weight-only quantization (``w4a16_nvfp4``): FP4 weights with group_size=16, BF16 activations, no calibration forward pass required. Use ``mtq.W4A16_NVFP4_CFG`` or ``--qformat w4a16_nvfp4`` in ``hf_ptq.py``. vLLM deployment support is in progress. +- Add quantized ``nn.Embedding`` support. ``nn.Embedding`` is now registered in ``QuantModuleRegistry`` and exposes ``weight_quantizer`` (embedding table), ``output_quantizer`` (lookup activations), and a permanently disabled ``input_quantizer`` placeholder — embedding inputs are integer indices and cannot be fake-quantized, so direct ``enable*()`` calls raise. ``export_hf_checkpoint`` packs quantized embedding weights alongside Linear layers. Embedding quantizers are opt-in (``parent_class: nn.Embedding`` disabled by default). **Bug Fixes** diff --git a/modelopt/torch/export/unified_export_hf.py b/modelopt/torch/export/unified_export_hf.py index 0626d0a8fd5..d7720f16198 100644 --- a/modelopt/torch/export/unified_export_hf.py +++ b/modelopt/torch/export/unified_export_hf.py @@ -679,6 +679,16 @@ def _process_quantized_modules( raise AssertionError( f"Failed to export module '{name}' (type={type(sub_module).__name__}): {e}" ) from e + elif isinstance(sub_module, nn.Embedding) and hasattr(sub_module, "weight_quantizer"): + # Quantized nn.Embedding: pack the embedding table the same way as Linear + # weights so downstream loaders see the NVFP4/FP8/INT-packed bytes + scales. + try: + with fsdp2_aware_weight_update(model, sub_module, reshard=False): + _export_quantized_weight(sub_module, dtype) + except AssertionError as e: + raise AssertionError( + f"Failed to export embedding '{name}' (type={type(sub_module).__name__}): {e}" + ) from e elif ( "Llama4TextExperts" in type(sub_module).__name__ or "GptOssExperts" in type(sub_module).__name__ diff --git a/modelopt/torch/quantization/nn/__init__.py b/modelopt/torch/quantization/nn/__init__.py index af9490c8311..2e6bc64054e 100644 --- a/modelopt/torch/quantization/nn/__init__.py +++ b/modelopt/torch/quantization/nn/__init__.py @@ -18,6 +18,7 @@ from .modules.quant_activations import * from .modules.quant_batchnorm import * from .modules.quant_conv import * +from .modules.quant_embedding import * from .modules.quant_instancenorm import * from .modules.quant_layernorm import * from .modules.quant_linear import * diff --git a/modelopt/torch/quantization/nn/modules/quant_embedding.py b/modelopt/torch/quantization/nn/modules/quant_embedding.py new file mode 100644 index 00000000000..c7ab42dfb92 --- /dev/null +++ b/modelopt/torch/quantization/nn/modules/quant_embedding.py @@ -0,0 +1,134 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Quantized Embedding.""" + +import contextlib + +import torch +import torch.nn as nn + +from ...tensor_quant import QUANT_DESC_8BIT_PER_TENSOR +from ...utils import is_torch_export_mode +from .quant_module import QuantModule, QuantModuleRegistry +from .tensor_quantizer import SequentialQuantizer, TensorQuantizer + +__all__ = ["QuantEmbedding"] + + +_INPUT_QUANTIZER_ERR = ( + "Cannot configure input_quantizer on a quantized nn.Embedding: the input is integer " + "indices and cannot be fake-quantized. Configure weight_quantizer (and optionally " + "output_quantizer) instead." +) + + +class _UnsettableInputQuantizer(TensorQuantizer): + """TensorQuantizer slot for nn.Embedding.input — present but not enable-able. + + Embedding inputs are integer indices that cannot be fake-quantized. The attribute + is kept so introspection code (export, calibration helpers) can find it. Wildcard + configs (e.g. ``NVFP4_DEFAULT_CFG``'s ``*input_quantizer``) are accepted silently + so that the standard "deny-all → enable wildcards → opt-out specific layers" + pattern in the stock configs still works. Direct calls to ``enable*()`` raise + immediately, and ``_QuantEmbedding.forward`` raises if the final state ends up + enabled (e.g. a user explicitly targeted this quantizer). + """ + + def enable(self): + """Disallowed for embedding inputs.""" + raise RuntimeError(_INPUT_QUANTIZER_ERR) + + def enable_quant(self): + """Disallowed for embedding inputs.""" + raise RuntimeError(_INPUT_QUANTIZER_ERR) + + def enable_calib(self): + """Disallowed for embedding inputs.""" + raise RuntimeError(_INPUT_QUANTIZER_ERR) + + +@QuantModuleRegistry.register({nn.Embedding: "nn.Embedding"}) +class _QuantEmbedding(QuantModule): + """Quantized version of ``nn.Embedding``. + + The literal input to ``nn.Embedding`` is integer indices, which cannot be + fake-quantized. The ``input_quantizer`` attribute is kept (for symmetry with + other quant modules and for introspection by export/calibration code) but + configuring it raises — see ``_UnsettableInputQuantizer``. Only the embedding + table (weight) and the lookup output (an activation feeding downstream layers) + are quantizable. + + Quantizer roles: + - ``weight_quantizer``: quantizes the embedding table (``self.weight``). + - ``input_quantizer``: permanently disabled placeholder — raises on configure. + - ``output_quantizer``: optional activation quantizer for the lookup output, + disabled by default. + """ + + weight_quantizer: TensorQuantizer | SequentialQuantizer + input_quantizer: _UnsettableInputQuantizer + output_quantizer: TensorQuantizer + _enable_weight_quantization: bool + default_quant_desc_weight = QUANT_DESC_8BIT_PER_TENSOR + default_quant_desc_input = QUANT_DESC_8BIT_PER_TENSOR + default_quant_desc_output = QUANT_DESC_8BIT_PER_TENSOR + + @contextlib.contextmanager + def quantize_weight(self): + """Context in which ``self.weight`` is quantized via the dynamic attribute.""" + self._enable_weight_quantization = True + try: + yield + finally: + self._enable_weight_quantization = False + + @staticmethod + def _get_quantized_weight(module: "_QuantEmbedding", weight: torch.Tensor) -> torch.Tensor: + if module._enable_weight_quantization or is_torch_export_mode(): + return module.weight_quantizer(weight) + return weight + + def _setup(self): + """Register weight, (locked) input, and output quantizers.""" + self._register_temp_attribute( + "weight_quantizer", TensorQuantizer(self.default_quant_desc_weight) + ) + # Build the input quantizer disabled. _UnsettableInputQuantizer's mutators raise, + # so we disable it once at construction via direct attribute assignment. + input_quantizer = _UnsettableInputQuantizer(self.default_quant_desc_input) + input_quantizer._disabled = True + self._register_temp_attribute("input_quantizer", input_quantizer) + self._register_temp_attribute( + "output_quantizer", TensorQuantizer(self.default_quant_desc_output) + ) + self.output_quantizer.disable() + self._register_temp_attribute("_enable_weight_quantization", False) + self._register_dynamic_attribute("weight", self._get_quantized_weight) + + def forward(self, input, *args, **kwargs): + """Quantize the embedding table, look up, then optionally quantize the output.""" + if self.input_quantizer.is_enabled: + # Caught any config or call that managed to flip _disabled to False. + raise RuntimeError(_INPUT_QUANTIZER_ERR) + if is_torch_export_mode(): + return super().forward(input, *args, **kwargs) + with self.quantize_weight(): + output = super().forward(input, *args, **kwargs) + return self.output_quantizer(output) + + +# Public alias consistent with quant_linear / quant_conv naming. +QuantEmbedding = _QuantEmbedding diff --git a/modelopt/torch/quantization/utils/core_utils.py b/modelopt/torch/quantization/utils/core_utils.py index cea3d4260e4..15c6504011b 100644 --- a/modelopt/torch/quantization/utils/core_utils.py +++ b/modelopt/torch/quantization/utils/core_utils.py @@ -297,6 +297,11 @@ def is_quantized_linear(module): """Check if a module is a quantized linear module.""" from ..nn import QuantModule, TensorQuantizer + # Embedding has a 2D weight but is not a GEMM op, so calibration passes that operate + # on linear activations (AWQ, SmoothQuant, SVDQuant) must skip it. + if isinstance(module, nn.Embedding): + return False + return ( isinstance(module, QuantModule) and isinstance(getattr(module, "input_quantizer", None), TensorQuantizer) diff --git a/modelopt_recipes/configs/ptq/units/default_disabled_quantizers.yaml b/modelopt_recipes/configs/ptq/units/default_disabled_quantizers.yaml index 1508f942776..4fddfdbbc6a 100644 --- a/modelopt_recipes/configs/ptq/units/default_disabled_quantizers.yaml +++ b/modelopt_recipes/configs/ptq/units/default_disabled_quantizers.yaml @@ -48,3 +48,6 @@ - parent_class: 'nn.LeakyReLU' quantizer_name: '*' enable: false + - parent_class: 'nn.Embedding' + quantizer_name: '*' + enable: false diff --git a/tests/unit/torch/quantization/test_quant_embedding.py b/tests/unit/torch/quantization/test_quant_embedding.py new file mode 100644 index 00000000000..2b588194506 --- /dev/null +++ b/tests/unit/torch/quantization/test_quant_embedding.py @@ -0,0 +1,104 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests of QuantEmbedding module.""" + +import pytest +import torch +import torch.nn as nn +import torch.nn.functional as F + +from modelopt.torch.quantization import tensor_quant +from modelopt.torch.quantization import utils as quant_utils +from modelopt.torch.quantization.config import QuantizerAttributeConfig +from modelopt.torch.quantization.conversion import set_quantizer_attributes_partial +from modelopt.torch.quantization.nn import QuantModuleRegistry +from modelopt.torch.quantization.nn.modules.quant_embedding import _UnsettableInputQuantizer + +VOCAB_SIZE = 16 +EMBED_DIM = 8 + + +def _make_quant_embedding(**kwargs) -> nn.Module: + return QuantModuleRegistry.convert(nn.Embedding(VOCAB_SIZE, EMBED_DIM, **kwargs)) + + +class TestQuantEmbedding: + def test_default_state_and_no_quant(self): + """Default state: input quant locked-disabled, output quant disabled, weight quant on; + with weight quant also off the wrapper matches plain F.embedding.""" + qemb = _make_quant_embedding() + assert isinstance(qemb.input_quantizer, _UnsettableInputQuantizer) + assert not qemb.input_quantizer.is_enabled + assert not qemb.output_quantizer.is_enabled + assert qemb.weight_quantizer.is_enabled + + qemb.weight_quantizer.disable() + ids = torch.randint(0, VOCAB_SIZE, (4, 6)) + assert torch.allclose(qemb(ids), F.embedding(ids, qemb.weight), rtol=0, atol=0) + + @pytest.mark.parametrize("axis", [None, 0]) + def test_weight_fake_quant(self, axis): + """Per-tensor (axis=None) and per-row (axis=0) weight fake quant match the manual ref.""" + qemb = _make_quant_embedding() + set_quantizer_attributes_partial( + qemb, "*weight_quantizer", QuantizerAttributeConfig(axis=axis).model_dump() + ) + + ids = torch.randint(0, VOCAB_SIZE, (4, 6)) + weight = qemb.weight.detach().clone() + amax = ( + torch.max(torch.abs(weight)) + if axis is None + else quant_utils.reduce_amax(weight, axis=1, keepdims=True) + ) + ref = F.embedding(ids, tensor_quant.fake_tensor_quant(weight, amax)) + assert torch.allclose(qemb(ids), ref, rtol=0, atol=0) + + def test_output_quantizer_applied_when_enabled(self): + qemb = _make_quant_embedding() + qemb.weight_quantizer.disable() + qemb.output_quantizer.enable() + ids = torch.randint(0, VOCAB_SIZE, (4, 6)) + with torch.no_grad(): + qemb(ids) # calibrate + + ref = qemb.output_quantizer(F.embedding(ids, qemb.weight)) + assert torch.allclose(qemb(ids), ref, rtol=0, atol=0) + + @pytest.mark.parametrize("method", ["enable", "enable_quant", "enable_calib"]) + def test_input_quantizer_mutators_raise(self, method): + qemb = _make_quant_embedding() + with pytest.raises(RuntimeError, match="nn.Embedding"): + getattr(qemb.input_quantizer, method)() + + def test_forward_raises_if_input_quantizer_enabled(self): + """Forward catches back-door flips of input_quantizer._disabled.""" + qemb = _make_quant_embedding() + qemb.input_quantizer._disabled = False + with pytest.raises(RuntimeError, match="nn.Embedding"): + qemb(torch.randint(0, VOCAB_SIZE, (4, 6))) + + def test_wildcard_config_accepted_then_opt_out(self): + """Wildcard cfg on ``*input_quantizer`` must not raise — stock NVFP4_DEFAULT_CFG relies on it. + A follow-up ``enable: false`` rule restores the disabled state.""" + qemb = _make_quant_embedding() + set_quantizer_attributes_partial( + qemb, + "*input_quantizer", + QuantizerAttributeConfig(num_bits=8, axis=None).model_dump(), + ) + set_quantizer_attributes_partial(qemb, "*input_quantizer", {"enable": False}) + qemb(torch.randint(0, VOCAB_SIZE, (4, 6))) # forward succeeds From f5f4227eda4ed37d50f3ca0e9b998fa1659e8a75 Mon Sep 17 00:00:00 2001 From: ajrasane <131806219+ajrasane@users.noreply.github.com> Date: Thu, 14 May 2026 19:17:06 +0000 Subject: [PATCH 2/3] fix(quant): address review feedback on quant_embedding MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Apply output_quantizer in the torch.export branch of _QuantEmbedding.forward so users who opt into output activation quantization don't silently lose it during export. Matches QuantInputBase.forward's behavior. - Detect Python-level weight tying (e.g. tied_word_embeddings → lm_head) in _process_quantized_modules and skip packing the embedding when the .weight Parameter is shared, with a UserWarning. Packing would otherwise reassign the embedding's .weight to a new uint8 Parameter, severing the tie and leaving the tied module pointing at a stale float Parameter. - Add export-path tests covering the normal pack flow (weight → uint8 + weight_scale + weight_scale_2 buffers) and the tied-embedding skip path (weight unchanged, warning raised, tie preserved). Signed-off-by: ajrasane <131806219+ajrasane@users.noreply.github.com> --- modelopt/torch/export/unified_export_hf.py | 32 +++++-- .../nn/modules/quant_embedding.py | 10 +- .../quantization/test_quant_embedding.py | 94 ++++++++++++++++++- 3 files changed, 126 insertions(+), 10 deletions(-) diff --git a/modelopt/torch/export/unified_export_hf.py b/modelopt/torch/export/unified_export_hf.py index d7720f16198..21e5ba0638d 100644 --- a/modelopt/torch/export/unified_export_hf.py +++ b/modelopt/torch/export/unified_export_hf.py @@ -682,13 +682,31 @@ def _process_quantized_modules( elif isinstance(sub_module, nn.Embedding) and hasattr(sub_module, "weight_quantizer"): # Quantized nn.Embedding: pack the embedding table the same way as Linear # weights so downstream loaders see the NVFP4/FP8/INT-packed bytes + scales. - try: - with fsdp2_aware_weight_update(model, sub_module, reshard=False): - _export_quantized_weight(sub_module, dtype) - except AssertionError as e: - raise AssertionError( - f"Failed to export embedding '{name}' (type={type(sub_module).__name__}): {e}" - ) from e + # Skip packing when the embedding's weight is tied to another module + # (e.g. tied_word_embeddings → lm_head): _export_quantized_weight reassigns + # the .weight attribute to a new uint8 Parameter, which severs the Python- + # level tie and leaves the other module pointing at a stale float Parameter. + tied_to = [ + other_name + for other_name, other_module in model.named_modules() + if other_module is not sub_module + and getattr(other_module, "weight", None) is sub_module.weight + ] + if tied_to: + warnings.warn( + f"Skipping quantized weight packing for embedding '{name}': its " + f"weight Parameter is shared with {tied_to} (weight tying). Packing " + "would break the tie and produce stale weights in the tied module(s). " + "The embedding will be exported as its fake-quantized float weight." + ) + else: + try: + with fsdp2_aware_weight_update(model, sub_module, reshard=False): + _export_quantized_weight(sub_module, dtype) + except AssertionError as e: + raise AssertionError( + f"Failed to export embedding '{name}' (type={type(sub_module).__name__}): {e}" + ) from e elif ( "Llama4TextExperts" in type(sub_module).__name__ or "GptOssExperts" in type(sub_module).__name__ diff --git a/modelopt/torch/quantization/nn/modules/quant_embedding.py b/modelopt/torch/quantization/nn/modules/quant_embedding.py index c7ab42dfb92..02ffb28fa86 100644 --- a/modelopt/torch/quantization/nn/modules/quant_embedding.py +++ b/modelopt/torch/quantization/nn/modules/quant_embedding.py @@ -124,9 +124,15 @@ def forward(self, input, *args, **kwargs): # Caught any config or call that managed to flip _disabled to False. raise RuntimeError(_INPUT_QUANTIZER_ERR) if is_torch_export_mode(): - return super().forward(input, *args, **kwargs) - with self.quantize_weight(): + # quantize_weight()'s attribute write is not allowed under torch.export; + # weight quantization is still applied inline via _get_quantized_weight's + # is_torch_export_mode() branch. Apply output_quantizer in this path too + # so users who opt into output activation quantization don't silently + # lose it during export — matches QuantInputBase.forward's behavior. output = super().forward(input, *args, **kwargs) + else: + with self.quantize_weight(): + output = super().forward(input, *args, **kwargs) return self.output_quantizer(output) diff --git a/tests/unit/torch/quantization/test_quant_embedding.py b/tests/unit/torch/quantization/test_quant_embedding.py index 2b588194506..d28a75f52b3 100644 --- a/tests/unit/torch/quantization/test_quant_embedding.py +++ b/tests/unit/torch/quantization/test_quant_embedding.py @@ -20,22 +20,28 @@ import torch.nn as nn import torch.nn.functional as F +import modelopt.torch.quantization as mtq +from modelopt.torch.export.unified_export_hf import _process_quantized_modules from modelopt.torch.quantization import tensor_quant from modelopt.torch.quantization import utils as quant_utils from modelopt.torch.quantization.config import QuantizerAttributeConfig from modelopt.torch.quantization.conversion import set_quantizer_attributes_partial from modelopt.torch.quantization.nn import QuantModuleRegistry from modelopt.torch.quantization.nn.modules.quant_embedding import _UnsettableInputQuantizer +from modelopt.torch.quantization.utils import quantizer_attr_names VOCAB_SIZE = 16 -EMBED_DIM = 8 +EMBED_DIM = 32 # multiple of the NVFP4 block size (16) so export tests can pack def _make_quant_embedding(**kwargs) -> nn.Module: + """Build an nn.Embedding and convert it through QuantModuleRegistry.""" return QuantModuleRegistry.convert(nn.Embedding(VOCAB_SIZE, EMBED_DIM, **kwargs)) class TestQuantEmbedding: + """Forward-path behavior of the registered QuantEmbedding wrapper.""" + def test_default_state_and_no_quant(self): """Default state: input quant locked-disabled, output quant disabled, weight quant on; with weight quant also off the wrapper matches plain F.embedding.""" @@ -68,6 +74,7 @@ def test_weight_fake_quant(self, axis): assert torch.allclose(qemb(ids), ref, rtol=0, atol=0) def test_output_quantizer_applied_when_enabled(self): + """Enabling output_quantizer makes forward equivalent to applying it to the lookup.""" qemb = _make_quant_embedding() qemb.weight_quantizer.disable() qemb.output_quantizer.enable() @@ -80,6 +87,7 @@ def test_output_quantizer_applied_when_enabled(self): @pytest.mark.parametrize("method", ["enable", "enable_quant", "enable_calib"]) def test_input_quantizer_mutators_raise(self, method): + """Each public enable/enable_quant/enable_calib API on input_quantizer raises.""" qemb = _make_quant_embedding() with pytest.raises(RuntimeError, match="nn.Embedding"): getattr(qemb.input_quantizer, method)() @@ -102,3 +110,87 @@ def test_wildcard_config_accepted_then_opt_out(self): ) set_quantizer_attributes_partial(qemb, "*input_quantizer", {"enable": False}) qemb(torch.randint(0, VOCAB_SIZE, (4, 6))) # forward succeeds + + +def _embedding_nvfp4_cfg() -> dict: + """Stock-NVFP4-style cfg that opts the embedding's weight quantizer in.""" + nvfp4 = { + "num_bits": (2, 1), + "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, + } + return { + "quant_cfg": [ + {"quantizer_name": "*", "enable": False}, + { + "parent_class": "nn.Embedding", + "quantizer_name": "*weight_quantizer", + "cfg": dict(nvfp4), + }, + ], + "algorithm": "max", + } + + +class _EmbeddingOnly(nn.Module): + """Single-embedding wrapper exposing forward + named_modules iteration.""" + + def __init__(self): + """Build the lone embedding submodule.""" + super().__init__() + self.embedding = nn.Embedding(VOCAB_SIZE, EMBED_DIM) + + def forward(self, ids): + """Look up embeddings for the given token IDs.""" + return self.embedding(ids) + + +class _TiedEmbeddingLM(nn.Module): + """Embedding + Linear lm_head with tied weights (lm_head.weight is embedding.weight).""" + + def __init__(self): + """Build embedding + lm_head and tie their weight Parameters.""" + super().__init__() + self.embedding = nn.Embedding(VOCAB_SIZE, EMBED_DIM) + self.lm_head = nn.Linear(EMBED_DIM, VOCAB_SIZE, bias=False) + self.lm_head.weight = self.embedding.weight # Python-level tie + + def forward(self, ids): + """Embed then project to vocab logits with the tied weight.""" + return self.lm_head(self.embedding(ids)) + + +class TestQuantEmbeddingExport: + """Export-path coverage: weight packing and tied-weight guard.""" + + def test_quantized_weight_is_packed_and_scales_registered(self): + """End-to-end: _process_quantized_modules packs the embedding weight and + registers ``weight_scale`` + ``weight_scale_2`` buffers.""" + model = _EmbeddingOnly() + model = mtq.quantize( + model, _embedding_nvfp4_cfg(), lambda m: m(torch.randint(0, VOCAB_SIZE, (2, 4))) + ) + _process_quantized_modules(model, dtype=torch.float16) + + attrs = quantizer_attr_names("weight") + assert model.embedding.weight.dtype == torch.uint8 + assert hasattr(model.embedding, attrs.weight_scale) + assert hasattr(model.embedding, attrs.weight_scale_2) + # input_scale is not registered (input_quantizer is permanently disabled). + assert not hasattr(model.embedding, attrs.input_scale) + + def test_tied_embedding_export_skips_packing(self): + """When the embedding weight is shared with lm_head, packing is skipped + with a warning so the tie survives the export.""" + model = _TiedEmbeddingLM() + assert model.lm_head.weight is model.embedding.weight # sanity + + model = mtq.quantize( + model, _embedding_nvfp4_cfg(), lambda m: m(torch.randint(0, VOCAB_SIZE, (2, 4))) + ) + orig_dtype = model.embedding.weight.dtype + with pytest.warns(UserWarning, match="tied"): + _process_quantized_modules(model, dtype=torch.float16) + + # Weight Parameter unchanged (not packed to uint8) and still tied. + assert model.embedding.weight.dtype == orig_dtype + assert model.lm_head.weight is model.embedding.weight From a932284fe68ff5123cc2b015f43953e13e327dfd Mon Sep 17 00:00:00 2001 From: ajrasane <131806219+ajrasane@users.noreply.github.com> Date: Thu, 14 May 2026 19:26:20 +0000 Subject: [PATCH 3/3] fix(quant): make embedding input_quantizer absorb wildcard configs MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The previous design raised in _QuantEmbedding.forward whenever input_quantizer.is_enabled, on the theory that any non-disable config was an explicit user mistake. That assumption was wrong for wildcard configs: the default QuantizeConfig is just [{"quantizer_name": "*", "cfg": {"num_bits": 8, ...}}] (no embedding opt-out), so the wildcard enables embed_tokens.input_quantizer for tiny Llama-style tests and the forward guard fires — breaking test_peft_save_load and test_transformers_save_load. Switch _UnsettableInputQuantizer.set_from_attribute_config to absorb the incoming config like a normal quantizer, then force _disabled = True at the end. The "throw on explicit set" semantics are preserved via the .enable / .enable_quant / .enable_calib overrides, which catch the direct mistakes users would actually make. The forward-time guard (and the corresponding test) are removed since the invariant is now maintained at the configure step. Signed-off-by: ajrasane <131806219+ajrasane@users.noreply.github.com> --- .../nn/modules/quant_embedding.py | 43 +++++++++++++------ .../quantization/test_quant_embedding.py | 20 ++++----- 2 files changed, 39 insertions(+), 24 deletions(-) diff --git a/modelopt/torch/quantization/nn/modules/quant_embedding.py b/modelopt/torch/quantization/nn/modules/quant_embedding.py index 02ffb28fa86..91e9f754c7d 100644 --- a/modelopt/torch/quantization/nn/modules/quant_embedding.py +++ b/modelopt/torch/quantization/nn/modules/quant_embedding.py @@ -39,12 +39,13 @@ class _UnsettableInputQuantizer(TensorQuantizer): """TensorQuantizer slot for nn.Embedding.input — present but not enable-able. Embedding inputs are integer indices that cannot be fake-quantized. The attribute - is kept so introspection code (export, calibration helpers) can find it. Wildcard - configs (e.g. ``NVFP4_DEFAULT_CFG``'s ``*input_quantizer``) are accepted silently - so that the standard "deny-all → enable wildcards → opt-out specific layers" - pattern in the stock configs still works. Direct calls to ``enable*()`` raise - immediately, and ``_QuantEmbedding.forward`` raises if the final state ends up - enabled (e.g. a user explicitly targeted this quantizer). + is kept so introspection code (export, calibration helpers) can find it. + + Wildcard configs (e.g. the default ``QuantizeConfig`` ``"*"`` rule or + ``NVFP4_DEFAULT_CFG``'s ``*input_quantizer``) are accepted silently, then the + quantizer is force-disabled — wildcards don't really mean "enable embedding + input quant", they mean "enable input quant in general". Direct, explicit + attempts (calling ``enable``/``enable_quant``/``enable_calib``) raise loudly. """ def enable(self): @@ -59,6 +60,17 @@ def enable_calib(self): """Disallowed for embedding inputs.""" raise RuntimeError(_INPUT_QUANTIZER_ERR) + def set_from_attribute_config(self, attribute_cfg): + """Apply the config like any quantizer, then force-disable us. + + This absorbs wildcard configs from stock recipes without raising. The + quantizer's other attributes (``num_bits``, ``axis``, etc.) take on the + config values for introspection, but ``_disabled`` is forced back to + ``True`` so forward is always a no-op. + """ + super().set_from_attribute_config(attribute_cfg) + self._disabled = True + @QuantModuleRegistry.register({nn.Embedding: "nn.Embedding"}) class _QuantEmbedding(QuantModule): @@ -66,14 +78,16 @@ class _QuantEmbedding(QuantModule): The literal input to ``nn.Embedding`` is integer indices, which cannot be fake-quantized. The ``input_quantizer`` attribute is kept (for symmetry with - other quant modules and for introspection by export/calibration code) but - configuring it raises — see ``_UnsettableInputQuantizer``. Only the embedding + other quant modules and for introspection by export/calibration code) but is + permanently disabled — see ``_UnsettableInputQuantizer``. Only the embedding table (weight) and the lookup output (an activation feeding downstream layers) are quantizable. Quantizer roles: - ``weight_quantizer``: quantizes the embedding table (``self.weight``). - - ``input_quantizer``: permanently disabled placeholder — raises on configure. + - ``input_quantizer``: permanently disabled placeholder — direct + ``enable*()`` calls raise; configs that target it are absorbed and the + quantizer is force-disabled. - ``output_quantizer``: optional activation quantizer for the lookup output, disabled by default. """ @@ -119,10 +133,13 @@ def _setup(self): self._register_dynamic_attribute("weight", self._get_quantized_weight) def forward(self, input, *args, **kwargs): - """Quantize the embedding table, look up, then optionally quantize the output.""" - if self.input_quantizer.is_enabled: - # Caught any config or call that managed to flip _disabled to False. - raise RuntimeError(_INPUT_QUANTIZER_ERR) + """Quantize the embedding table, look up, then optionally quantize the output. + + ``input_quantizer`` is intentionally never applied — embedding inputs are + integer indices. ``_UnsettableInputQuantizer.set_from_attribute_config`` + keeps that quantizer disabled regardless of what configs target it, so we + rely on that invariant rather than a runtime check here. + """ if is_torch_export_mode(): # quantize_weight()'s attribute write is not allowed under torch.export; # weight quantization is still applied inline via _get_quantized_weight's diff --git a/tests/unit/torch/quantization/test_quant_embedding.py b/tests/unit/torch/quantization/test_quant_embedding.py index d28a75f52b3..af3931b1892 100644 --- a/tests/unit/torch/quantization/test_quant_embedding.py +++ b/tests/unit/torch/quantization/test_quant_embedding.py @@ -92,24 +92,22 @@ def test_input_quantizer_mutators_raise(self, method): with pytest.raises(RuntimeError, match="nn.Embedding"): getattr(qemb.input_quantizer, method)() - def test_forward_raises_if_input_quantizer_enabled(self): - """Forward catches back-door flips of input_quantizer._disabled.""" - qemb = _make_quant_embedding() - qemb.input_quantizer._disabled = False - with pytest.raises(RuntimeError, match="nn.Embedding"): - qemb(torch.randint(0, VOCAB_SIZE, (4, 6))) + def test_wildcard_config_keeps_input_quantizer_disabled(self): + """set_from_attribute_config absorbs any cfg but force-disables input_quantizer. - def test_wildcard_config_accepted_then_opt_out(self): - """Wildcard cfg on ``*input_quantizer`` must not raise — stock NVFP4_DEFAULT_CFG relies on it. - A follow-up ``enable: false`` rule restores the disabled state.""" + Stock recipes' ``*input_quantizer`` wildcard (and the default ``QuantizeConfig`` + ``"*"`` rule) target every quantizer including the embedding's input slot. + The quantizer must end up disabled regardless of what the cfg said. + """ qemb = _make_quant_embedding() set_quantizer_attributes_partial( qemb, "*input_quantizer", QuantizerAttributeConfig(num_bits=8, axis=None).model_dump(), ) - set_quantizer_attributes_partial(qemb, "*input_quantizer", {"enable": False}) - qemb(torch.randint(0, VOCAB_SIZE, (4, 6))) # forward succeeds + assert not qemb.input_quantizer.is_enabled + # Forward still works — input_quantizer is disabled and never applied. + qemb(torch.randint(0, VOCAB_SIZE, (4, 6))) def _embedding_nvfp4_cfg() -> dict: